| import pprint |
| from tqdm import tqdm, trange |
| import numpy as np |
| import os |
| from collections import OrderedDict, defaultdict |
| from utils.basic_utils import AverageMeter |
|
|
| import torch |
| import torch.nn.functional as F |
| import torch.backends.cudnn as cudnn |
| from torch.utils.data import DataLoader |
|
|
| from cg_detr.config import TestOptions |
| from cg_detr.model import build_model |
| from cg_detr.span_utils import span_cxw_to_xx |
| from cg_detr.start_end_dataset import StartEndDataset, start_end_collate, prepare_batch_inputs |
| from cg_detr.postprocessing_cg_detr import PostProcessorDETR |
| from standalone_eval.eval import eval_submission |
| from utils.basic_utils import save_jsonl, save_json |
| from utils.temporal_nms import temporal_nms |
|
|
| import logging |
|
|
| logger = logging.getLogger(__name__) |
| logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s", |
| datefmt="%Y-%m-%d %H:%M:%S", |
| level=logging.INFO) |
|
|
|
|
| def post_processing_mr_nms(mr_res, nms_thd, max_before_nms, max_after_nms): |
| mr_res_after_nms = [] |
| for e in mr_res: |
| e["pred_relevant_windows"] = temporal_nms( |
| e["pred_relevant_windows"][:max_before_nms], |
| nms_thd=nms_thd, |
| max_after_nms=max_after_nms |
| ) |
| mr_res_after_nms.append(e) |
| return mr_res_after_nms |
|
|
|
|
| def eval_epoch_post_processing(submission, opt, gt_data, save_submission_filename): |
| |
| logger.info("Saving/Evaluating before nms results") |
| submission_path = os.path.join(opt.results_dir, save_submission_filename) |
| save_jsonl(submission, submission_path) |
|
|
| if opt.eval_split_name in ["val"]: |
| metrics = eval_submission( |
| submission, gt_data, |
| verbose=opt.debug, match_number=not opt.debug |
| ) |
| save_metrics_path = submission_path.replace(".jsonl", "_metrics.json") |
| save_json(metrics, save_metrics_path, save_pretty=True, sort_keys=False) |
| latest_file_paths = [submission_path, save_metrics_path] |
| else: |
| metrics = None |
| latest_file_paths = [submission_path, ] |
|
|
| if opt.nms_thd != -1: |
| logger.info("[MR] Performing nms with nms_thd {}".format(opt.nms_thd)) |
| submission_after_nms = post_processing_mr_nms( |
| submission, nms_thd=opt.nms_thd, |
| max_before_nms=opt.max_before_nms, max_after_nms=opt.max_after_nms |
| ) |
|
|
| logger.info("Saving/Evaluating nms results") |
| submission_nms_path = submission_path.replace(".jsonl", "_nms_thd_{}.jsonl".format(opt.nms_thd)) |
| save_jsonl(submission_after_nms, submission_nms_path) |
| if opt.eval_split_name == "val": |
| metrics_nms = eval_submission( |
| submission_after_nms, gt_data, |
| verbose=opt.debug, match_number=not opt.debug |
| ) |
| save_metrics_nms_path = submission_nms_path.replace(".jsonl", "_metrics.json") |
| save_json(metrics_nms, save_metrics_nms_path, save_pretty=True, sort_keys=False) |
| latest_file_paths += [submission_nms_path, save_metrics_nms_path] |
| else: |
| metrics_nms = None |
| latest_file_paths = [submission_nms_path, ] |
| else: |
| metrics_nms = None |
| return metrics, metrics_nms, latest_file_paths |
|
|
|
|
| |
| @torch.no_grad() |
| def compute_hl_results(model, eval_loader, opt, epoch_i=None, criterion=None, tb_writer=None): |
| model.eval() |
| if criterion: |
| assert eval_loader.dataset.load_labels |
| criterion.eval() |
|
|
| loss_meters = defaultdict(AverageMeter) |
| write_tb = tb_writer is not None and epoch_i is not None |
|
|
| mr_res = [] |
|
|
| topk = 5 |
|
|
| video_ap_collected = [] |
| for batch in tqdm(eval_loader, desc="compute st ed scores"): |
| query_meta = batch[0] |
|
|
| model_inputs, targets = prepare_batch_inputs(batch[1], opt.device, non_blocking=opt.pin_memory) |
|
|
| outputs = model(**model_inputs) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| preds = outputs['saliency_scores'].clone().detach() |
|
|
| for meta, pred in zip(query_meta, preds): |
| pred = pred |
| label = meta['label'] |
|
|
| video_ap = [] |
| |
| |
| if opt.dset_name in ["tvsum"]: |
| for i in range(20): |
| pred=pred.cpu() |
| cur_pred = pred[:len(label)] |
| inds = torch.argsort(cur_pred, descending=True, dim=-1) |
|
|
| |
| cur_label = torch.Tensor(label)[:, i] |
| cur_label = torch.where(cur_label > cur_label.median(), 1.0, .0) |
|
|
| cur_label = cur_label[inds].tolist()[:topk] |
|
|
| |
| num_gt = sum(cur_label) |
| if num_gt == 0: |
| video_ap.append(0) |
| continue |
|
|
| hits = ap = rec = 0 |
| prc = 1 |
|
|
| for j, gt in enumerate(cur_label): |
| hits += gt |
|
|
| _rec = hits / num_gt |
| _prc = hits / (j + 1) |
|
|
| ap += (_rec - rec) * (prc + _prc) / 2 |
| rec, prc = _rec, _prc |
|
|
| video_ap.append(ap) |
| |
| elif opt.dset_name in ["youtube_uni"]: |
| cur_pred = pred[:len(label)] |
| |
| cur_pred = cur_pred.cpu() |
| inds = torch.argsort(cur_pred, descending=True, dim=-1) |
|
|
|
|
| cur_label = torch.Tensor(label).squeeze()[inds].tolist() |
| |
| num_gt = sum(cur_label) |
| if num_gt == 0: |
| video_ap.append(0) |
| continue |
|
|
| hits = ap = rec = 0 |
| prc = 1 |
|
|
| for j, gt in enumerate(cur_label): |
| hits += gt |
|
|
| _rec = hits / num_gt |
| _prc = hits / (j + 1) |
|
|
| ap += (_rec - rec) * (prc + _prc) / 2 |
| rec, prc = _rec, _prc |
| |
| video_ap.append(float(ap)) |
| else: |
| print("No such dataset") |
| exit(-1) |
| |
| video_ap_collected.append(video_ap) |
|
|
| mean_ap = np.mean(video_ap_collected) |
| submmission = dict(mAP=round(mean_ap, 5)) |
| |
|
|
| |
| if write_tb and criterion: |
| for k, v in loss_meters.items(): |
| tb_writer.add_scalar("Eval/{}".format(k), v.avg, epoch_i + 1) |
|
|
| return submmission, loss_meters |
|
|
|
|
|
|
| @torch.no_grad() |
| def compute_mr_results(model, eval_loader, opt, epoch_i=None, criterion=None, tb_writer=None): |
| model.eval() |
| if criterion: |
| assert eval_loader.dataset.load_labels |
| criterion.eval() |
|
|
| loss_meters = defaultdict(AverageMeter) |
| write_tb = tb_writer is not None and epoch_i is not None |
|
|
| mr_res = [] |
| for batch in tqdm(eval_loader, desc="compute st ed scores"): |
| query_meta = batch[0] |
|
|
| model_inputs, targets = prepare_batch_inputs(batch[1], opt.device, non_blocking=opt.pin_memory) |
|
|
| outputs = model(**model_inputs) |
| prob = F.softmax(outputs["pred_logits"], -1) |
| if opt.span_loss_type == "l1": |
| scores = prob[..., 0] |
| pred_spans = outputs["pred_spans"] |
| _saliency_scores = outputs["saliency_scores"].half() |
| saliency_scores = [] |
| valid_vid_lengths = model_inputs["src_vid_mask"].sum(1).cpu().tolist() |
| for j in range(len(valid_vid_lengths)): |
| saliency_scores.append(_saliency_scores[j, :int(valid_vid_lengths[j])].tolist()) |
| else: |
| bsz, n_queries = outputs["pred_spans"].shape[:2] |
| pred_spans_logits = outputs["pred_spans"].view(bsz, n_queries, 2, opt.max_v_l) |
| pred_span_scores, pred_spans = F.softmax(pred_spans_logits, dim=-1).max(-1) |
| scores = torch.prod(pred_span_scores, 2) |
| pred_spans[:, 1] += 1 |
| pred_spans *= opt.clip_length |
|
|
| |
| for idx, (meta, spans, score) in enumerate(zip(query_meta, pred_spans.cpu(), scores.cpu())): |
| if opt.span_loss_type == "l1": |
| spans = span_cxw_to_xx(spans) * meta["duration"] |
| spans = torch.clamp(spans, 0, meta["duration"]) |
| |
| cur_ranked_preds = torch.cat([spans, score[:, None]], dim=1).tolist() |
| if not opt.no_sort_results: |
| cur_ranked_preds = sorted(cur_ranked_preds, key=lambda x: x[2], reverse=True) |
| cur_ranked_preds = [[float(f"{e:.4f}") for e in row] for row in cur_ranked_preds] |
| cur_query_pred = dict( |
| qid=meta["qid"], |
| query=meta["query"], |
| vid=meta["vid"], |
| pred_relevant_windows=cur_ranked_preds, |
| pred_saliency_scores=saliency_scores[idx] |
| ) |
| mr_res.append(cur_query_pred) |
|
|
| if criterion: |
| loss_dict = criterion(outputs, targets) |
| weight_dict = criterion.weight_dict |
| losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) |
| loss_dict["loss_overall"] = float(losses) |
| for k, v in loss_dict.items(): |
| loss_meters[k].update(float(v) * weight_dict[k] if k in weight_dict else float(v)) |
|
|
| if opt.debug: |
| break |
|
|
| if write_tb and criterion: |
| for k, v in loss_meters.items(): |
| tb_writer.add_scalar("Eval/{}".format(k), v.avg, epoch_i + 1) |
|
|
| if opt.dset_name in ['hl']: |
| post_processor = PostProcessorDETR( |
| clip_length=opt.clip_length, min_ts_val=0, max_ts_val=150, |
| min_w_l=2, max_w_l=150, move_window_method="left", |
| process_func_names=("clip_ts", "round_multiple") |
| ) |
| elif opt.dset_name in ['charadesSTA']: |
| if opt.v_feat_dim == 4096: |
| post_processor = PostProcessorDETR( |
| clip_length=opt.clip_length, min_ts_val=0, max_ts_val=360, |
| min_w_l=12, max_w_l=360, move_window_method="left", |
| process_func_names=("clip_ts", "round_multiple") |
| ) |
| else: |
| post_processor = PostProcessorDETR( |
| clip_length=opt.clip_length, min_ts_val=0, max_ts_val=150, |
| min_w_l=2, max_w_l=60, move_window_method="left", |
| process_func_names=("clip_ts", "round_multiple") |
| ) |
| else: |
| post_processor = PostProcessorDETR( |
| clip_length=opt.clip_length, min_ts_val=0, max_ts_val=50000, |
| min_w_l=0, max_w_l=50000, move_window_method="left", |
| process_func_names=(["round_multiple"]) |
| ) |
|
|
| mr_res = post_processor(mr_res) |
| return mr_res, loss_meters |
|
|
|
|
| def get_eval_res(model, eval_loader, opt, epoch_i, criterion, tb_writer): |
| """compute and save query and video proposal embeddings""" |
| eval_res, eval_loss_meters = compute_mr_results(model, eval_loader, opt, epoch_i, criterion, tb_writer) |
| return eval_res, eval_loss_meters |
|
|
|
|
| def eval_epoch(model, eval_dataset, opt, save_submission_filename, epoch_i=None, criterion=None, tb_writer=None): |
| logger.info("Generate submissions") |
| model.eval() |
| if criterion is not None and eval_dataset.load_labels: |
| criterion.eval() |
| else: |
| criterion = None |
|
|
| if opt.dset_name == 'tacos': |
| shuffle = True |
| else: |
| shuffle = False |
|
|
| eval_loader = DataLoader( |
| eval_dataset, |
| collate_fn=start_end_collate, |
| batch_size=opt.eval_bsz, |
| num_workers=opt.num_workers, |
| shuffle=shuffle, |
| pin_memory=opt.pin_memory |
| ) |
|
|
|
|
| |
| if opt.dset_name in ['tvsum', 'youtube_uni']: |
| metrics, eval_loss_meters = compute_hl_results(model, eval_loader, opt, epoch_i, criterion, tb_writer) |
| |
| |
| submission = [ |
| {"brief": metrics} |
| ] |
| submission_path = os.path.join(opt.results_dir, "latest_metric.jsonl") |
| save_jsonl(submission, submission_path) |
|
|
| return submission[0], submission[0], eval_loss_meters, [submission_path] |
|
|
| else: |
| submission, eval_loss_meters = get_eval_res(model, eval_loader, opt, epoch_i, criterion, tb_writer) |
|
|
| if opt.dset_name in ['charadesSTA', 'tacos', 'nlq']: |
| new_submission = [] |
| for s in submission: |
| s.pop('pred_saliency_scores', None) |
| new_submission.append(s) |
| submission = new_submission |
|
|
| if opt.no_sort_results: |
| save_submission_filename = save_submission_filename.replace(".jsonl", "_unsorted.jsonl") |
| metrics, metrics_nms, latest_file_paths = eval_epoch_post_processing( |
| submission, opt, eval_dataset.data, save_submission_filename) |
| return metrics, metrics_nms, eval_loss_meters, latest_file_paths |
|
|
|
|
| def setup_model(opt): |
| """setup model/optimizer/scheduler and load checkpoints when needed""" |
| logger.info("setup model/optimizer/scheduler") |
| model, criterion = build_model(opt) |
| if opt.device.type == "cuda": |
| logger.info("CUDA enabled.") |
| model.to(opt.device) |
| criterion.to(opt.device) |
|
|
| param_dicts = [{"params": [p for n, p in model.named_parameters() if p.requires_grad]}] |
| optimizer = torch.optim.AdamW(param_dicts, lr=opt.lr, weight_decay=opt.wd) |
| lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, opt.lr_drop) |
|
|
| if opt.resume is not None: |
| logger.info(f"Load checkpoint from {opt.resume}") |
| checkpoint = torch.load(opt.resume, map_location="cpu") |
| from collections import OrderedDict |
| new_state_dict = OrderedDict() |
| if 'pt' in opt.resume[:-4]: |
| if 'asr' in opt.resume[:25]: |
| model.load_state_dict(checkpoint["model"]) |
| else: |
| for k, v in checkpoint["model"].items(): |
| name = k[7:] |
| new_state_dict[name] = v |
| |
| model.load_state_dict(new_state_dict) |
| else: |
| model.load_state_dict(checkpoint["model"]) |
| if opt.resume_all: |
| optimizer.load_state_dict(checkpoint['optimizer']) |
| lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) |
| opt.start_epoch = checkpoint['epoch'] + 1 |
| logger.info(f"Loaded model saved at epoch {checkpoint['epoch']} from checkpoint: {opt.resume}") |
| else: |
| logger.warning("If you intend to evaluate the model, please specify --resume with ckpt path") |
|
|
| return model, criterion, optimizer, lr_scheduler |
|
|
|
|
| def start_inference(train_opt=None, split=None, splitfile=None): |
| if train_opt is not None: |
| opt = TestOptions().parse(train_opt.a_feat_dir) |
| else: |
| opt = TestOptions().parse() |
| if split is not None: |
| opt.eval_split_name = split |
| if splitfile is not None: |
| opt.eval_path = splitfile |
|
|
| print(opt.eval_split_name) |
| print(opt.eval_path) |
| logger.info("Setup config, data and model...") |
|
|
|
|
| cudnn.benchmark = True |
| cudnn.deterministic = False |
|
|
| assert opt.eval_path is not None |
| if opt.eval_split_name == 'val': |
| loadlabel = True |
| else: |
| loadlabel = False |
|
|
| eval_dataset = StartEndDataset( |
| dset_name=opt.dset_name, |
| data_path=opt.eval_path, |
| v_feat_dirs=opt.v_feat_dirs, |
| q_feat_dir=opt.t_feat_dir, |
| q_feat_type="last_hidden_state", |
| max_q_l=opt.max_q_l, |
| max_v_l=opt.max_v_l, |
| ctx_mode=opt.ctx_mode, |
| data_ratio=opt.data_ratio, |
| normalize_v=not opt.no_norm_vfeat, |
| normalize_t=not opt.no_norm_tfeat, |
| clip_len=opt.clip_length, |
| max_windows=opt.max_windows, |
| load_labels=loadlabel, |
| span_loss_type=opt.span_loss_type, |
| txt_drop_ratio=0, |
| dset_domain=opt.dset_domain, |
| ) |
|
|
|
|
|
|
| model, criterion, _, _ = setup_model(opt) |
|
|
| save_submission_filename = "hl_{}_submission.jsonl".format( |
| opt.eval_split_name) |
| |
| |
| logger.info("Starting inference...") |
| with torch.no_grad(): |
| metrics_no_nms, metrics_nms, eval_loss_meters, latest_file_paths = \ |
| eval_epoch(model, eval_dataset, opt, save_submission_filename, criterion=criterion) |
| if opt.eval_split_name == 'val': |
| logger.info("metrics_no_nms {}".format(pprint.pformat(metrics_no_nms["brief"], indent=4))) |
| if metrics_nms is not None: |
| logger.info("metrics_nms {}".format(pprint.pformat(metrics_nms["brief"], indent=4))) |
|
|
| from sys import argv |
| if __name__ == '__main__': |
| _,_,_,_,split,_,splitfile = argv |
|
|
| start_inference(split=split, splitfile=splitfile) |
|
|