| """Training and validation for Ref-AVS (text + audio + SAM2 multimask decoding).""" |
| import numpy |
| import torch |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
|
|
| _DECODE_MODES = frozenset({'', 'iou_select', 'iou_occ_select'}) |
|
|
|
|
| def _decode_mode_and_wandb_tag(process): |
| """Match tmp.code: `process` is decode mode for known strings; else Ref split tag + default decode.""" |
| if process in _DECODE_MODES: |
| return process, process |
| return 'iou_select', process |
|
|
|
|
| class Trainer: |
| """Train / valid / null-valid steps with composite loss, contrastive term, and metrics.""" |
|
|
| def __init__(self, hyp_param, loss, tensorboard, metrics): |
| self.param = hyp_param |
| self.loss = loss |
| self.tensorboard = tensorboard |
| self.metrics = metrics |
| from loss.training.contrastive_learning import ContrastLoss |
| self.cl = ContrastLoss(self.param) |
|
|
| @torch.no_grad() |
| def valid_null(self, epoch, dataloader, model, process='test_n'): |
| if not isinstance(dataloader, DataLoader): |
| raise TypeError("valid_null() expects a torch.utils.data.DataLoader (do not pass iter(dataloader) first).") |
| decode_mode, wandb_tag = _decode_mode_and_wandb_tag(process) |
| self.metrics['foreground_s'].reset() |
| dataloader_length = len(dataloader) |
| tbar = range(dataloader_length) |
| tbar = tqdm(tbar, ncols=135) if self.param.local_rank <= 0 else tbar |
| p_pool = [None] * self.param.gpus |
| n_pool = [None] * self.param.gpus |
|
|
| data_iter = iter(dataloader) |
| for _ in tbar: |
| items = next(data_iter) |
| frame, spect, prompt_dicts = items['frame'], items['spectrogram'], items['text'] |
| logits = [] |
| for frame_, spect_, prompt_dicts_ in zip(frame, spect, prompt_dicts): |
| frame_ = frame_.cuda(self.param.local_rank, non_blocking=True) |
| spect_ = spect_.cuda(self.param.local_rank, non_blocking=True) |
| prompt_dicts_ = [prompt_dicts_] |
| with torch.autocast("cuda", dtype=torch.bfloat16): |
| outputs, _ = model.module(frame_, spect_, prompt_dicts_, sam_process=False) |
|
|
| logits_ = torch.cat([torch.cat(i['multistep_pred_multimasks_high_res']) for i in outputs]) |
| ious_scores = torch.cat([torch.cat(i['multistep_pred_ious']) for i in outputs]) |
| occ_scores = torch.cat([torch.cat(i['multistep_object_score_logits']) for i in outputs]) |
| if decode_mode == 'iou_select': |
| ious_scores = torch.argmax(ious_scores, dim=1) |
| logits_ = logits_[torch.arange(0, frame_.shape[0]), ious_scores, ...] |
| elif decode_mode == 'iou_occ_select': |
| ious_scores = torch.argmax(ious_scores, dim=1) |
| logits_ = logits_[torch.arange(0, frame_.shape[0]), ious_scores, ...] |
| logits_[occ_scores.squeeze() < 0, ...] = 0. |
| else: |
| logits_ = logits_[:, 0, ...] |
| logits.append(logits_) |
|
|
| logits = torch.cat(logits).reshape(frame.shape[0], -1, self.param.image_size, self.param.image_size) |
| if len(logits.shape) == 3: |
| logits = logits.unsqueeze(1) |
|
|
| foreground_s = self.metrics['foreground_s'].metric_s_for_null(logits, get_entire_list=True) |
| torch.distributed.all_gather_object(p_pool, foreground_s['foreground_p']) |
| torch.distributed.all_gather_object(n_pool, foreground_s['foreground_n']) |
| foreground_s = sum([i[0].cpu() for i in p_pool]) / sum([i[0] for i in n_pool]) |
|
|
| if self.param.local_rank <= 0: |
| tbar.set_description( |
| 'epoch {} | valid.null_s {}'.format(epoch, numpy.round(foreground_s, 5)), |
| ) |
| torch.cuda.empty_cache() |
|
|
| final_s = foreground_s |
| if self.param.local_rank <= 0 and self.tensorboard is not None: |
| self.tensorboard.upload_wandb_info({"valid.f_s/{}".format(wandb_tag): final_s}) |
|
|
| return numpy.round(final_s, 5) |
|
|
| @torch.no_grad() |
| def valid(self, epoch, dataloader, model, process='iou_select'): |
| """Evaluate IoU / F-score; `process` is decode mode (tmp) or split tag (test_s / test_u). Wandb keys like tmp.""" |
| if not isinstance(dataloader, DataLoader): |
| raise TypeError("valid() expects a torch.utils.data.DataLoader (do not pass iter(dataloader) first).") |
| decode_mode, wandb_tag = _decode_mode_and_wandb_tag(process) |
| self.metrics['foreground_iou'].reset() |
| self.metrics['foreground_f-score'].reset() |
| dataloader_length = len(dataloader) |
| tbar = range(dataloader_length) |
| tbar = tqdm(tbar, ncols=135) if self.param.local_rank <= 0 else tbar |
| iou_pool = [None] * self.param.gpus |
| fscore_pool = [None] * self.param.gpus |
|
|
| data_iter = iter(dataloader) |
| for _ in tbar: |
| items = next(data_iter) |
| frame, spect, label, prompt_dicts = ( |
| items['frame'], items['spectrogram'], items['label'], items['text'] |
| ) |
| logits = [] |
| labels = [] |
| for frame_, spect_, label_, prompt_dicts_ in zip(frame, spect, label, prompt_dicts): |
| frame_ = frame_.cuda(self.param.local_rank, non_blocking=True) |
| spect_ = spect_.cuda(self.param.local_rank, non_blocking=True) |
| label_ = label_.cuda(self.param.local_rank, non_blocking=True) |
| prompt_dicts_ = [prompt_dicts_] |
| with torch.autocast("cuda", dtype=torch.bfloat16): |
| outputs, _ = model.module(frame_, spect_, prompt_dicts_, sam_process=False) |
|
|
| logits_ = torch.cat([torch.cat(i['multistep_pred_multimasks_high_res']) for i in outputs]) |
| ious_scores = torch.cat([torch.cat(i['multistep_pred_ious']) for i in outputs]) |
| occ_scores = torch.cat([torch.cat(i['multistep_object_score_logits']) for i in outputs]) |
| if decode_mode == 'iou_select': |
| ious_scores = torch.argmax(ious_scores, dim=1) |
| logits_ = logits_[torch.arange(0, frame_.shape[0]), ious_scores, ...] |
| elif decode_mode == 'iou_occ_select': |
| ious_scores = torch.argmax(ious_scores, dim=1) |
| logits_ = logits_[torch.arange(0, frame_.shape[0]), ious_scores, ...] |
| logits_[occ_scores.squeeze() < 0, ...] = 0. |
| else: |
| logits_ = logits_[:, 0, ...] |
| logits.append(logits_) |
| labels.append(label_) |
|
|
| logits = torch.cat(logits) |
| labels = torch.cat(labels) |
| foreground_iou_rank = self.metrics['foreground_iou'].calculate_iou( |
| (logits > 0.).squeeze().long(), labels.squeeze().long(), get_entire_list=True, |
| ) |
| foreground_f_score_rank = self.metrics['foreground_f-score'].calculate_f_score( |
| logits.squeeze(), labels.squeeze().long(), get_entire_list=True, |
| ) |
| torch.distributed.all_gather_object(iou_pool, foreground_iou_rank) |
| torch.distributed.all_gather_object(fscore_pool, foreground_f_score_rank) |
| foreground_iou = sum([i['foreground_iou'][0].cpu() for i in iou_pool]) / sum( |
| [i['foreground_iou'][1] for i in iou_pool]) |
| foreground_f_score = sum([i['foreground_f-score'][0] for i in fscore_pool]) / sum( |
| [i['foreground_f-score'][1] for i in fscore_pool]) |
|
|
| if self.param.local_rank <= 0: |
| tbar.set_description( |
| 'epoch {} | valid.f_iou {}, valid.f_f-score {}'.format( |
| epoch, |
| numpy.round(foreground_iou.cpu().numpy(), 5), |
| numpy.round(foreground_f_score, 5), |
| ), |
| ) |
| torch.cuda.empty_cache() |
|
|
| final_iou = foreground_iou |
| final_fscore = foreground_f_score |
| if self.param.local_rank <= 0 and self.tensorboard is not None: |
| self.tensorboard.upload_wandb_info({ |
| "valid.f_iou/{}".format(wandb_tag): final_iou, |
| "valid.f_f-score/{}".format(wandb_tag): final_fscore, |
| }) |
|
|
| def _to_float(x): |
| if isinstance(x, torch.Tensor): |
| return float(x.detach().cpu().item()) |
| return float(x) |
|
|
| return numpy.round(_to_float(final_iou), 5), numpy.round(_to_float(final_fscore), 5) |
|
|
| def train(self, epoch, dataloader, model, optimiser): |
| if not isinstance(dataloader, DataLoader): |
| raise TypeError("train() expects a torch.utils.data.DataLoader (do not pass iter(dataloader) first).") |
| dataloader_length = len(dataloader) |
| tbar = range(dataloader_length) |
| tbar = tqdm(tbar, ncols=135) if self.param.local_rank <= 0 else tbar |
|
|
| data_iter = iter(dataloader) |
| for batch_index in tbar: |
| current_index = dataloader_length * epoch + batch_index |
| items = next(data_iter) |
| frame, spect, label, prompt_dicts = ( |
| items['frame'], items['spectrogram'], items['label'], items['text'], |
| ) |
| frame = torch.flatten(frame, start_dim=0, end_dim=1).cuda(self.param.local_rank, non_blocking=True) |
| spect = torch.flatten(spect, start_dim=0, end_dim=1).cuda(self.param.local_rank, non_blocking=True) |
| label = torch.flatten(label, start_dim=0, end_dim=1).cuda(self.param.local_rank, non_blocking=True) |
|
|
| with torch.autocast("cuda", dtype=torch.bfloat16): |
| outputs, proj_feats = model(frame, spect, prompt_dicts, sam_process=False) |
| loss_dict = self.loss(outputs, label.unsqueeze(1)) |
| cl_loss = self.cl(proj_feats, outputs, label) |
|
|
| optimiser.zero_grad() |
| (loss_dict['core_loss'] + cl_loss).backward() |
| optimiser.step() |
|
|
| current_lr = self.param.lr * (1 - current_index / (dataloader_length * self.param.epochs)) ** 0.9 |
| for params_lr in optimiser.param_groups: |
| names = params_lr.get("name", []) |
| if names and any("vgg" in n for n in names): |
| params_lr['lr'] = current_lr * 0.1 |
| else: |
| params_lr['lr'] = current_lr |
|
|
| if self.param.local_rank <= 0 and self.tensorboard is not None: |
| logits = torch.cat([i['multistep_pred_multimasks_high_res'][0] for i in outputs]) |
| foreground_iou = self.metrics['foreground_iou'].calculate_iou( |
| (logits > 0)[:, 0, ...].long(), label.long(), |
| ) |
| self.tensorboard.upload_wandb_info({ |
| "loss": loss_dict['core_loss'].item(), "f_iou": foreground_iou.item(), |
| "lr": optimiser.param_groups[0]['lr'], |
| "loss_dice": loss_dict['loss_dice'], |
| "loss_focal": loss_dict['loss_mask'], |
| "loss_contras": cl_loss.item(), |
| }) |
| tbar.set_description( |
| 'epoch {} | loss {}, f_iou {}'.format( |
| epoch, loss_dict['core_loss'].item(), foreground_iou.item(), |
| ), |
| ) |
| return |
|
|