SAT-HMR / engines /engine.py
ChiSu001's picture
Upload model files
ff07ed4 verified
from accelerate import Accelerator
from tqdm.auto import tqdm
import torch
from torch.utils.data import DataLoader
from datasets.multiple_datasets import MultipleDatasets, datasets_dict
from datasets.common import COMMON
from transformers import get_scheduler
from safetensors.torch import load_file
import os
import re
import time
import datetime
from models import build_sat_model
from .funcs.eval_funcs import *
from .funcs.infer_funcs import inference
from utils import misc
from utils.misc import get_world_size
import torch.multiprocessing
import numpy as np
class Engine():
def __init__(self, args, mode='train'):
self.exp_name = args.exp_name
self.mode = mode
assert mode in ['train','eval','infer']
self.conf_thresh = args.conf_thresh
self.eval_func_maps = {'agora_validation': evaluate_agora,
'bedlam_validation_6fps': evaluate_agora,
'agora_test': test_agora}
self.inference_func = inference
if self.mode == 'train':
self.output_dir = os.path.join('./outputs')
self.log_dir = os.path.join(self.output_dir,'logs')
self.ckpt_dir = os.path.join(self.output_dir,'ckpts')
self.distributed_eval = args.distributed_eval
self.eval_vis_num = args.eval_vis_num
elif self.mode == 'eval':
self.output_dir = os.path.join('./results')
self.distributed_eval = args.distributed_eval
self.eval_vis_num = args.eval_vis_num
elif self.mode == 'infer':
output_dir = getattr(args, 'output_dir', None)
if output_dir is not None:
self.output_dir = output_dir
else:
now = datetime.datetime.now()
timestamp = now.strftime("%Y%m%d_%H%M%S")
self.output_dir = os.path.join('./results',f'{self.exp_name}_infer_{timestamp}')
self.distributed_infer = args.distributed_infer
self.prepare_accelerator()
self.prepare_models(args)
self.prepare_datas(args)
if self.mode == 'train':
self.prepare_training(args)
total_cnt = sum(p.numel() for p in self.model.parameters())
trainable_cnt = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
self.accelerator.print(f'Initialization finished.\n{trainable_cnt} trainable parameters({total_cnt} total).')
def prepare_accelerator(self):
if self.mode == 'train':
self.accelerator = Accelerator(
log_with="tensorboard",
project_dir=os.path.join(self.log_dir)
)
if self.accelerator.is_main_process:
os.makedirs(self.log_dir, exist_ok=True)
os.makedirs(os.path.join(self.ckpt_dir,self.exp_name),exist_ok=True)
self.accelerator.init_trackers(self.exp_name)
else:
self.accelerator = Accelerator()
if self.accelerator.is_main_process:
os.makedirs(self.output_dir, exist_ok=True)
def prepare_models(self, args):
# load model and criterion
self.accelerator.print('Preparing models...')
self.unwrapped_model, self.criterion = build_sat_model(args, set_criterion = (self.mode == 'train'))
if self.criterion is not None:
self.weight_dict = self.criterion.weight_dict
# load weights
if args.pretrain:
self.accelerator.print(f'Loading pretrained weights: {args.pretrain_path}')
state_dict = torch.load(args.pretrain_path)
self.unwrapped_model.load_state_dict(state_dict,strict=False)
# to gpu
self.model = self.accelerator.prepare(self.unwrapped_model)
def prepare_datas(self, args):
# load dataset and dataloader
if self.mode == 'train':
self.accelerator.print('Loading training datasets:\n',
[f'{d}_{s}' for d,s in zip(args.train_datasets_used, args.train_datasets_split)])
self.train_batch_size = args.train_batch_size
train_dataset = MultipleDatasets(args.train_datasets_used, args.train_datasets_split,
make_same_len=False, input_size=args.input_size, aug=True,
mode = 'train', sat_cfg=args.sat_cfg,
aug_cfg=args.aug_cfg)
self.train_dataloader = DataLoader(dataset=train_dataset, batch_size=self.train_batch_size,
shuffle=True,collate_fn=misc.collate_fn,
num_workers=args.train_num_workers,pin_memory=True)
self.train_dataloader = self.accelerator.prepare(self.train_dataloader)
if self.mode != 'infer':
self.accelerator.print('Loading evaluation datasets:',
[f'{d}_{s}' for d,s in zip(args.eval_datasets_used, args.eval_datasets_split)])
self.eval_batch_size = args.eval_batch_size
eval_ds = {f'{ds}_{split}': datasets_dict[ds](split = split,
mode = 'eval',
input_size = args.input_size,
aug = False,
sat_cfg=args.sat_cfg)\
for (ds, split) in zip(args.eval_datasets_used, args.eval_datasets_split)}
self.eval_dataloaders = {k: DataLoader(dataset=v, batch_size=self.eval_batch_size,
shuffle=False,collate_fn=misc.collate_fn,
num_workers=args.eval_num_workers,pin_memory=True)\
for (k,v) in eval_ds.items()}
if self.distributed_eval:
for (k,v) in self.eval_dataloaders.items():
self.eval_dataloaders.update({k: self.accelerator.prepare(v)})
else:
img_folder = args.input_dir
self.accelerator.print(f'Loading inference images from {img_folder}')
self.infer_batch_size = args.infer_batch_size
infer_ds = COMMON(img_folder = img_folder, input_size=args.input_size,aug=False,
mode = 'infer', sat_cfg=args.sat_cfg)
self.infer_dataloader = DataLoader(dataset=infer_ds, batch_size=self.infer_batch_size,
shuffle=False,collate_fn=misc.collate_fn,
num_workers=args.infer_num_workers,pin_memory=True)
if self.distributed_infer:
self.infer_dataloader = self.accelerator.prepare(self.infer_dataloader)
def prepare_training(self, args):
self.start_epoch = 0
self.num_epochs = args.num_epochs
self.global_step = 0
if hasattr(args, 'sat_gt_epoch'):
self.sat_gt_epoch = args.sat_gt_epoch
self.accelerator.print(f'Use GT for the first {self.sat_gt_epoch} epoch(s)...')
else:
self.sat_gt_epoch = -1
self.save_and_eval_epoch = args.save_and_eval_epoch
self.least_eval_epoch = args.least_eval_epoch
self.detach_j3ds = args.detach_j3ds
self.accelerator.print('Preparing optimizer and lr_scheduler...')
param_dicts = [
{
"params":
[p for n, p in self.unwrapped_model.named_parameters()
if not misc.match_name_keywords(n, args.lr_encoder_names) and p.requires_grad],
"lr": args.lr,
},
{
"params":
[p for n, p in self.unwrapped_model.named_parameters()
if misc.match_name_keywords(n, args.lr_encoder_names) and p.requires_grad],
"lr": args.lr_encoder,
}
]
# optimizer
if args.optimizer == 'adamw':
self.optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
weight_decay=args.weight_decay)
else:
raise NotImplementedError
# lr_scheduler
if args.lr_scheduler == 'cosine':
self.lr_scheduler = get_scheduler(name="cosine", optimizer=self.optimizer,
num_warmup_steps=args.num_warmup_steps,
num_training_steps=get_world_size() * self.num_epochs * len(self.train_dataloader))
elif args.lr_scheduler == 'multistep':
self.lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, args.milestones, gamma=args.gamma)
else:
raise NotImplementedError
self.optimizer, self.lr_scheduler = self.accelerator.prepare(self.optimizer, self.lr_scheduler)
# resume
if args.resume: #load model, optimizer, lr_scheduler and random_state
if hasattr(args, 'ckpt_epoch'):
self.load_ckpt(args.ckpt_epoch,args.ckpt_step)
else:
self.accelerator.print('Auto resume from latest ckpt...')
epoch, step = -1, -1
pattern = re.compile(r'epoch_(\d+)_step_(\d+)')
for folder_name in os.listdir(os.path.join(self.output_dir,'ckpts',self.exp_name)):
match = pattern.match(folder_name)
if match:
i, j = int(match.group(1)), int(match.group(2))
if i > epoch:
epoch, step = i, j
if epoch >= 0:
self.load_ckpt(epoch, step)
else:
self.accelerator.print('No existing ckpts! Train from scratch.')
def load_ckpt(self, epoch, step):
self.accelerator.print(f'Loading checkpoint: epoch_{epoch}_step_{step}')
ckpts_save_path = os.path.join(self.output_dir,'ckpts',self.exp_name, f'epoch_{epoch}_step_{step}')
self.start_epoch = epoch + 1
self.global_step = step + 1
self.accelerator.load_state(ckpts_save_path)
def train(self):
# torch.autograd.set_detect_anomaly(True)
self.accelerator.print('Start training!')
for epoch in range(self.start_epoch, self.num_epochs):
torch.cuda.empty_cache()
progress_bar = tqdm(total=len(self.train_dataloader), disable=not self.accelerator.is_local_main_process)
progress_bar.set_description(f"Epoch {epoch}")
self.model.train()
self.criterion.train()
sat_use_gt = (epoch < self.sat_gt_epoch)
for step, (samples,targets) in enumerate(self.train_dataloader):
outputs = self.model(samples, targets, sat_use_gt = sat_use_gt, detach_j3ds = self.detach_j3ds)
loss_dict = self.criterion(outputs, targets)
loss = sum(loss_dict[k] * self.weight_dict[k] for k in loss_dict.keys())
self.accelerator.backward(loss)
if self.accelerator.sync_gradients:
self.accelerator.clip_grad_norm_(self.model.parameters(), 1.0)
self.optimizer.step()
self.lr_scheduler.step()
self.optimizer.zero_grad()
reduced_dict = self.accelerator.reduce(loss_dict,reduction='mean')
simplified_logs = {k: v.item() for k, v in reduced_dict.items() if '.' not in k}
# logs.update({"lr": self.lr_scheduler.get_last_lr()[0], "step": self.global_step})
if self.accelerator.is_main_process:
tqdm.write(f'[{epoch}-{step+1}/{len(self.train_dataloader)}]: ' + str(simplified_logs))
if step % 10 == 0:
self.accelerator.log({('train/'+k):v for k,v in simplified_logs.items()},
step=self.global_step)
progress_bar.update(1)
progress_bar.set_postfix(**{"lr": self.lr_scheduler.get_last_lr()[0], "step": self.global_step})
self.global_step += 1
self.accelerator.wait_for_everyone()
# self.lr_scheduler.step()
if epoch % self.save_and_eval_epoch == 0 or epoch == self.num_epochs-1:
self.save_and_eval(epoch, save_ckpt=True)
self.accelerator.end_training()
def eval(self, results_save_path = None, epoch = -1):
if results_save_path is None:
results_save_path = os.path.join(self.output_dir,self.exp_name,'evaluation')
# preparing
self.model.eval()
unwrapped_model = self.unwrapped_model # self.accelerator.unwrap_model(self.model)
if self.accelerator.is_main_process:
os.makedirs(results_save_path,exist_ok=True)
# evaluate
for i, (key, eval_dataloader) in enumerate(self.eval_dataloaders.items()):
assert key in self.eval_func_maps
img_cnt = len(eval_dataloader) * self.eval_batch_size
if self.distributed_eval:
img_cnt *= self.accelerator.num_processes
self.accelerator.print(f'Evaluate on {key}: {img_cnt} images')
self.accelerator.print('Using following threshold(s): ', self.conf_thresh)
conf_thresh = self.conf_thresh if 'agora' in key or 'bedlam' in key else [0.2]
for thresh in conf_thresh:
if self.accelerator.is_main_process or self.distributed_eval:
error_dict = self.eval_func_maps[key](model = unwrapped_model,
eval_dataloader = eval_dataloader,
conf_thresh = thresh,
vis_step = img_cnt // self.eval_vis_num,
results_save_path = os.path.join(results_save_path,key,f'thresh_{thresh}'),
distributed = self.distributed_eval,
accelerator = self.accelerator,
vis=True)
if isinstance(error_dict,dict) and self.mode == 'train':
log_dict = flatten_dict(error_dict)
self.accelerator.log({(f'{key}_thresh_{thresh}/'+k):v for k,v in log_dict.items()}, step=epoch)
self.accelerator.print(f'thresh_{thresh}: ',error_dict)
self.accelerator.wait_for_everyone()
def save_and_eval(self, epoch, save_ckpt=False):
torch.cuda.empty_cache()
# save current state and model
if self.accelerator.is_main_process and save_ckpt:
ckpts_save_path = os.path.join(self.output_dir,'ckpts',self.exp_name, f'epoch_{epoch}_step_{self.global_step-1}')
os.makedirs(ckpts_save_path,exist_ok=True)
self.accelerator.save_state(ckpts_save_path, safe_serialization=False)
self.accelerator.wait_for_everyone()
if epoch < self.least_eval_epoch:
return
results_save_path = os.path.join(self.output_dir,'results',self.exp_name, f'epoch_{epoch}_step_{self.global_step-1}')
self.eval(results_save_path, epoch=epoch)
def infer(self):
self.model.eval()
# unwrapped_model = self.accelerator.unwrap_model(self.model)
unwrapped_model = self.unwrapped_model
results_save_path = self.output_dir
if self.accelerator.is_main_process:
os.makedirs(results_save_path,exist_ok=True)
self.accelerator.print('Using following threshold(s): ', self.conf_thresh)
for thresh in self.conf_thresh:
if self.accelerator.is_main_process or self.distributed_infer:
self.inference_func(model = unwrapped_model,
infer_dataloader = self.infer_dataloader,
conf_thresh = thresh,
results_save_path = os.path.join(results_save_path,f'thresh_{thresh}'),
distributed = self.distributed_infer,
accelerator = self.accelerator)
self.accelerator.wait_for_everyone()
def flatten_dict(d, parent_key='', sep='-'):
items = []
for k, v in d.items():
new_key = f"{parent_key}{sep}{k}" if parent_key else k
if isinstance(v, dict):
items.extend(flatten_dict(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)