File size: 10,543 Bytes
bc90483 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 | # Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This software may be used and distributed in accordance with
# the terms of the DINOv3 License Agreement.
import gc
import logging
import os
from enum import Enum
from typing import Any, Dict, List, Optional
import numpy as np
import torch
from torch import nn
from torchmetrics import Metric
import dinov3.distributed as distributed
from dinov3.data import DatasetWithEnumeratedTargets, SamplerType, make_data_loader
from dinov3.eval.accumulators import NoOpAccumulator, ResultsAccumulator
from dinov3.logging import MetricLogger
logger = logging.getLogger("fairvit")
class LossType(Enum):
CROSS_ENTROPY = "cross_entropy"
BINARY_CROSS_ENTROPY = "binary_cross_entropy"
class ModelWithNormalize(torch.nn.Module):
def __init__(self, model: torch.nn.Module) -> None:
super().__init__()
self._model = model
def forward(self, samples):
return nn.functional.normalize(self._model(samples), dim=1, p=2)
class ModelWithMultiScale(torch.nn.Module):
def __init__(self, model: torch.nn.Module, mode: str = "bilinear") -> None:
super().__init__()
self._model = model
self._mode = mode
def forward(self, samples):
output = None
for scale in (1, 0.5**0.5, 0.5):
if scale == 1:
resized_samples = samples.clone()
else:
resized_samples = nn.functional.interpolate(
samples, scale_factor=scale, mode=self._mode, align_corners=False
)
scale_output = self._model(resized_samples).clone()
if output is None:
output = scale_output
else:
output += scale_output
return output / 3
def wrap_model(
model: nn.Module,
*,
normalize: bool = True,
multi_scale: bool = False,
) -> nn.Module:
logger.info("multi-scale: {}".format("enabled" if multi_scale else "disabled"))
if multi_scale:
model = ModelWithMultiScale(model)
logger.info("normalize: {}".format("enabled" if normalize else "disabled"))
if normalize:
model = ModelWithNormalize(model)
return model
class ModelWithIntermediateLayers(nn.Module):
def __init__(self, feature_model, n_last_blocks, autocast_ctx):
super().__init__()
self.feature_model = feature_model
self.feature_model.eval()
self.n_last_blocks = n_last_blocks
self.autocast_ctx = autocast_ctx
def forward(self, images):
with torch.inference_mode():
with self.autocast_ctx():
features = self.feature_model.get_intermediate_layers(
images, n=self.n_last_blocks, return_class_token=True
)
return features
@torch.inference_mode()
def evaluate(
model: nn.Module,
data_loader,
postprocessors: Dict[str, nn.Module],
metrics: Dict[str, Metric],
device: torch.device,
criterion: Optional[nn.Module] = None,
accumulate_results: bool = False,
):
gc.collect() # Avoids garbage collection errors in DataLoader workers
model.eval()
if criterion is not None:
criterion.eval()
for metric in metrics.values():
metric = metric.to(device)
metric_logger = MetricLogger(delimiter=" ")
header = "Test:"
accumulator_class = ResultsAccumulator if accumulate_results else NoOpAccumulator
accumulators = {k: accumulator_class() for k in postprocessors.keys()}
# Dataset needs to be wrapped in fairvit.data.adapters.DatasetWithEnumeratedTargets
for samples, (index, targets), *_ in metric_logger.log_every(data_loader, 10, header):
samples, targets, index = samples[index >= 0], targets[index >= 0], index[index >= 0]
if len(index) == 0:
continue
outputs = model(samples.to(device))
index = index.to(device)
targets = targets.to(device)
if criterion is not None:
loss = criterion(outputs, targets)
metric_logger.update(loss=loss.item())
for k, metric in metrics.items():
metric_inputs = postprocessors[k](outputs, targets)
metric.update(**metric_inputs)
accumulators[k].update(preds=metric_inputs["preds"], target=metric_inputs["target"], index=index)
metric_logger.synchronize_between_processes()
logger.info(f"Averaged stats: {metric_logger}")
stats = {k: metric.compute() for k, metric in metrics.items()}
metric_logger_stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
# accumulator.accumulate() returns None for the NoOpAccumulator
accumulated_results = {k: accumulator.accumulate() for k, accumulator in accumulators.items()}
return metric_logger_stats, stats, accumulated_results
def all_gather_and_flatten(tensor_rank):
tensor_all_ranks = torch.empty(
distributed.get_world_size(),
*tensor_rank.shape,
dtype=tensor_rank.dtype,
device=tensor_rank.device,
)
tensor_list = list(tensor_all_ranks.unbind(0))
torch.distributed.all_gather(tensor_list, tensor_rank.contiguous())
return tensor_all_ranks.flatten(end_dim=1)
def extract_features(model, dataset, batch_size, num_workers, gather_on_cpu=False):
dataset_with_enumerated_targets = DatasetWithEnumeratedTargets(dataset)
sample_count = len(dataset_with_enumerated_targets)
data_loader = make_data_loader(
dataset=dataset_with_enumerated_targets,
batch_size=batch_size,
num_workers=num_workers,
sampler_type=SamplerType.DISTRIBUTED,
drop_last=False,
shuffle=False,
)
return extract_features_with_dataloader(model, data_loader, sample_count, gather_on_cpu)
@torch.inference_mode()
def extract_features_with_dataloader(model, data_loader, sample_count, gather_on_cpu=False):
gather_device = torch.device("cpu") if gather_on_cpu else torch.device("cuda")
metric_logger = MetricLogger(delimiter=" ")
features, all_labels = None, None
for samples, (index, labels_rank) in metric_logger.log_every(data_loader, 10):
samples = samples.cuda(non_blocking=True)
labels_rank = labels_rank.cuda(non_blocking=True)
index = index.cuda(non_blocking=True)
features_rank = model(samples).float()
# init storage feature matrix
if features is None:
features = torch.zeros(sample_count, features_rank.shape[-1], device=gather_device)
labels_shape = list(labels_rank.shape)
labels_shape[0] = sample_count
all_labels = torch.full(labels_shape, fill_value=-1, device=gather_device)
logger.info(f"Storing features into tensor of shape {features.shape}")
# share indexes, features and labels between processes
index_all = all_gather_and_flatten(index).to(gather_device)
features_all_ranks = all_gather_and_flatten(features_rank).to(gather_device)
labels_all_ranks = all_gather_and_flatten(labels_rank).to(gather_device)
# update storage feature matrix
if len(index_all) > 0:
features.index_copy_(0, index_all, features_all_ranks)
all_labels.index_copy_(0, index_all, labels_all_ranks)
logger.info(f"Features Shape {features.shape}")
logger.info(f"Labels Shape {all_labels.shape}")
return features, all_labels
def save_features_dict(features_dict: Dict[str, torch.Tensor], path: str) -> None:
logger.info(f'saving features to "{path}"')
for key, value in features_dict.items():
assert isinstance(key, str)
assert isinstance(value, torch.Tensor)
_, ext = os.path.splitext(path)
if ext == ".pt":
torch.save(features_dict, path)
elif ext == ".npy":
numpy_features_dict = { # Convert to NumPy arrays (if possible)
key: value.cpu().numpy() for key, value in features_dict.items()
}
np.save(path, numpy_features_dict, allow_pickle=True)
else:
raise ValueError(f'Unsupported features dict extension "{ext}"')
def load_features_dict(path: str) -> Dict[str, torch.Tensor]:
logger.info(f'loading features from "{path}"')
_, ext = os.path.splitext(path)
if ext == ".pt":
features_dict = torch.load(path)
elif ext == ".npy":
numpy_features_dict = np.load(path, allow_pickle=True).item()
features_dict = {key: torch.from_numpy(value) for key, value in numpy_features_dict.items()}
else:
raise ValueError(f'Unsupported features dict extension "{ext}"')
for key, value in features_dict.items():
assert isinstance(key, str)
assert isinstance(value, torch.Tensor)
return features_dict
def average_metrics(eval_metrics_dict: dict[Any, dict[str, torch.Tensor]], ignore_keys: List[str] = []):
"""
Function that computes the average and the std on a metrics dict.
A linear evaluation dictionary contains "best_classifier",
so this specific key is removed for computing aggregated metrics.
"""
output_metrics_dict = {}
metrics = [metric for metric in eval_metrics_dict[0].keys() if metric not in ignore_keys]
for metric in metrics:
stats_tensor = torch.tensor([stat[metric] for stat in eval_metrics_dict.values()])
output_metrics_dict[metric + "_mean"] = stats_tensor.mean().item()
output_metrics_dict[metric + "_std"] = torch.std(stats_tensor).item()
return output_metrics_dict
def save_results(
preds: torch.Tensor,
target: torch.Tensor,
output_dir: str,
filename_suffix: Optional[str] = None,
) -> None:
"""
Helper to save predictions from a model and their associated targets, aligned by their index
"""
filename_suffix = "" if filename_suffix is None else f"_{filename_suffix}"
preds_filename = f"preds{filename_suffix}.npy"
target_filename = f"target{filename_suffix}.npy"
preds_path = os.path.join(output_dir, preds_filename)
target_path = os.path.join(output_dir, target_filename)
logger.info(f"Saving to {preds_path}")
np.save(preds_path, preds.cpu().numpy())
logger.info(f"Saving to {target_path}")
np.save(target_path, target.cpu().numpy())
|