File size: 1,125 Bytes
7382c66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import os
import torch
from contextlib import nullcontext
from tqdm import tqdm
from da2 import (
    prepare_to_run,
    load_model
)
from eval.utils import run_evaluation


def eval(model, config, accelerator, output_dir):
    model = model.eval()
    eval_datasets = config['evaluation']['datasets']
    if accelerator.is_main_process:
        if torch.backends.mps.is_available():
            autocast_ctx = nullcontext()
        else:
            autocast_ctx = torch.autocast(accelerator.device.type)
        with autocast_ctx, torch.no_grad():
            for dataset_name in eval_datasets.keys():
                metrics = run_evaluation(model, config, dataset_name, output_dir, accelerator.device)
                for metric_name in config['evaluation']['metric_show']:
                    config['env']['logger'].info(f"\033[92mEVAL --> {dataset_name}: {config['evaluation']['metric_show'][metric_name]} = {metrics[metric_name]}.\033[0m")

if __name__ == '__main__':
    config, accelerator, output_dir = prepare_to_run()
    model = load_model(config, accelerator)
    eval(model, config, accelerator, output_dir)