yyliu01's picture
Upload folder using huggingface_hub
c6dfc69 verified
"""Training and validation for Ref-AVS (text + audio + SAM2 multimask decoding)."""
import numpy
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
_DECODE_MODES = frozenset({'', 'iou_select', 'iou_occ_select'})
def _decode_mode_and_wandb_tag(process):
"""Match tmp.code: `process` is decode mode for known strings; else Ref split tag + default decode."""
if process in _DECODE_MODES:
return process, process
return 'iou_select', process
class Trainer:
"""Train / valid / null-valid steps with composite loss, contrastive term, and metrics."""
def __init__(self, hyp_param, loss, tensorboard, metrics):
self.param = hyp_param
self.loss = loss
self.tensorboard = tensorboard
self.metrics = metrics
from loss.training.contrastive_learning import ContrastLoss
self.cl = ContrastLoss(self.param)
@torch.no_grad()
def valid_null(self, epoch, dataloader, model, process='test_n'):
if not isinstance(dataloader, DataLoader):
raise TypeError("valid_null() expects a torch.utils.data.DataLoader (do not pass iter(dataloader) first).")
decode_mode, wandb_tag = _decode_mode_and_wandb_tag(process)
self.metrics['foreground_s'].reset()
dataloader_length = len(dataloader)
tbar = range(dataloader_length)
tbar = tqdm(tbar, ncols=135) if self.param.local_rank <= 0 else tbar
p_pool = [None] * self.param.gpus
n_pool = [None] * self.param.gpus
data_iter = iter(dataloader)
for _ in tbar:
items = next(data_iter)
frame, spect, prompt_dicts = items['frame'], items['spectrogram'], items['text']
logits = []
for frame_, spect_, prompt_dicts_ in zip(frame, spect, prompt_dicts):
frame_ = frame_.cuda(self.param.local_rank, non_blocking=True)
spect_ = spect_.cuda(self.param.local_rank, non_blocking=True)
prompt_dicts_ = [prompt_dicts_]
with torch.autocast("cuda", dtype=torch.bfloat16):
outputs, _ = model.module(frame_, spect_, prompt_dicts_, sam_process=False)
logits_ = torch.cat([torch.cat(i['multistep_pred_multimasks_high_res']) for i in outputs])
ious_scores = torch.cat([torch.cat(i['multistep_pred_ious']) for i in outputs])
occ_scores = torch.cat([torch.cat(i['multistep_object_score_logits']) for i in outputs])
if decode_mode == 'iou_select':
ious_scores = torch.argmax(ious_scores, dim=1)
logits_ = logits_[torch.arange(0, frame_.shape[0]), ious_scores, ...]
elif decode_mode == 'iou_occ_select':
ious_scores = torch.argmax(ious_scores, dim=1)
logits_ = logits_[torch.arange(0, frame_.shape[0]), ious_scores, ...]
logits_[occ_scores.squeeze() < 0, ...] = 0.
else:
logits_ = logits_[:, 0, ...]
logits.append(logits_)
logits = torch.cat(logits).reshape(frame.shape[0], -1, self.param.image_size, self.param.image_size)
if len(logits.shape) == 3:
logits = logits.unsqueeze(1)
foreground_s = self.metrics['foreground_s'].metric_s_for_null(logits, get_entire_list=True)
torch.distributed.all_gather_object(p_pool, foreground_s['foreground_p'])
torch.distributed.all_gather_object(n_pool, foreground_s['foreground_n'])
foreground_s = sum([i[0].cpu() for i in p_pool]) / sum([i[0] for i in n_pool])
if self.param.local_rank <= 0:
tbar.set_description(
'epoch {} | valid.null_s {}'.format(epoch, numpy.round(foreground_s, 5)),
)
torch.cuda.empty_cache()
final_s = foreground_s
if self.param.local_rank <= 0 and self.tensorboard is not None:
self.tensorboard.upload_wandb_info({"valid.f_s/{}".format(wandb_tag): final_s})
return numpy.round(final_s, 5)
@torch.no_grad()
def valid(self, epoch, dataloader, model, process='iou_select'):
"""Evaluate IoU / F-score; `process` is decode mode (tmp) or split tag (test_s / test_u). Wandb keys like tmp."""
if not isinstance(dataloader, DataLoader):
raise TypeError("valid() expects a torch.utils.data.DataLoader (do not pass iter(dataloader) first).")
decode_mode, wandb_tag = _decode_mode_and_wandb_tag(process)
self.metrics['foreground_iou'].reset()
self.metrics['foreground_f-score'].reset()
dataloader_length = len(dataloader)
tbar = range(dataloader_length)
tbar = tqdm(tbar, ncols=135) if self.param.local_rank <= 0 else tbar
iou_pool = [None] * self.param.gpus
fscore_pool = [None] * self.param.gpus
data_iter = iter(dataloader)
for _ in tbar:
items = next(data_iter)
frame, spect, label, prompt_dicts = (
items['frame'], items['spectrogram'], items['label'], items['text']
)
logits = []
labels = []
for frame_, spect_, label_, prompt_dicts_ in zip(frame, spect, label, prompt_dicts):
frame_ = frame_.cuda(self.param.local_rank, non_blocking=True)
spect_ = spect_.cuda(self.param.local_rank, non_blocking=True)
label_ = label_.cuda(self.param.local_rank, non_blocking=True)
prompt_dicts_ = [prompt_dicts_]
with torch.autocast("cuda", dtype=torch.bfloat16):
outputs, _ = model.module(frame_, spect_, prompt_dicts_, sam_process=False)
logits_ = torch.cat([torch.cat(i['multistep_pred_multimasks_high_res']) for i in outputs])
ious_scores = torch.cat([torch.cat(i['multistep_pred_ious']) for i in outputs])
occ_scores = torch.cat([torch.cat(i['multistep_object_score_logits']) for i in outputs])
if decode_mode == 'iou_select':
ious_scores = torch.argmax(ious_scores, dim=1)
logits_ = logits_[torch.arange(0, frame_.shape[0]), ious_scores, ...]
elif decode_mode == 'iou_occ_select':
ious_scores = torch.argmax(ious_scores, dim=1)
logits_ = logits_[torch.arange(0, frame_.shape[0]), ious_scores, ...]
logits_[occ_scores.squeeze() < 0, ...] = 0.
else:
logits_ = logits_[:, 0, ...]
logits.append(logits_)
labels.append(label_)
logits = torch.cat(logits)
labels = torch.cat(labels)
foreground_iou_rank = self.metrics['foreground_iou'].calculate_iou(
(logits > 0.).squeeze().long(), labels.squeeze().long(), get_entire_list=True,
)
foreground_f_score_rank = self.metrics['foreground_f-score'].calculate_f_score(
logits.squeeze(), labels.squeeze().long(), get_entire_list=True,
)
torch.distributed.all_gather_object(iou_pool, foreground_iou_rank)
torch.distributed.all_gather_object(fscore_pool, foreground_f_score_rank)
foreground_iou = sum([i['foreground_iou'][0].cpu() for i in iou_pool]) / sum(
[i['foreground_iou'][1] for i in iou_pool])
foreground_f_score = sum([i['foreground_f-score'][0] for i in fscore_pool]) / sum(
[i['foreground_f-score'][1] for i in fscore_pool])
if self.param.local_rank <= 0:
tbar.set_description(
'epoch {} | valid.f_iou {}, valid.f_f-score {}'.format(
epoch,
numpy.round(foreground_iou.cpu().numpy(), 5),
numpy.round(foreground_f_score, 5),
),
)
torch.cuda.empty_cache()
final_iou = foreground_iou
final_fscore = foreground_f_score
if self.param.local_rank <= 0 and self.tensorboard is not None:
self.tensorboard.upload_wandb_info({
"valid.f_iou/{}".format(wandb_tag): final_iou,
"valid.f_f-score/{}".format(wandb_tag): final_fscore,
})
def _to_float(x):
if isinstance(x, torch.Tensor):
return float(x.detach().cpu().item())
return float(x)
return numpy.round(_to_float(final_iou), 5), numpy.round(_to_float(final_fscore), 5)
def train(self, epoch, dataloader, model, optimiser):
if not isinstance(dataloader, DataLoader):
raise TypeError("train() expects a torch.utils.data.DataLoader (do not pass iter(dataloader) first).")
dataloader_length = len(dataloader)
tbar = range(dataloader_length)
tbar = tqdm(tbar, ncols=135) if self.param.local_rank <= 0 else tbar
data_iter = iter(dataloader)
for batch_index in tbar:
current_index = dataloader_length * epoch + batch_index
items = next(data_iter)
frame, spect, label, prompt_dicts = (
items['frame'], items['spectrogram'], items['label'], items['text'],
)
frame = torch.flatten(frame, start_dim=0, end_dim=1).cuda(self.param.local_rank, non_blocking=True)
spect = torch.flatten(spect, start_dim=0, end_dim=1).cuda(self.param.local_rank, non_blocking=True)
label = torch.flatten(label, start_dim=0, end_dim=1).cuda(self.param.local_rank, non_blocking=True)
with torch.autocast("cuda", dtype=torch.bfloat16):
outputs, proj_feats = model(frame, spect, prompt_dicts, sam_process=False)
loss_dict = self.loss(outputs, label.unsqueeze(1))
cl_loss = self.cl(proj_feats, outputs, label)
optimiser.zero_grad()
(loss_dict['core_loss'] + cl_loss).backward()
optimiser.step()
current_lr = self.param.lr * (1 - current_index / (dataloader_length * self.param.epochs)) ** 0.9
for params_lr in optimiser.param_groups:
names = params_lr.get("name", [])
if names and any("vgg" in n for n in names):
params_lr['lr'] = current_lr * 0.1
else:
params_lr['lr'] = current_lr
if self.param.local_rank <= 0 and self.tensorboard is not None:
logits = torch.cat([i['multistep_pred_multimasks_high_res'][0] for i in outputs])
foreground_iou = self.metrics['foreground_iou'].calculate_iou(
(logits > 0)[:, 0, ...].long(), label.long(),
)
self.tensorboard.upload_wandb_info({
"loss": loss_dict['core_loss'].item(), "f_iou": foreground_iou.item(),
"lr": optimiser.param_groups[0]['lr'],
"loss_dice": loss_dict['loss_dice'],
"loss_focal": loss_dict['loss_mask'],
"loss_contras": cl_loss.item(),
})
tbar.set_description(
'epoch {} | loss {}, f_iou {}'.format(
epoch, loss_dict['core_loss'].item(), foreground_iou.item(),
),
)
return