| import os | |
| import torch | |
| from contextlib import nullcontext | |
| from tqdm import tqdm | |
| from da2 import ( | |
| prepare_to_run, | |
| load_model | |
| ) | |
| from eval.utils import run_evaluation | |
| def eval(model, config, accelerator, output_dir): | |
| model = model.eval() | |
| eval_datasets = config['evaluation']['datasets'] | |
| if accelerator.is_main_process: | |
| if torch.backends.mps.is_available(): | |
| autocast_ctx = nullcontext() | |
| else: | |
| autocast_ctx = torch.autocast(accelerator.device.type) | |
| with autocast_ctx, torch.no_grad(): | |
| for dataset_name in eval_datasets.keys(): | |
| metrics = run_evaluation(model, config, dataset_name, output_dir, accelerator.device) | |
| for metric_name in config['evaluation']['metric_show']: | |
| config['env']['logger'].info(f"\033[92mEVAL --> {dataset_name}: {config['evaluation']['metric_show'][metric_name]} = {metrics[metric_name]}.\033[0m") | |
| if __name__ == '__main__': | |
| config, accelerator, output_dir = prepare_to_run() | |
| model = load_model(config, accelerator) | |
| eval(model, config, accelerator, output_dir) | |