| | from accelerate import Accelerator |
| | from tqdm.auto import tqdm |
| | import torch |
| | from torch.utils.data import DataLoader |
| | from datasets.multiple_datasets import MultipleDatasets, datasets_dict |
| | from datasets.common import COMMON |
| | from transformers import get_scheduler |
| | from safetensors.torch import load_file |
| | import os |
| | import re |
| | import time |
| | import datetime |
| | from models import build_sat_model |
| | from .funcs.eval_funcs import * |
| | from .funcs.infer_funcs import inference |
| | from utils import misc |
| | from utils.misc import get_world_size |
| | import torch.multiprocessing |
| | import numpy as np |
| |
|
| |
|
| | class Engine(): |
| | def __init__(self, args, mode='train'): |
| | self.exp_name = args.exp_name |
| | self.mode = mode |
| | assert mode in ['train','eval','infer'] |
| | self.conf_thresh = args.conf_thresh |
| | self.eval_func_maps = {'agora_validation': evaluate_agora, |
| | 'bedlam_validation_6fps': evaluate_agora, |
| | 'agora_test': test_agora} |
| | self.inference_func = inference |
| |
|
| | if self.mode == 'train': |
| | self.output_dir = os.path.join('./outputs') |
| | self.log_dir = os.path.join(self.output_dir,'logs') |
| | self.ckpt_dir = os.path.join(self.output_dir,'ckpts') |
| | self.distributed_eval = args.distributed_eval |
| | self.eval_vis_num = args.eval_vis_num |
| | elif self.mode == 'eval': |
| | self.output_dir = os.path.join('./results') |
| | self.distributed_eval = args.distributed_eval |
| | self.eval_vis_num = args.eval_vis_num |
| | elif self.mode == 'infer': |
| | output_dir = getattr(args, 'output_dir', None) |
| | if output_dir is not None: |
| | self.output_dir = output_dir |
| | else: |
| | now = datetime.datetime.now() |
| | timestamp = now.strftime("%Y%m%d_%H%M%S") |
| | self.output_dir = os.path.join('./results',f'{self.exp_name}_infer_{timestamp}') |
| | self.distributed_infer = args.distributed_infer |
| |
|
| | self.prepare_accelerator() |
| | self.prepare_models(args) |
| | self.prepare_datas(args) |
| | if self.mode == 'train': |
| | self.prepare_training(args) |
| |
|
| | total_cnt = sum(p.numel() for p in self.model.parameters()) |
| | trainable_cnt = sum(p.numel() for p in self.model.parameters() if p.requires_grad) |
| | self.accelerator.print(f'Initialization finished.\n{trainable_cnt} trainable parameters({total_cnt} total).') |
| |
|
| | def prepare_accelerator(self): |
| | if self.mode == 'train': |
| | self.accelerator = Accelerator( |
| | log_with="tensorboard", |
| | project_dir=os.path.join(self.log_dir) |
| | ) |
| | if self.accelerator.is_main_process: |
| | os.makedirs(self.log_dir, exist_ok=True) |
| | os.makedirs(os.path.join(self.ckpt_dir,self.exp_name),exist_ok=True) |
| | self.accelerator.init_trackers(self.exp_name) |
| | else: |
| | self.accelerator = Accelerator() |
| | if self.accelerator.is_main_process: |
| | os.makedirs(self.output_dir, exist_ok=True) |
| | |
| | def prepare_models(self, args): |
| | |
| | self.accelerator.print('Preparing models...') |
| | self.unwrapped_model, self.criterion = build_sat_model(args, set_criterion = (self.mode == 'train')) |
| | if self.criterion is not None: |
| | self.weight_dict = self.criterion.weight_dict |
| | |
| | if args.pretrain: |
| | self.accelerator.print(f'Loading pretrained weights: {args.pretrain_path}') |
| | state_dict = torch.load(args.pretrain_path) |
| | self.unwrapped_model.load_state_dict(state_dict,strict=False) |
| |
|
| | |
| | self.model = self.accelerator.prepare(self.unwrapped_model) |
| | |
| | def prepare_datas(self, args): |
| | |
| | if self.mode == 'train': |
| | self.accelerator.print('Loading training datasets:\n', |
| | [f'{d}_{s}' for d,s in zip(args.train_datasets_used, args.train_datasets_split)]) |
| | self.train_batch_size = args.train_batch_size |
| | train_dataset = MultipleDatasets(args.train_datasets_used, args.train_datasets_split, |
| | make_same_len=False, input_size=args.input_size, aug=True, |
| | mode = 'train', sat_cfg=args.sat_cfg, |
| | aug_cfg=args.aug_cfg) |
| | self.train_dataloader = DataLoader(dataset=train_dataset, batch_size=self.train_batch_size, |
| | shuffle=True,collate_fn=misc.collate_fn, |
| | num_workers=args.train_num_workers,pin_memory=True) |
| | self.train_dataloader = self.accelerator.prepare(self.train_dataloader) |
| |
|
| | if self.mode != 'infer': |
| | self.accelerator.print('Loading evaluation datasets:', |
| | [f'{d}_{s}' for d,s in zip(args.eval_datasets_used, args.eval_datasets_split)]) |
| | self.eval_batch_size = args.eval_batch_size |
| | eval_ds = {f'{ds}_{split}': datasets_dict[ds](split = split, |
| | mode = 'eval', |
| | input_size = args.input_size, |
| | aug = False, |
| | sat_cfg=args.sat_cfg)\ |
| | for (ds, split) in zip(args.eval_datasets_used, args.eval_datasets_split)} |
| | self.eval_dataloaders = {k: DataLoader(dataset=v, batch_size=self.eval_batch_size, |
| | shuffle=False,collate_fn=misc.collate_fn, |
| | num_workers=args.eval_num_workers,pin_memory=True)\ |
| | for (k,v) in eval_ds.items()} |
| | if self.distributed_eval: |
| | for (k,v) in self.eval_dataloaders.items(): |
| | self.eval_dataloaders.update({k: self.accelerator.prepare(v)}) |
| | |
| | else: |
| | img_folder = args.input_dir |
| | self.accelerator.print(f'Loading inference images from {img_folder}') |
| | self.infer_batch_size = args.infer_batch_size |
| | infer_ds = COMMON(img_folder = img_folder, input_size=args.input_size,aug=False, |
| | mode = 'infer', sat_cfg=args.sat_cfg) |
| | self.infer_dataloader = DataLoader(dataset=infer_ds, batch_size=self.infer_batch_size, |
| | shuffle=False,collate_fn=misc.collate_fn, |
| | num_workers=args.infer_num_workers,pin_memory=True) |
| |
|
| | if self.distributed_infer: |
| | self.infer_dataloader = self.accelerator.prepare(self.infer_dataloader) |
| |
|
| | def prepare_training(self, args): |
| | self.start_epoch = 0 |
| | self.num_epochs = args.num_epochs |
| | self.global_step = 0 |
| | if hasattr(args, 'sat_gt_epoch'): |
| | self.sat_gt_epoch = args.sat_gt_epoch |
| | self.accelerator.print(f'Use GT for the first {self.sat_gt_epoch} epoch(s)...') |
| | else: |
| | self.sat_gt_epoch = -1 |
| | self.save_and_eval_epoch = args.save_and_eval_epoch |
| | self.least_eval_epoch = args.least_eval_epoch |
| |
|
| | self.detach_j3ds = args.detach_j3ds |
| |
|
| | self.accelerator.print('Preparing optimizer and lr_scheduler...') |
| | param_dicts = [ |
| | { |
| | "params": |
| | [p for n, p in self.unwrapped_model.named_parameters() |
| | if not misc.match_name_keywords(n, args.lr_encoder_names) and p.requires_grad], |
| | "lr": args.lr, |
| | }, |
| | { |
| | "params": |
| | [p for n, p in self.unwrapped_model.named_parameters() |
| | if misc.match_name_keywords(n, args.lr_encoder_names) and p.requires_grad], |
| | "lr": args.lr_encoder, |
| | } |
| | ] |
| |
|
| | |
| | if args.optimizer == 'adamw': |
| | self.optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, |
| | weight_decay=args.weight_decay) |
| | else: |
| | raise NotImplementedError |
| | |
| | |
| | if args.lr_scheduler == 'cosine': |
| | self.lr_scheduler = get_scheduler(name="cosine", optimizer=self.optimizer, |
| | num_warmup_steps=args.num_warmup_steps, |
| | num_training_steps=get_world_size() * self.num_epochs * len(self.train_dataloader)) |
| | elif args.lr_scheduler == 'multistep': |
| | self.lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, args.milestones, gamma=args.gamma) |
| | else: |
| | raise NotImplementedError |
| |
|
| | self.optimizer, self.lr_scheduler = self.accelerator.prepare(self.optimizer, self.lr_scheduler) |
| |
|
| | |
| | if args.resume: |
| | if hasattr(args, 'ckpt_epoch'): |
| | self.load_ckpt(args.ckpt_epoch,args.ckpt_step) |
| | else: |
| | self.accelerator.print('Auto resume from latest ckpt...') |
| | epoch, step = -1, -1 |
| | pattern = re.compile(r'epoch_(\d+)_step_(\d+)') |
| | for folder_name in os.listdir(os.path.join(self.output_dir,'ckpts',self.exp_name)): |
| | match = pattern.match(folder_name) |
| | if match: |
| | i, j = int(match.group(1)), int(match.group(2)) |
| | if i > epoch: |
| | epoch, step = i, j |
| | if epoch >= 0: |
| | self.load_ckpt(epoch, step) |
| | else: |
| | self.accelerator.print('No existing ckpts! Train from scratch.') |
| |
|
| | def load_ckpt(self, epoch, step): |
| | self.accelerator.print(f'Loading checkpoint: epoch_{epoch}_step_{step}') |
| | ckpts_save_path = os.path.join(self.output_dir,'ckpts',self.exp_name, f'epoch_{epoch}_step_{step}') |
| | self.start_epoch = epoch + 1 |
| | self.global_step = step + 1 |
| | self.accelerator.load_state(ckpts_save_path) |
| |
|
| | def train(self): |
| | |
| | self.accelerator.print('Start training!') |
| | for epoch in range(self.start_epoch, self.num_epochs): |
| | torch.cuda.empty_cache() |
| | progress_bar = tqdm(total=len(self.train_dataloader), disable=not self.accelerator.is_local_main_process) |
| | progress_bar.set_description(f"Epoch {epoch}") |
| |
|
| | self.model.train() |
| | self.criterion.train() |
| |
|
| | sat_use_gt = (epoch < self.sat_gt_epoch) |
| |
|
| | for step, (samples,targets) in enumerate(self.train_dataloader): |
| |
|
| | outputs = self.model(samples, targets, sat_use_gt = sat_use_gt, detach_j3ds = self.detach_j3ds) |
| | loss_dict = self.criterion(outputs, targets) |
| |
|
| | loss = sum(loss_dict[k] * self.weight_dict[k] for k in loss_dict.keys()) |
| | |
| |
|
| | self.accelerator.backward(loss) |
| |
|
| | if self.accelerator.sync_gradients: |
| | self.accelerator.clip_grad_norm_(self.model.parameters(), 1.0) |
| |
|
| | self.optimizer.step() |
| |
|
| | self.lr_scheduler.step() |
| | self.optimizer.zero_grad() |
| | |
| | reduced_dict = self.accelerator.reduce(loss_dict,reduction='mean') |
| | simplified_logs = {k: v.item() for k, v in reduced_dict.items() if '.' not in k} |
| | |
| | |
| | if self.accelerator.is_main_process: |
| | tqdm.write(f'[{epoch}-{step+1}/{len(self.train_dataloader)}]: ' + str(simplified_logs)) |
| |
|
| | if step % 10 == 0: |
| | self.accelerator.log({('train/'+k):v for k,v in simplified_logs.items()}, |
| | step=self.global_step) |
| |
|
| | progress_bar.update(1) |
| | progress_bar.set_postfix(**{"lr": self.lr_scheduler.get_last_lr()[0], "step": self.global_step}) |
| |
|
| | self.global_step += 1 |
| | self.accelerator.wait_for_everyone() |
| |
|
| | |
| |
|
| | if epoch % self.save_and_eval_epoch == 0 or epoch == self.num_epochs-1: |
| | self.save_and_eval(epoch, save_ckpt=True) |
| | |
| | self.accelerator.end_training() |
| |
|
| | def eval(self, results_save_path = None, epoch = -1): |
| | if results_save_path is None: |
| | results_save_path = os.path.join(self.output_dir,self.exp_name,'evaluation') |
| | |
| | self.model.eval() |
| | unwrapped_model = self.unwrapped_model |
| | if self.accelerator.is_main_process: |
| | os.makedirs(results_save_path,exist_ok=True) |
| | |
| | for i, (key, eval_dataloader) in enumerate(self.eval_dataloaders.items()): |
| | assert key in self.eval_func_maps |
| | img_cnt = len(eval_dataloader) * self.eval_batch_size |
| | if self.distributed_eval: |
| | img_cnt *= self.accelerator.num_processes |
| | self.accelerator.print(f'Evaluate on {key}: {img_cnt} images') |
| | self.accelerator.print('Using following threshold(s): ', self.conf_thresh) |
| | conf_thresh = self.conf_thresh if 'agora' in key or 'bedlam' in key else [0.2] |
| | for thresh in conf_thresh: |
| | if self.accelerator.is_main_process or self.distributed_eval: |
| | error_dict = self.eval_func_maps[key](model = unwrapped_model, |
| | eval_dataloader = eval_dataloader, |
| | conf_thresh = thresh, |
| | vis_step = img_cnt // self.eval_vis_num, |
| | results_save_path = os.path.join(results_save_path,key,f'thresh_{thresh}'), |
| | distributed = self.distributed_eval, |
| | accelerator = self.accelerator, |
| | vis=True) |
| | if isinstance(error_dict,dict) and self.mode == 'train': |
| | log_dict = flatten_dict(error_dict) |
| | self.accelerator.log({(f'{key}_thresh_{thresh}/'+k):v for k,v in log_dict.items()}, step=epoch) |
| |
|
| | self.accelerator.print(f'thresh_{thresh}: ',error_dict) |
| | self.accelerator.wait_for_everyone() |
| | |
| | def save_and_eval(self, epoch, save_ckpt=False): |
| | torch.cuda.empty_cache() |
| | |
| | if self.accelerator.is_main_process and save_ckpt: |
| | ckpts_save_path = os.path.join(self.output_dir,'ckpts',self.exp_name, f'epoch_{epoch}_step_{self.global_step-1}') |
| | os.makedirs(ckpts_save_path,exist_ok=True) |
| | self.accelerator.save_state(ckpts_save_path, safe_serialization=False) |
| | self.accelerator.wait_for_everyone() |
| | |
| | if epoch < self.least_eval_epoch: |
| | return |
| | results_save_path = os.path.join(self.output_dir,'results',self.exp_name, f'epoch_{epoch}_step_{self.global_step-1}') |
| | self.eval(results_save_path, epoch=epoch) |
| |
|
| | def infer(self): |
| | self.model.eval() |
| | |
| | unwrapped_model = self.unwrapped_model |
| | |
| | results_save_path = self.output_dir |
| | if self.accelerator.is_main_process: |
| | os.makedirs(results_save_path,exist_ok=True) |
| | |
| | self.accelerator.print('Using following threshold(s): ', self.conf_thresh) |
| | for thresh in self.conf_thresh: |
| | if self.accelerator.is_main_process or self.distributed_infer: |
| | self.inference_func(model = unwrapped_model, |
| | infer_dataloader = self.infer_dataloader, |
| | conf_thresh = thresh, |
| | results_save_path = os.path.join(results_save_path,f'thresh_{thresh}'), |
| | distributed = self.distributed_infer, |
| | accelerator = self.accelerator) |
| | self.accelerator.wait_for_everyone() |
| |
|
| |
|
| | def flatten_dict(d, parent_key='', sep='-'): |
| | items = [] |
| | for k, v in d.items(): |
| | new_key = f"{parent_key}{sep}{k}" if parent_key else k |
| | if isinstance(v, dict): |
| | items.extend(flatten_dict(v, new_key, sep=sep).items()) |
| | else: |
| | items.append((new_key, v)) |
| | return dict(items) |
| |
|
| |
|