| import os |
| import time |
| import numpy as np |
| import math |
| import sys |
| from typing import Iterable, Optional |
| import torch |
| from datasets.mixup import Mixup |
| from timm.utils import accuracy, ModelEma |
| import utils |
| from scipy.special import softmax |
|
|
|
|
| def train_class_batch(model, samples, target, criterion): |
| outputs = model(samples) |
| outputs = torch.permute(outputs, (0, 2, 1)) |
| loss = criterion(outputs, target) |
| return loss, outputs |
|
|
|
|
| def get_loss_scale_for_deepspeed(model): |
| optimizer = model.optimizer |
| return optimizer.loss_scale if hasattr(optimizer, "loss_scale") else optimizer.cur_scale |
|
|
|
|
| def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, |
| data_loader: Iterable, optimizer: torch.optim.Optimizer, |
| device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, |
| model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None, |
| start_steps=None, lr_schedule_values=None, wd_schedule_values=None, |
| num_training_steps_per_epoch=None, update_freq=None): |
| model.train(True) |
| metric_logger = utils.MetricLogger(delimiter=" ") |
| metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) |
| metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) |
|
|
| metric_logger.add_meter('train_acc', utils.SmoothedValue(window_size=1, fmt='{value:.4f}')) |
|
|
| header = 'Epoch: [{}]'.format(epoch) |
| print_freq = 10 |
|
|
| if loss_scaler is None: |
| model.zero_grad() |
| model.micro_steps = 0 |
| else: |
| optimizer.zero_grad() |
|
|
| for data_iter_step, (samples, targets, _, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): |
| step = data_iter_step // update_freq |
| if step >= num_training_steps_per_epoch: |
| continue |
| it = start_steps + step |
| |
| if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0: |
| for i, param_group in enumerate(optimizer.param_groups): |
| if lr_schedule_values is not None: |
| param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"] |
| if wd_schedule_values is not None and param_group["weight_decay"] > 0: |
| param_group["weight_decay"] = wd_schedule_values[it] |
|
|
| samples = samples.to(device, non_blocking=True) |
| targets = targets.to(device, non_blocking=True) |
|
|
| if loss_scaler is None: |
| samples = samples.half() |
| loss, output = train_class_batch( |
| model, samples, targets, criterion) |
| else: |
| with torch.cuda.amp.autocast(): |
| loss, output = train_class_batch( |
| model, samples, targets, criterion) |
|
|
| curr_output = torch.permute(output, (0, 2, 1)) |
| preds = torch.argmax(curr_output, dim=-1) |
| acc_cats = torch.mean((preds == targets).to(torch.float), dim=0) |
| acc1 = torch.mean(acc_cats) |
| metric_logger.update(train_acc=acc1) |
|
|
| loss_value = loss.item() |
|
|
| if not math.isfinite(loss_value): |
| print("Loss is {}, stopping training".format(loss_value)) |
| sys.exit(1) |
|
|
| if loss_scaler is None: |
| loss /= update_freq |
| model.backward(loss) |
| model.step() |
|
|
| if (data_iter_step + 1) % update_freq == 0: |
| |
| |
| if model_ema is not None: |
| model_ema.update(model) |
| grad_norm = None |
| loss_scale_value = get_loss_scale_for_deepspeed(model) |
| else: |
| |
| is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order |
| loss /= update_freq |
| grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm, |
| parameters=model.parameters(), create_graph=is_second_order, |
| update_grad=(data_iter_step + 1) % update_freq == 0) |
| if (data_iter_step + 1) % update_freq == 0: |
| optimizer.zero_grad() |
| if model_ema is not None: |
| model_ema.update(model) |
| loss_scale_value = loss_scaler.state_dict()["scale"] |
|
|
| torch.cuda.synchronize() |
|
|
| class_acc = acc1 |
| |
| metric_logger.update(loss=loss_value) |
| metric_logger.update(class_acc=class_acc) |
| metric_logger.update(loss_scale=loss_scale_value) |
| min_lr = 10. |
| max_lr = 0. |
| for group in optimizer.param_groups: |
| min_lr = min(min_lr, group["lr"]) |
| max_lr = max(max_lr, group["lr"]) |
|
|
| metric_logger.update(lr=max_lr) |
| metric_logger.update(min_lr=min_lr) |
| weight_decay_value = None |
| for group in optimizer.param_groups: |
| if group["weight_decay"] > 0: |
| weight_decay_value = group["weight_decay"] |
| metric_logger.update(weight_decay=weight_decay_value) |
| metric_logger.update(grad_norm=grad_norm) |
|
|
| if log_writer is not None: |
| log_writer.update(loss=loss_value, head="loss") |
| log_writer.update(class_acc=class_acc, head="loss") |
| log_writer.update(loss_scale=loss_scale_value, head="opt") |
| log_writer.update(lr=max_lr, head="opt") |
| log_writer.update(min_lr=min_lr, head="opt") |
| log_writer.update(weight_decay=weight_decay_value, head="opt") |
| log_writer.update(grad_norm=grad_norm, head="opt") |
|
|
| |
|
|
| log_writer.set_step() |
|
|
| |
| metric_logger.synchronize_between_processes() |
| print('** Train Acc {top1.global_avg:.4f}'.format(top1=metric_logger.train_acc)) |
| print("Averaged stats:", metric_logger) |
| return {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
|
|
|
|
| @torch.no_grad() |
| def validation_one_epoch(data_loader, model, device, fp32=False): |
| criterion = torch.nn.CrossEntropyLoss() |
|
|
| metric_logger = utils.MetricLogger(delimiter=" ") |
| header = 'Val:' |
|
|
| |
| model.eval() |
|
|
| for batch in metric_logger.log_every(data_loader, 10, header): |
| videos = batch[0] |
| target = batch[1] |
| videos = videos.to(device, non_blocking=True) |
| target = target.to(device, non_blocking=True) |
|
|
| |
| with torch.cuda.amp.autocast(): |
| output = model(videos) |
| output = torch.permute(output, (0, 2, 1)) |
| loss = criterion(output, target) |
|
|
| |
| output = torch.permute(output, (0, 2, 1)) |
| preds = torch.argmax(output, dim=-1) |
| acc_cats = torch.mean((preds == target).to(torch.float), dim=0) |
| acc1 = torch.mean(acc_cats) |
|
|
| batch_size = videos.shape[0] |
| metric_logger.update(loss=loss.item()) |
| metric_logger.meters['acc1'].update(acc1, n=batch_size) |
| |
| |
| metric_logger.synchronize_between_processes() |
| print('* Acc@1 {top1.global_avg:.4f} loss {losses.global_avg:.3f}' |
| .format(top1=metric_logger.acc1, |
| |
| losses=metric_logger.loss)) |
|
|
| return {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
|
|
|
|
| @torch.no_grad() |
| def final_test(data_loader, model, device, file): |
| criterion = torch.nn.CrossEntropyLoss() |
|
|
| metric_logger = utils.MetricLogger(delimiter=" ") |
| header = 'Test:' |
|
|
| num_cats = 20 |
| for i in range(num_cats): |
| metric_logger.add_meter(f'test_acc_cat_{i}', utils.SmoothedValue(window_size=1, fmt='{value:.4f}')) |
| |
| |
| model.eval() |
| final_result = [] |
| |
| for batch in metric_logger.log_every(data_loader, 10, header): |
| videos = batch[0] |
| target = batch[1] |
| ids = batch[2] |
| chunk_nb = batch[3] |
| split_nb = batch[4] |
| videos = videos.to(device, non_blocking=True) |
| target = target.to(device, non_blocking=True) |
|
|
| |
| with torch.cuda.amp.autocast(): |
| output = model(videos) |
| output = torch.permute(output, (0, 2, 1)) |
| loss = criterion(output, target) |
|
|
| for i in range(output.size(0)): |
| string = "{} {} {} {} {}\n".format(ids[i], \ |
| str(output.data[i].cpu().numpy().tolist()), \ |
| str(target[i].cpu().numpy()), \ |
| str(chunk_nb[i].cpu().numpy()), \ |
| str(split_nb[i].cpu().numpy())) |
| final_result.append(string) |
|
|
| |
| output = torch.permute(output, (0, 2, 1)) |
| preds = torch.argmax(output, dim=-1) |
| acc_cats = torch.mean((preds == target).to(torch.float), dim=0) |
| acc1 = torch.mean(acc_cats) |
|
|
| batch_size = videos.shape[0] |
| metric_logger.update(loss=loss.item()) |
| metric_logger.meters['acc1'].update(acc1, n=batch_size) |
| |
|
|
| if not os.path.exists(file): |
| os.mknod(file) |
| with open(file, 'w') as f: |
| f.write("{}\n".format(acc1)) |
| for line in final_result: |
| f.write(line) |
| |
| metric_logger.synchronize_between_processes() |
| print('* Acc@1 {top1.global_avg:.4f} loss {losses.global_avg:.3f}' |
| .format(top1=metric_logger.acc1, |
| |
| losses=metric_logger.loss)) |
| |
| for i in range(num_cats): |
| print('** Test Acc {cat} {top1.global_avg:.4f}'.format(cat=i, top1=metric_logger.meters[f'test_acc_cat_{i}'])) |
|
|
|
|
| return {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
|
|
|
|
| def merge(eval_path, num_tasks): |
| dict_feats = {} |
| dict_label = {} |
| dict_pos = {} |
| print("Reading individual output files") |
|
|
| for x in range(num_tasks): |
| file = os.path.join(eval_path, str(x) + '.txt') |
| lines = open(file, 'r').readlines()[1:] |
| for line in lines: |
| line = line.strip() |
| name = line.split('[')[0] |
| label = line.split(']')[1].split(' ')[1] |
| chunk_nb = line.split(']')[1].split(' ')[2] |
| split_nb = line.split(']')[1].split(' ')[3] |
| data = np.fromstring(line.split('[')[1].split(']')[0], dtype=np.float, sep=',') |
| data = softmax(data) |
| if not name in dict_feats: |
| dict_feats[name] = [] |
| dict_label[name] = 0 |
| dict_pos[name] = [] |
| if chunk_nb + split_nb in dict_pos[name]: |
| continue |
| dict_feats[name].append(data) |
| dict_pos[name].append(chunk_nb + split_nb) |
| dict_label[name] = label |
| print("Computing final results") |
|
|
| input_lst = [] |
| print(len(dict_feats)) |
| for i, item in enumerate(dict_feats): |
| input_lst.append([i, item, dict_feats[item], dict_label[item]]) |
| from multiprocessing import Pool |
| p = Pool(64) |
| ans = p.map(compute_video, input_lst) |
| top1 = [x[1] for x in ans] |
| top5 = [x[2] for x in ans] |
| pred = [x[0] for x in ans] |
| label = [x[3] for x in ans] |
| final_top1 ,final_top5 = np.mean(top1), np.mean(top5) |
| return final_top1*100 ,final_top5*100 |
|
|
| def compute_video(lst): |
| i, video_id, data, label = lst |
| feat = [x for x in data] |
| feat = np.mean(feat, axis=0) |
| pred = np.argmax(feat) |
| top1 = (int(pred) == int(label)) * 1.0 |
| top5 = (int(label) in np.argsort(-feat)[:5]) * 1.0 |
| return [pred, top1, top5, int(label)] |
|
|