File size: 1,125 Bytes
7382c66 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 | import os
import torch
from contextlib import nullcontext
from tqdm import tqdm
from da2 import (
prepare_to_run,
load_model
)
from eval.utils import run_evaluation
def eval(model, config, accelerator, output_dir):
model = model.eval()
eval_datasets = config['evaluation']['datasets']
if accelerator.is_main_process:
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)
with autocast_ctx, torch.no_grad():
for dataset_name in eval_datasets.keys():
metrics = run_evaluation(model, config, dataset_name, output_dir, accelerator.device)
for metric_name in config['evaluation']['metric_show']:
config['env']['logger'].info(f"\033[92mEVAL --> {dataset_name}: {config['evaluation']['metric_show'][metric_name]} = {metrics[metric_name]}.\033[0m")
if __name__ == '__main__':
config, accelerator, output_dir = prepare_to_run()
model = load_model(config, accelerator)
eval(model, config, accelerator, output_dir)
|