"""Training and validation loop for the AV segmentation model.""" import numpy import torch from torch.utils.data import DataLoader from tqdm import tqdm class Trainer: """Wraps train/valid steps with optional loss, metrics, and logging.""" def __init__(self, hyp_param, loss, tensorboard, metrics): self.param = hyp_param self.loss = loss self.tensorboard = tensorboard self.metrics = metrics from loss.training.contrastive_learning import ContrastLoss self.cl = ContrastLoss(self.param) @torch.no_grad() def valid(self, epoch, dataloader, model, process=''): """Evaluate foreground IoU / F-score. `process` selects SAM multimask decoding (see branch below).""" if not isinstance(dataloader, DataLoader): raise TypeError( "valid() expects a torch.utils.data.DataLoader (do not pass iter(dataloader) first)." ) self.metrics['foreground_iou'].reset() self.metrics['foreground_f-score'].reset() dataloader_length = len(dataloader) tbar = range(dataloader_length) tbar = tqdm(tbar, ncols=135) if self.param.local_rank <= 0 else tbar iou_pool = [None] * self.param.gpus fscore_pool = [None] * self.param.gpus data_iter = iter(dataloader) for batch_index in tbar: items = next(data_iter) frame, spect, label, prompt_dicts = items['frame'], items['spectrogram'], items['label'], items['prompts'] frame = torch.flatten(frame, start_dim=0, end_dim=1).cuda(self.param.local_rank, non_blocking=True) spect = torch.flatten(spect, start_dim=0, end_dim=1).cuda(self.param.local_rank, non_blocking=True) label = torch.flatten(label, start_dim=0, end_dim=1).cuda(self.param.local_rank, non_blocking=True) with torch.autocast("cuda", dtype=torch.bfloat16): outputs, _ = model.module(frame, spect, prompt_dicts, sam_process=True) logits = torch.cat([torch.cat(i['multistep_pred_multimasks_high_res']) for i in outputs]) ious_scores = torch.cat([torch.cat(i['multistep_pred_ious']) for i in outputs]) occ_scores = torch.cat([torch.cat(i['multistep_object_score_logits']) for i in outputs]) # process: '' = first multimask; iou_select = argmax IoU head; iou_occ_select = + objectness gate if process == 'iou_select': ious_scores = torch.argmax(ious_scores, dim=1) logits = logits[torch.arange(0, frame.shape[0]), ious_scores, ...] elif process == 'iou_occ_select': ious_scores = torch.argmax(ious_scores, dim=1) logits = logits[torch.arange(0, frame.shape[0]), ious_scores, ...] logits[occ_scores.squeeze() < 0, ...] = 0. else: logits = logits[:, 0, ...] masks = logits > 0. foreground_iou_rank = self.metrics['foreground_iou'].calculate_iou(masks.squeeze().long(), label.squeeze().long(), get_entire_list=True) foreground_f_score_rank = self.metrics['foreground_f-score'].calculate_f_score(logits.squeeze(), label.squeeze(), get_entire_list=True) torch.distributed.all_gather_object(iou_pool, foreground_iou_rank) torch.distributed.all_gather_object(fscore_pool, foreground_f_score_rank) foreground_iou = sum([i['foreground_iou'][0].cpu() for i in iou_pool]) / sum( [i['foreground_iou'][1] for i in iou_pool]) foreground_f_score = sum([i['foreground_f-score'][0] for i in fscore_pool]) / sum( [i['foreground_f-score'][1] for i in fscore_pool]) if self.param.local_rank <= 0: tbar.set_description('epoch {} | valid.f_iou {}, valid.f_f-score {}'.format(epoch, numpy.round( foreground_iou.cpu().numpy(), 5), numpy.round( foreground_f_score, 5))) torch.cuda.empty_cache() final_iou = foreground_iou final_fscore = foreground_f_score if self.param.local_rank <= 0 and self.tensorboard is not None: self.tensorboard.upload_wandb_info({"valid.f_iou/{}".format(process): final_iou, "valid.f_f-score/{}".format(process): final_fscore}) def _to_float(x): if isinstance(x, torch.Tensor): return float(x.detach().cpu().item()) return float(x) return numpy.round(_to_float(final_iou), 5), numpy.round(_to_float(final_fscore), 5) def train(self, epoch, dataloader, model, optimiser): """One epoch: SAM frozen, AuralFuser + heads trained with composite loss + contrastive term.""" if not isinstance(dataloader, DataLoader): raise TypeError( "train() expects a torch.utils.data.DataLoader (do not pass iter(dataloader) first)." ) self.metrics['foreground_iou'].reset() self.metrics['foreground_f-score'].reset() dataloader_length = len(dataloader) tbar = range(dataloader_length) tbar = tqdm(tbar, ncols=135) if self.param.local_rank <= 0 else tbar data_iter = iter(dataloader) for batch_index in tbar: current_index = dataloader_length * epoch + batch_index items = next(data_iter) frame, spect, label, prompt_dicts = items['frame'], items['spectrogram'], items['label'], items['prompts'] frame = torch.flatten(frame, start_dim=0, end_dim=1).cuda(self.param.local_rank, non_blocking=True) spect = torch.flatten(spect, start_dim=0, end_dim=1).cuda(self.param.local_rank, non_blocking=True) label = torch.flatten(label, start_dim=0, end_dim=1).cuda(self.param.local_rank, non_blocking=True) with torch.autocast("cuda", dtype=torch.bfloat16): outputs, proj_feats = model(frame, spect, prompt_dicts, sam_process=False) # Use label_index to pick one supervised frame (legacy v1s behavior). label_index = prompt_dicts.get('label_index', None) if label_index is not None: if not isinstance(label_index, torch.Tensor): label_index = torch.as_tensor(label_index) label_index = label_index.flatten().to(device=label.device, dtype=torch.bool) if label_index.any(): frame_idx = int(torch.where(label_index)[0][0].item()) else: frame_idx = 0 else: frame_idx = 0 outputs_sel = outputs[frame_idx:frame_idx + 1] label_sel = label[frame_idx:frame_idx + 1] vision_feats, audio_feats = proj_feats # Keep the same nested-list structure as legacy v1s code. proj_feats_sel = ( [[vision_feats[i][frame_idx]] for i in range(3)], [[audio_feats[i][frame_idx]] for i in range(3)], ) loss_dict = self.loss(outputs_sel, label_sel.unsqueeze(1)) cl_loss = self.cl(proj_feats_sel, outputs_sel, label_sel) optimiser.zero_grad() (loss_dict['core_loss'] + cl_loss).backward() optimiser.step() current_lr = self.param.lr * (1 - current_index / (dataloader_length * self.param.epochs)) ** 0.9 for params_lr in optimiser.param_groups: names = params_lr.get("name", []) if names and any("vgg" in n for n in names): params_lr['lr'] = current_lr * 0.1 else: params_lr['lr'] = current_lr if self.param.local_rank <= 0: logits = torch.cat([i['multistep_pred_multimasks_high_res'][0] for i in outputs_sel]) foreground_iou = self.metrics['foreground_iou'].calculate_iou((logits > 0)[:, 0, ...].long(), label_sel.long()) self.tensorboard.upload_wandb_info({"loss": loss_dict['core_loss'].item(), "f_iou": foreground_iou.item(), "lr": optimiser.param_groups[0]['lr'], "loss_dice": loss_dict['loss_dice'], "loss_focal": loss_dict['loss_mask'], "loss_contras": cl_loss.item()}) tbar.set_description('epoch {} | loss {}, f_iou {}'.format(epoch, loss_dict['core_loss'].item(), foreground_iou.item())) ''' if batch_index % 200 == 0: pred_mask = (logits > 0)[:, 0, ...].long() n_vis = min(4, frame.shape[0], pred_mask.shape[0], label.shape[0]) self.tensorboard.upload_wandb_image( frame[:n_vis], pred_mask[:n_vis], label[:n_vis].long() ) ''' return