File size: 6,339 Bytes
456c878 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
# --------------------------------------------------------
# Make evaluation along with training. Swith time with space/computation.
# Licensed under The MIT License [see LICENSE for details]
# Written by Peng Zheng
# --------------------------------------------------------
import os
from glob import glob
from time import sleep
import argparse
import torch
from config import Config
from models.birefnet import BiRefNet
from dataset import MyData
from evaluation.valid import valid
parser = argparse.ArgumentParser(description='')
parser.add_argument('--cuda_idx', default=-1, type=int)
parser.add_argument('--val_step', default=5*1, type=int)
parser.add_argument('--program_id', default=0, type=int)
# id-th one of this program will evaluate val_step * N + program_id -th epoch model.
# Test more models, number of programs == number of GPUs: [models[num_all - program_id_1], models[num_all - program_id_max(n, val_step-1)], ...] programs with id>val_step will speed up the evaluation on (val_step - id)%val_step -th epoch models.
# Test fastest, only sequentially searched val_step*N -th models -- set all program_id as the same.
parser.add_argument('--testsets', default='DIS-VD+DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4', type=str)
args_eval = parser.parse_args()
args_eval.program_id = (args_eval.val_step - args_eval.program_id) % args_eval.val_step
config = Config()
config.only_S_MAE = True
device = 'cpu' if args_eval.cuda_idx < 0 else 'cuda:{}'.format(args_eval.cuda_idx)
ckpt_dir, testsets = glob(os.path.join('ckpt', '*'))[0], args_eval.testsets
def validate_model(model, test_loaders, epoch):
num_image_testset_all = {'DIS-VD': 470, 'DIS-TE1': 500, 'DIS-TE2': 500, 'DIS-TE3': 500, 'DIS-TE4': 500}
num_image_testset = {}
for testset in testsets.split('+'):
if 'DIS-TE' in testset:
num_image_testset[testset] = num_image_testset_all[testset]
weighted_scores = {'f_max': 0, 'sm': 0, 'e_max': 0, 'mae': 0}
len_all_data_loaders = 0
model.epoch = epoch
for testset, data_loader_test in test_loaders.items():
print('Validating {}...'.format(testset))
performance_dict = valid(
model,
data_loader_test,
pred_dir='.',
method=ckpt_dir.split('/')[-1] if ckpt_dir.split('/')[-1].strip('.').strip('/') else 'tmp_val',
testset=testset,
only_S_MAE=config.only_S_MAE,
device=device
)
print('Test set: {}:'.format(testset))
if config.only_S_MAE:
print('Smeasure: {:.4f}, MAE: {:.4f}'.format(
performance_dict['sm'], performance_dict['mae']
))
else:
print('Fmax: {:.4f}, Fwfm: {:.4f}, Smeasure: {:.4f}, Emean: {:.4f}, MAE: {:.4f}'.format(
performance_dict['f_max'], performance_dict['f_wfm'], performance_dict['sm'], performance_dict['e_mean'], performance_dict['mae']
))
if '-TE' in testset:
for metric in ['sm', 'mae'] if config.only_S_MAE else ['f_max', 'f_wfm', 'sm', 'e_mean', 'mae']:
weighted_scores[metric] += performance_dict[metric] * len(data_loader_test)
len_all_data_loaders += len(data_loader_test)
print('Weighted Scores:')
for metric, score in weighted_scores.items():
if score:
print('\t{}: {:.4f}.'.format(metric, score / len_all_data_loaders))
@torch.no_grad()
def main():
config = Config()
# Dataloader
test_loaders = {}
for testset in testsets.split('+'):
dataset = MyData(
datasets=testset,
image_size=config.size, is_train=False
)
_data_loader_test = torch.utils.data.DataLoader(
dataset=dataset, batch_size=config.batch_size_valid, num_workers=min(config.num_workers, config.batch_size_valid),
pin_memory=device != 'cpu', shuffle=False
)
print(len(_data_loader_test), "batches of valid dataloader {} have been created.".format(testset))
test_loaders[testset] = _data_loader_test
# Model, 3070MiB GPU memory for inference
model = BiRefNet(bb_pretrained=False).to(device)
models_evaluated = []
continous_sleep_time = 0
while True:
if (
(models_evaluated and continous_sleep_time > 60*60*2) or
(not models_evaluated and continous_sleep_time > 60*60*24)
):
# If no ckpt has been saved, we wait for 24h;
# elif some ckpts have been saved, we wait for 2h for new ones;
# else: exit this waiting.
print('Exiting the waiting for evaluation.')
break
models_evaluated_record = 'tmp_models_evaluated.txt'
if os.path.exists(models_evaluated_record):
with open(models_evaluated_record, 'r') as f:
models_evaluated_global = f.read().splitlines()
else:
models_evaluated_global = []
models_detected = [
m for idx_m, m in enumerate(sorted(
glob(os.path.join(ckpt_dir, '*.pth')),
key=lambda x: int(x.rstrip('.pth').split('epoch_')[-1]), reverse=True
)) if idx_m % args_eval.val_step == args_eval.program_id and m not in models_evaluated + models_evaluated_global
]
if models_detected:
from time import time
time_st = time()
# register the evaluated models
model_not_evaluated_latest = models_detected[0]
with open('tmp_models_evaluated.txt', 'a') as f:
f.write(model_not_evaluated_latest + '\n')
models_evaluated.append(model_not_evaluated_latest)
print('Loading {} for validation...'.format(model_not_evaluated_latest))
# evaluate the current model
state_dict = torch.load(model_not_evaluated_latest, map_location=device)
model.load_state_dict(state_dict, strict=False)
validate_model(model, test_loaders, int(model_not_evaluated_latest.rstrip('.pth').split('epoch_')[-1]))
continous_sleep_time = 0
print('Duration of this evaluation:', time() - time_st)
else:
sleep_interval = 60 * 2
sleep(sleep_interval)
continous_sleep_time += sleep_interval
continue
if __name__ == '__main__':
main()
|