|
|
|
|
|
|
|
|
|
|
| import logging
|
| import sys
|
| import time
|
| from dataclasses import dataclass, field
|
| from functools import partial
|
| from typing import Any, Dict, List, Optional
|
|
|
| import torch
|
| import torch.backends.cudnn as cudnn
|
| import torch.distributed
|
| from omegaconf import MISSING
|
| from torch import nn
|
| from torch.utils.data import TensorDataset
|
| from torchmetrics import MetricTracker
|
|
|
| from dinov3.data import SamplerType, make_data_loader, make_dataset
|
| from dinov3.data.adapters import DatasetWithEnumeratedTargets
|
| from dinov3.data.transforms import CROP_DEFAULT_SIZE, get_target_transform, make_classification_eval_transform
|
| from dinov3.distributed import get_rank, get_world_size
|
| from dinov3.eval.data import (
|
| create_train_dataset_dict,
|
| extract_features_for_dataset_dict,
|
| get_num_classes,
|
| split_train_val_datasets,
|
| )
|
| from dinov3.eval.helpers import args_dict_to_dataclass, cli_parser, write_results
|
| from dinov3.eval.metrics import ClassificationMetricType, build_classification_metric
|
| from dinov3.eval.setup import ModelConfig, load_model_and_context
|
| from dinov3.eval.utils import average_metrics, evaluate, extract_features
|
| from dinov3.eval.utils import save_results as default_save_results_func
|
| from dinov3.run.init import job_context
|
| from dinov3.utils.dtype import as_torch_dtype
|
|
|
| logger = logging.getLogger("fairvit")
|
|
|
|
|
| RESULTS_FILENAME = "results-log-regression.csv"
|
| MAIN_METRICS = ["top-1(_mean)?"]
|
|
|
|
|
| try:
|
| from sklearnex import patch_sklearn
|
|
|
| patch_sklearn()
|
| except ImportError:
|
| logger.warning("Can't import sklearnex. If installed, that speeds up scikit-learn 10-100x")
|
|
|
| try:
|
| from sklearn.linear_model import LogisticRegression as sklearnLogisticRegression
|
| from sklearn.multiclass import OneVsRestClassifier
|
| except ImportError:
|
| logger.warning("Can't import scikit-learn. This is necessary for evaluating log regression")
|
| raise ImportError
|
|
|
|
|
| C_POWER_RANGE = torch.linspace(-6, 5, 45)
|
| _CPU_DEVICE = torch.device("cpu")
|
|
|
|
|
| @dataclass
|
| class TrainConfig:
|
| dataset: str = MISSING
|
| val_dataset: Optional[str] = None
|
| val_metric_type: ClassificationMetricType = ClassificationMetricType.MEAN_ACCURACY
|
| batch_size: int = 256
|
| num_workers: int = 5
|
| tol: float = 1e-12
|
| train_features_device: str = "cpu"
|
| train_dtype: str = "float64"
|
| max_train_iters: int = 1_000
|
|
|
|
|
| @dataclass
|
| class EvalConfig:
|
| test_dataset: str = MISSING
|
| batch_size: int | None = None
|
| num_workers: int = 5
|
| test_metric_type: Optional[ClassificationMetricType] = None
|
|
|
|
|
| @dataclass
|
| class TransformConfig:
|
| resize_size: int = CROP_DEFAULT_SIZE
|
| crop_size: int = CROP_DEFAULT_SIZE
|
|
|
|
|
| @dataclass
|
| class FewShotConfig:
|
| enable: bool = False
|
| k_or_percent: Optional[float] = None
|
| n_tries: int = 1
|
|
|
|
|
| @dataclass
|
| class LogregEvalConfig:
|
| model: ModelConfig
|
| train: TrainConfig = field(default_factory=TrainConfig)
|
| eval: EvalConfig = field(default_factory=EvalConfig)
|
| transform: TransformConfig = field(default_factory=TransformConfig)
|
| few_shot: FewShotConfig = field(default_factory=FewShotConfig)
|
| save_results: bool = False
|
| output_dir: str = ""
|
|
|
|
|
| class LogRegModule(nn.Module):
|
| def __init__(self, C, multi_label=False, logreg_config=TrainConfig):
|
| super().__init__()
|
| self.dtype = as_torch_dtype(logreg_config.train_dtype)
|
| self.device = torch.device(logreg_config.train_features_device)
|
| assert self.device == _CPU_DEVICE, f"SKLearn can only work on CPU device, got {self.device}"
|
| self.estimator = sklearnLogisticRegression(
|
| penalty="l2",
|
| solver="lbfgs",
|
| C=C,
|
| max_iter=logreg_config.max_train_iters,
|
| n_jobs=-1,
|
| tol=logreg_config.tol,
|
| )
|
| if multi_label:
|
| self.estimator = OneVsRestClassifier(self.estimator, n_jobs=-1)
|
|
|
| def forward(self, samples, targets):
|
| samples_device = samples.device
|
| samples = samples.to(dtype=self.dtype, device=self.device)
|
| if self.device == _CPU_DEVICE:
|
| samples = samples.numpy()
|
| probas = self.estimator.predict_proba(samples)
|
| return {"preds": torch.from_numpy(probas).to(samples_device), "target": targets}
|
|
|
| def fit(self, train_features, train_labels):
|
| train_features = train_features.to(dtype=self.dtype, device=self.device)
|
| train_labels = train_labels.to(dtype=self.dtype, device=self.device)
|
| if self.device == _CPU_DEVICE:
|
|
|
| train_features = train_features.numpy()
|
| train_labels = train_labels.numpy()
|
| self.estimator.fit(train_features, train_labels)
|
|
|
|
|
| def evaluate_logreg_model(*, logreg_model, test_metric, test_data_loader, save_results_func=None):
|
| key = "metrics"
|
| postprocessors, metrics = {key: logreg_model}, {key: test_metric}
|
| _, eval_metrics, accumulated_results = evaluate(
|
| nn.Identity(),
|
| test_data_loader,
|
| postprocessors,
|
| metrics,
|
| torch.cuda.current_device(),
|
| accumulate_results=save_results_func is not None,
|
| )
|
| if save_results_func is not None:
|
| save_results_func(**accumulated_results[key])
|
| return eval_metrics
|
|
|
|
|
| def train_for_C(*, C, train_features, train_labels, logreg_config: TrainConfig):
|
| logreg_model = LogRegModule(C, multi_label=len(train_labels.shape) > 1, logreg_config=logreg_config)
|
| logreg_model.fit(train_features, train_labels)
|
| return logreg_model
|
|
|
|
|
| def sweep_C_values(
|
| *,
|
| train_features,
|
| train_labels,
|
| val_data_loader,
|
| val_metric,
|
| logreg_config: TrainConfig,
|
| ):
|
| metric_tracker = MetricTracker(val_metric, maximize=True)
|
| ALL_C = 10**C_POWER_RANGE
|
| logreg_models: Dict[float, Any] = {}
|
|
|
| train_features_device = torch.device(logreg_config.train_features_device)
|
| train_dtype = as_torch_dtype(logreg_config.train_dtype)
|
| train_features = train_features.to(dtype=train_dtype, device=train_features_device)
|
| train_labels = train_labels.to(device=train_features_device)
|
|
|
| for i in range(get_rank(), len(ALL_C), get_world_size()):
|
| C = ALL_C[i].item()
|
| logger.info(
|
| f"Training for C = {C:.4g}, dtype={train_dtype}, "
|
| f"features: {train_features.shape}, {train_features.dtype}, "
|
| f"labels: {train_labels.shape}, {train_labels.dtype}"
|
| )
|
| logreg_models[C] = train_for_C(
|
| C=C,
|
| train_features=train_features,
|
| train_labels=train_labels,
|
| logreg_config=logreg_config,
|
| )
|
|
|
| gather_list: List[Dict[float, Any]] = [{} for _ in range(get_world_size())]
|
| torch.distributed.all_gather_object(gather_list, logreg_models)
|
|
|
| for logreg_dict in gather_list:
|
| logreg_models.update(logreg_dict)
|
| gather_list.clear()
|
|
|
| for i in range(len(ALL_C)):
|
| metric_tracker.increment()
|
| C = ALL_C[i].item()
|
| evals = evaluate_logreg_model(
|
| logreg_model=logreg_models.pop(C),
|
| test_metric=metric_tracker,
|
| test_data_loader=val_data_loader,
|
| )
|
| logger.info(f"Trained for C = {C:.4g}, accuracies = {evals}")
|
| best_stats, which_epoch = metric_tracker.best_metric(return_step=True)
|
| best_stats_100 = {k: 100.0 * v for k, v in best_stats.items()}
|
| if which_epoch["top-1"] == i:
|
| best_C = C
|
| logger.info(f"Sweep best {best_stats_100}, best C = {best_C:.4g}")
|
|
|
| return best_stats, best_C
|
|
|
|
|
| def make_logreg_data_loader(batch_size: int, num_workers: int, features: torch.Tensor, labels: torch.Tensor):
|
| return make_data_loader(
|
| dataset=DatasetWithEnumeratedTargets(
|
| TensorDataset(features, labels), pad_dataset=True, num_replicas=get_world_size()
|
| ),
|
| batch_size=batch_size,
|
| num_workers=num_workers,
|
| sampler_type=SamplerType.DISTRIBUTED,
|
| drop_last=False,
|
| shuffle=False,
|
| )
|
|
|
|
|
| def get_best_logreg_with_features(
|
| *,
|
| train_features: torch.Tensor,
|
| train_labels: torch.Tensor,
|
| val_features: torch.Tensor,
|
| val_labels: torch.Tensor,
|
| val_metric,
|
| concatenate_train_val: bool,
|
| train_config: TrainConfig,
|
| ):
|
| val_data_loader = make_logreg_data_loader(
|
| train_config.batch_size, train_config.num_workers, val_features, val_labels
|
| )
|
| _, best_C_t = sweep_C_values(
|
| train_features=train_features,
|
| train_labels=train_labels,
|
| val_data_loader=val_data_loader,
|
| val_metric=val_metric,
|
| logreg_config=train_config,
|
| )
|
| if concatenate_train_val:
|
| logger.info("Best parameter found, concatenating features")
|
| train_features = torch.cat((train_features, val_features))
|
| train_labels = torch.cat((train_labels, val_labels))
|
|
|
| logger.info("Training final model")
|
|
|
| logreg_model = train_for_C(
|
| C=best_C_t,
|
| logreg_config=train_config,
|
| train_features=train_features,
|
| train_labels=train_labels,
|
| )
|
| return logreg_model
|
|
|
|
|
| def make_transform(config: TransformConfig):
|
| if config.resize_size / config.crop_size != 1:
|
| logger.warning(f"Default resize / crop ratio is 1, here we have {config.resize_size} / {config.crop_size}")
|
| transform = make_classification_eval_transform(resize_size=config.resize_size, crop_size=config.crop_size)
|
| return transform
|
|
|
|
|
| def make_train_val_datasets(train_config: TrainConfig, few_shot_config: FewShotConfig, transform):
|
| train_dataset = make_dataset(
|
| dataset_str=train_config.dataset,
|
| transform=transform,
|
| target_transform=get_target_transform(train_config.dataset),
|
| )
|
| if train_config.val_dataset is not None:
|
| val_dataset = make_dataset(
|
| dataset_str=train_config.val_dataset,
|
| transform=transform,
|
| target_transform=get_target_transform(train_config.val_dataset),
|
| )
|
| else:
|
| split_percentage = 0.01 if few_shot_config.enable else 0.1
|
| train_dataset, val_dataset = split_train_val_datasets(train_dataset, split_percentage=split_percentage)
|
|
|
| train_dataset_dict = create_train_dataset_dict(
|
| train_dataset,
|
| few_shot_eval=few_shot_config.enable,
|
| few_shot_k_or_percent=few_shot_config.k_or_percent,
|
| few_shot_n_tries=few_shot_config.n_tries,
|
| )
|
| num_classes = get_num_classes(train_dataset)
|
| return train_dataset_dict, val_dataset, num_classes
|
|
|
|
|
| def make_test_dataset_and_data_loader(model, config: EvalConfig, transform, gather_on_cpu: bool):
|
| test_dataset = make_dataset(
|
| dataset_str=config.test_dataset,
|
| transform=transform,
|
| target_transform=get_target_transform(config.test_dataset),
|
| )
|
| test_features, test_labels = extract_features(
|
| model, test_dataset, config.batch_size, config.num_workers, gather_on_cpu=gather_on_cpu
|
| )
|
| assert isinstance(config.batch_size, int)
|
| test_data_loader = make_logreg_data_loader(config.batch_size, config.num_workers, test_features, test_labels)
|
| return test_dataset, test_data_loader
|
|
|
|
|
| def eval_log_regression_with_model(*, model: torch.nn.Module, autocast_dtype, config: LogregEvalConfig):
|
| """
|
| Implements the "standard" process for log regression evaluation:
|
| The value of C is chosen by training on train_dataset and evaluating on
|
| val_dataset. Then, the final model is trained on a concatenation of
|
| train_dataset and val_dataset, and is evaluated on test_dataset.
|
| If there is no val_dataset, the value of C is the one that yields
|
| the best results on a random 10% subset of the train dataset
|
| """
|
| start = time.time()
|
| cudnn.benchmark = True
|
|
|
| transform = make_transform(config.transform)
|
| config.eval.batch_size = config.eval.batch_size or config.train.batch_size
|
|
|
|
|
| train_dataset_dict, val_dataset, num_classes = make_train_val_datasets(config.train, config.few_shot, transform)
|
|
|
|
|
| with torch.autocast("cuda", dtype=autocast_dtype):
|
| gather_on_cpu = torch.device(config.train.train_features_device) == _CPU_DEVICE
|
| train_data_dict = extract_features_for_dataset_dict(
|
| model, train_dataset_dict, config.train.batch_size, config.train.num_workers, gather_on_cpu=gather_on_cpu
|
| )
|
| logger.info("Choosing hyperparameters on the val dataset")
|
| val_features, val_labels = extract_features(
|
| model, val_dataset, config.train.batch_size, config.train.num_workers, gather_on_cpu=gather_on_cpu
|
| )
|
| test_dataset, test_data_loader = make_test_dataset_and_data_loader(model, config.eval, transform, gather_on_cpu)
|
|
|
|
|
| model.cpu()
|
| torch.cuda.empty_cache()
|
|
|
|
|
| val_metric = build_classification_metric(config.train.val_metric_type, num_classes=num_classes, dataset=val_dataset)
|
| test_metric_type = config.eval.test_metric_type or config.train.val_metric_type
|
| test_metric = build_classification_metric(test_metric_type, num_classes=num_classes, dataset=test_dataset)
|
|
|
|
|
| save_results_func = None
|
| if config.save_results:
|
| save_results_func = partial(default_save_results_func, output_dir=config.output_dir)
|
|
|
| results_dict = {}
|
| for _try in train_data_dict.keys():
|
| logreg_model = get_best_logreg_with_features(
|
| train_features=train_data_dict[_try]["train_features"],
|
| train_labels=train_data_dict[_try]["train_labels"],
|
| val_features=val_features,
|
| val_labels=val_labels,
|
| val_metric=val_metric,
|
| concatenate_train_val=not config.few_shot.enable,
|
| train_config=config.train,
|
| )
|
| if len(train_data_dict) > 1 and save_results_func is not None:
|
| split_results_saver = partial(save_results_func, filename_suffix=str(_try))
|
| else:
|
| split_results_saver = save_results_func
|
|
|
| eval_metrics = evaluate_logreg_model(
|
| logreg_model=logreg_model,
|
| test_metric=test_metric.clone(),
|
| test_data_loader=test_data_loader,
|
| save_results_func=split_results_saver,
|
| )
|
| results_dict[_try] = {k: v.item() * 100.0 for k, v in eval_metrics["metrics"].items()}
|
|
|
| if len(train_data_dict) > 1:
|
| results_dict = average_metrics(results_dict)
|
| else:
|
| results_dict = {**results_dict[_try]}
|
|
|
| logger.info(f"Log regression evaluation done in {int(time.time() - start)}s")
|
| logger.info("Training of the supervised logistic regression on frozen features completed.")
|
| results_string = "\n".join([f"{k}: {results_dict[k]:.4g}" for k in sorted(results_dict.keys())])
|
| logger.info("Results:\n" + results_string)
|
|
|
| torch.distributed.barrier()
|
| return results_dict
|
|
|
|
|
| def benchmark_launcher(eval_args: dict[str, object]) -> dict[str, Any]:
|
| """Initialization of distributed and logging are preconditions for this method"""
|
| dataclass_config, output_dir = args_dict_to_dataclass(eval_args=eval_args, config_dataclass=LogregEvalConfig)
|
| model, model_context = load_model_and_context(dataclass_config.model, output_dir=output_dir)
|
| results_dict = eval_log_regression_with_model(
|
| model=model, config=dataclass_config, autocast_dtype=model_context["autocast_dtype"]
|
| )
|
| write_results(results_dict, output_dir, RESULTS_FILENAME)
|
| return results_dict
|
|
|
|
|
| def main(argv=None):
|
| if argv is None:
|
| argv = sys.argv[1:]
|
| eval_args = cli_parser(argv)
|
| with job_context(output_dir=eval_args["output_dir"]):
|
| benchmark_launcher(eval_args=eval_args)
|
| return 0
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|