| |
| |
| |
| |
| import os |
| import sys |
| current_file_path = os.path.abspath(__file__) |
| parent_dir = os.path.dirname(os.path.dirname(current_file_path)) |
| project_root_dir = os.path.dirname(parent_dir) |
| sys.path.append(parent_dir) |
| sys.path.append(project_root_dir) |
|
|
| import pickle |
| import datetime |
| import logging |
| import numpy as np |
| from copy import deepcopy |
| from collections import defaultdict |
| from tqdm import tqdm |
| import time |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.optim as optim |
| from torch.nn import DataParallel |
| from torch.utils.tensorboard import SummaryWriter |
| from metrics.base_metrics_class import Recorder |
| from torch.optim.swa_utils import AveragedModel, SWALR |
| from torch import distributed as dist |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| from sklearn import metrics |
| from metrics.utils import get_test_metrics |
| from metrics.base_metrics_class import calculate_acc_for_test |
|
|
| FFpp_pool=['FaceForensics++','FF-DF','FF-F2F','FF-FS','FF-NT'] |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
| class Trainer(object): |
| def __init__( |
| self, |
| config, |
| model, |
| optimizer, |
| scheduler, |
| logger, |
| metric_scoring='auc', |
| time_now = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S'), |
| swa_model=None |
| ): |
| |
| if config is None or model is None or optimizer is None or logger is None: |
| raise ValueError("config, model, optimizier, logger, and tensorboard writer must be implemented") |
|
|
| self.config = config |
| self.model = model |
| self.optimizer = optimizer |
| self.scheduler = scheduler |
| self.swa_model = swa_model |
| self.writers = {} |
| self.logger = logger |
| self.metric_scoring = metric_scoring |
| |
| self.best_metrics_all_time = defaultdict( |
| lambda: defaultdict(lambda: float('-inf') |
| if self.metric_scoring != 'eer' else float('inf')) |
| ) |
| self.speed_up() |
|
|
| |
| self.timenow = time_now |
| |
| if 'task_target' not in config: |
| self.log_dir = os.path.join( |
| self.config['log_dir'], |
| self.config['model_name'] + '_' + self.timenow |
| ) |
| else: |
| task_str = f"_{config['task_target']}" if config['task_target'] is not None else "" |
| self.log_dir = os.path.join( |
| self.config['log_dir'], |
| self.config['model_name'] + task_str + '_' + self.timenow |
| ) |
| os.makedirs(self.log_dir, exist_ok=True) |
|
|
| def get_writer(self, phase, dataset_key, metric_key): |
| writer_key = f"{phase}-{dataset_key}-{metric_key}" |
| if writer_key not in self.writers: |
| |
| writer_path = os.path.join( |
| self.log_dir, |
| phase, |
| dataset_key, |
| metric_key, |
| "metric_board" |
| ) |
| os.makedirs(writer_path, exist_ok=True) |
| |
| self.writers[writer_key] = SummaryWriter(writer_path) |
| return self.writers[writer_key] |
|
|
|
|
| def speed_up(self): |
| self.model.to(device) |
| self.model.device = device |
| if self.config['ddp'] == True: |
| num_gpus = torch.cuda.device_count() |
| print(f'avai gpus: {num_gpus}') |
| |
| self.model = DDP(self.model, device_ids=[self.config['local_rank']],find_unused_parameters=True, output_device=self.config['local_rank']) |
| |
|
|
| def setTrain(self): |
| self.model.train() |
| self.train = True |
|
|
| def setEval(self): |
| self.model.eval() |
| self.train = False |
|
|
| def load_ckpt(self, model_path): |
| if os.path.isfile(model_path): |
| saved = torch.load(model_path, map_location='cpu') |
| suffix = model_path.split('.')[-1] |
| if suffix == 'p': |
| self.model.load_state_dict(saved.state_dict()) |
| else: |
| self.model.load_state_dict(saved) |
| self.logger.info('Model found in {}'.format(model_path)) |
| else: |
| raise NotImplementedError( |
| "=> no model found at '{}'".format(model_path)) |
|
|
| def save_ckpt(self, phase, dataset_key,ckpt_info=None,ckpt_name="ckpt_best.pth"): |
| save_dir = os.path.join(self.log_dir, phase, dataset_key) |
| os.makedirs(save_dir, exist_ok=True) |
| |
| save_path = os.path.join(save_dir, ckpt_name) |
| if self.config['ddp'] == True: |
| torch.save(self.model.state_dict(), save_path) |
| else: |
| if 'svdd' in self.config['model_name']: |
| torch.save({'R': self.model.R, |
| 'c': self.model.c, |
| 'state_dict': self.model.state_dict(),}, save_path) |
| else: |
| torch.save(self.model.state_dict(), save_path) |
| self.logger.info(f"Checkpoint saved to {save_path}, current ckpt is {ckpt_info}") |
|
|
| def save_swa_ckpt(self): |
| save_dir = self.log_dir |
| os.makedirs(save_dir, exist_ok=True) |
| ckpt_name = f"swa.pth" |
| save_path = os.path.join(save_dir, ckpt_name) |
| torch.save(self.swa_model.state_dict(), save_path) |
| self.logger.info(f"SWA Checkpoint saved to {save_path}") |
|
|
|
|
| def save_feat(self, phase, fea, dataset_key): |
| save_dir = os.path.join(self.log_dir, phase, dataset_key) |
| os.makedirs(save_dir, exist_ok=True) |
| features = fea |
| feat_name = f"feat_best.npy" |
| save_path = os.path.join(save_dir, feat_name) |
| np.save(save_path, features) |
| self.logger.info(f"Feature saved to {save_path}") |
|
|
| def save_data_dict(self, phase, data_dict, dataset_key): |
| if self.config['local_rank'] != 0: |
| return |
| save_dir = os.path.join(self.log_dir, phase, dataset_key) |
| os.makedirs(save_dir, exist_ok=True) |
| file_path = os.path.join(save_dir, f'data_dict_{phase}.pickle') |
| with open(file_path, 'wb') as file: |
| pickle.dump(data_dict, file) |
| self.logger.info(f"data_dict saved to {file_path}") |
|
|
| def save_metrics(self, phase, metric_one_dataset, dataset_key): |
| save_dir = os.path.join(self.log_dir, phase, dataset_key) |
| os.makedirs(save_dir, exist_ok=True) |
| file_path = os.path.join(save_dir, 'metric_dict_best.pickle') |
| with open(file_path, 'wb') as file: |
| pickle.dump(metric_one_dataset, file) |
| self.logger.info(f"Metrics saved to {file_path}") |
|
|
| def train_step(self,data_dict): |
| if self.config['optimizer']['type']=='sam': |
| for i in range(2): |
| predictions = self.model(data_dict) |
| losses = self.model.get_losses(data_dict, predictions) |
| if i == 0: |
| pred_first = predictions |
| losses_first = losses |
| self.optimizer.zero_grad() |
| losses['overall'].backward() |
| if i == 0: |
| self.optimizer.first_step(zero_grad=True) |
| else: |
| self.optimizer.second_step(zero_grad=True) |
| return losses_first, pred_first |
| else: |
|
|
| predictions = self.model(data_dict) |
| if type(self.model) is DDP: |
| losses = self.model.module.get_losses(data_dict, predictions) |
| else: |
| losses = self.model.get_losses(data_dict, predictions) |
| self.optimizer.zero_grad() |
| losses['overall'].backward() |
| self.optimizer.step() |
|
|
|
|
| return losses,predictions |
|
|
|
|
| def train_epoch(self, epoch, train_data_loader, test_data_loaders=None): |
| self.logger.info("===> Epoch[{}] start!".format(epoch)) |
| |
| |
| times_per_epoch = 1 |
| test_step = len(train_data_loader) // times_per_epoch |
| step_cnt = epoch * len(train_data_loader) |
|
|
| |
| data_dict = train_data_loader.dataset.data_dict |
| self.save_data_dict('train', data_dict, ','.join(self.config['train_dataset'])) |
| |
| |
| train_recorder_loss = defaultdict(Recorder) |
| train_recorder_metric = defaultdict(Recorder) |
| |
| self.logger.info("===> Start For Loop!".format(epoch)) |
| for iteration, data_dict in tqdm(enumerate(train_data_loader),total=len(train_data_loader)): |
| self.setTrain() |
| |
| |
| for key in data_dict.keys(): |
| if data_dict[key]!=None and key!='name': |
| data_dict[key]=data_dict[key].cuda() |
|
|
| losses,predictions=self.train_step(data_dict) |
|
|
| |
| if 'SWA' in self.config and self.config['SWA'] and epoch>self.config['swa_start']: |
| self.swa_model.update_parameters(self.model) |
|
|
| |
| if type(self.model) is DDP: |
| batch_metrics = self.model.module.get_train_metrics(data_dict, predictions) |
| else: |
| batch_metrics = self.model.get_train_metrics(data_dict, predictions) |
|
|
| |
| |
| for name, value in batch_metrics.items(): |
| train_recorder_metric[name].update(value) |
| |
| for name, value in losses.items(): |
| train_recorder_loss[name].update(value) |
|
|
| |
| if iteration % 100 == 0 and self.config['local_rank']==0: |
| if self.config['SWA'] and (epoch>self.config['swa_start'] or self.config['dry_run']): |
| self.scheduler.step() |
| |
| loss_str = f"Iter: {step_cnt} " |
| for k, v in train_recorder_loss.items(): |
| v_avg = v.average() |
| if v_avg == None: |
| loss_str += f"training-loss, {k}: not calculated" |
| continue |
| loss_str += f"training-loss, {k}: {v_avg:.6f} " |
| |
| writer = self.get_writer('train', ','.join(self.config['train_dataset']), k) |
| writer.add_scalar(f'train_loss/{k}', v_avg, global_step=step_cnt) |
| self.logger.info(loss_str) |
| |
| metric_str = f"Iter: {step_cnt} " |
| for k, v in train_recorder_metric.items(): |
| v_avg = v.average() |
| if v_avg == None: |
| metric_str += f"training-metric, {k}: not calculated " |
| continue |
| metric_str += f"training-metric, {k}: {v_avg:.6f} " |
| |
| writer = self.get_writer('train', ','.join(self.config['train_dataset']), k) |
| writer.add_scalar(f'train_metric/{k}', v_avg, global_step=step_cnt) |
| self.logger.info(metric_str) |
|
|
| |
| |
| for name, recorder in train_recorder_loss.items(): |
| recorder.clear() |
| for name, recorder in train_recorder_metric.items(): |
| recorder.clear() |
|
|
| |
| test_best_metric = None |
| if (step_cnt+1) % test_step == 0: |
| if test_data_loaders is not None and (not self.config['ddp'] ): |
| self.logger.info("===> Test start!") |
| test_best_metric = self.test_epoch(epoch, iteration, test_data_loaders, step_cnt,) |
| elif test_data_loaders is not None and (self.config['ddp'] and dist.get_rank() == 0): |
| self.logger.info("===> Test start!") |
| test_best_metric = self.test_epoch(epoch, iteration, test_data_loaders, step_cnt,) |
| else: |
| test_best_metric = None |
|
|
| |
| |
| |
| step_cnt += 1 |
| |
| torch.cuda.empty_cache() |
| |
| return test_best_metric |
|
|
| def get_respect_acc_bin(self,prob,label): |
| pred = np.where(prob > 0.5, 1, 0) |
| judge = (pred == label) |
| zero_num = len(label) - np.count_nonzero(label) |
| acc_fake = np.count_nonzero(judge[zero_num:]) / len(judge[zero_num:]) |
| acc_real = np.count_nonzero(judge[:zero_num]) / len(judge[:zero_num]) |
| return acc_real,acc_fake |
|
|
| def get_respect_acc(self, pred_probs, labels): |
| pred_probs = torch.tensor(pred_probs) |
| labels = torch.tensor(labels) |
| |
| _, preds = torch.max(pred_probs, dim=1) |
| |
| classes = torch.unique(labels) |
| num_classes = len(classes) |
| |
| class_correct = torch.zeros(num_classes, dtype=torch.int64) |
| class_total = torch.zeros(num_classes, dtype=torch.int64) |
| |
| for label, pred in zip(labels, preds): |
| class_idx = label.item() |
| class_total[class_idx] += 1 |
| if label == pred: |
| class_correct[class_idx] += 1 |
| |
|
|
| class_acc = {} |
| for i, cls in enumerate(classes): |
| if class_total[i] == 0: |
| class_acc[cls.item()] = 0.0 |
| else: |
| class_acc[cls.item()] = (class_correct[i] / class_total[i]).item() |
| |
| return class_acc |
|
|
| def test_one_dataset(self, data_loader): |
| |
| test_recorder_loss = defaultdict(Recorder) |
| prediction_lists, feature_lists, label_lists = [], [], [] |
| for i, data_dict in tqdm(enumerate(data_loader),total=len(data_loader)): |
| if 'label_spe' in data_dict: |
| data_dict.pop('label_spe') |
| |
| |
| |
| for key in data_dict.keys(): |
| if data_dict[key] != None: |
| data_dict[key] = data_dict[key].cuda() |
| |
| |
| predictions = self.inference(data_dict) |
| |
| label_lists += list(data_dict['label'].cpu().detach().numpy()) |
| prediction_lists += list(predictions['prob'].cpu().detach().numpy()) |
| feature_lists += list(predictions['feat'].cpu().detach().numpy()) |
| |
| if type(self.model) is not AveragedModel: |
| |
| if type(self.model) is DDP: |
| losses = self.model.module.get_losses(data_dict, predictions) |
| else: |
| losses = self.model.get_losses(data_dict, predictions) |
|
|
| |
| for name, value in losses.items(): |
| test_recorder_loss[name].update(value) |
|
|
| return test_recorder_loss, np.array(prediction_lists), np.array(label_lists), np.array(feature_lists) |
|
|
| def save_best(self,epoch,iteration,step,losses_one_dataset_recorder,key,metric_one_dataset): |
| best_metric = self.best_metrics_all_time[key].get(self.metric_scoring, |
| float('-inf') if self.metric_scoring != 'eer' else float( |
| 'inf')) |
| |
| improved = (metric_one_dataset[self.metric_scoring] > best_metric) if self.metric_scoring != 'eer' else ( |
| metric_one_dataset[self.metric_scoring] < best_metric) |
| if improved: |
| |
| self.best_metrics_all_time[key][self.metric_scoring] = metric_one_dataset[self.metric_scoring] |
| if key == 'avg': |
| self.best_metrics_all_time[key]['dataset_dict'] = metric_one_dataset['dataset_dict'] |
| |
| if self.config['save_ckpt'] and key not in FFpp_pool: |
| self.save_ckpt('test', key, f"{epoch}+{iteration}") |
| self.save_metrics('test', metric_one_dataset, key) |
| |
| |
| if self.config['save_latest_ckpt']: |
| self.save_ckpt('test', key, f"{epoch}+{iteration}", ckpt_name="ckpt_latest.pth") |
| |
| |
| if losses_one_dataset_recorder is not None: |
| |
| loss_str = f"dataset: {key} step: {step} " |
| for k, v in losses_one_dataset_recorder.items(): |
| writer = self.get_writer('test', key, k) |
| v_avg = v.average() |
| if v_avg == None: |
| print(f'{k} is not calculated') |
| continue |
| |
| writer.add_scalar(f'test_losses/{k}', v_avg, global_step=step) |
| loss_str += f"testing-loss, {k}: {v_avg:.6f} " |
| self.logger.info(loss_str) |
| |
| |
| metric_str = f"dataset: {key} step: {step} " |
| for k, v in metric_one_dataset.items(): |
| if k == 'pred' or k == 'label' or k=='dataset_dict': |
| continue |
| metric_str += f"testing-metric, {k}: {v:.6f} " |
| |
| writer = self.get_writer('test', key, k) |
| writer.add_scalar(f'test_metrics/{k}', v, global_step=step) |
| self.logger.info(metric_str) |
| |
| if self.config['local_rank'] == 0: |
| if 'pred' in metric_one_dataset: |
| |
| self.logger.info("Start get_respect_acc()") |
| get_respect_acc = self.get_respect_acc(metric_one_dataset['pred'], metric_one_dataset['label']) |
| self.logger.info("End get_respect_acc()") |
| for cls_name, cls_val in get_respect_acc.items(): |
| metric_str = f"testing-metric, {cls_name}-acc:{cls_val:.4f}" |
| self.logger.info(metric_str) |
| writer.add_scalar(f'test_metrics/acc_{cls_name}', cls_val, global_step=step) |
| |
| def test_epoch(self, epoch, iteration, test_data_loaders, step): |
| self.setEval() |
|
|
| losses_all_datasets, metrics_all_datasets = {}, {} |
| best_metrics_per_dataset = defaultdict(dict) |
| avg_metric = {'acc': 0, 'mAP': 0, 'dataset_dict':{}} |
| |
| |
| keys = test_data_loaders.keys() |
| for key in keys: |
| |
| data_dict = test_data_loaders[key].dataset.data_dict |
| self.save_data_dict('test', data_dict, key) |
|
|
| |
| losses_one_dataset_recorder, predictions_nps, label_nps, feature_nps = self.test_one_dataset(test_data_loaders[key]) |
| print(f'stack len:{predictions_nps.shape};{label_nps.shape};{len(data_dict["image"])}') |
| |
| losses_all_datasets[key] = losses_one_dataset_recorder |
| |
| |
| metric_one_dataset = calculate_acc_for_test(label_nps, predictions_nps, self.config['backbone_config']['num_classes']) |
| |
| for metric_name, value in metric_one_dataset.items(): |
| if metric_name in avg_metric: |
| avg_metric[metric_name]+=value |
| avg_metric['dataset_dict'][key] = metric_one_dataset[self.metric_scoring] |
| if type(self.model) is AveragedModel: |
| metric_str = f"Iter Final for SWA: " |
| for k, v in metric_one_dataset.items(): |
| metric_str += f"testing-metric, {k}: {v:.6f} " |
| self.logger.info(metric_str) |
| continue |
| self.save_best(epoch,iteration,step,losses_one_dataset_recorder,key,metric_one_dataset) |
|
|
| if len(keys)>0 and self.config.get('save_avg',False): |
| |
| for key in avg_metric: |
| if key != 'dataset_dict': |
| avg_metric[key] /= len(keys) |
| self.save_best(epoch, iteration, step, None, 'avg', avg_metric) |
|
|
| self.logger.info('===> Test Done!') |
| return self.best_metrics_all_time |
|
|
| @torch.no_grad() |
| def inference(self, data_dict): |
| predictions = self.model(data_dict, inference=True) |
| return predictions |
|
|
| |
| |
| |
| |
| |
| |
|
|