import os os.environ["DECORD_DUPLICATE_WARNING_THRESHOLD"] = "1.0" import argparse import csv from pathlib import Path import torch from torch.amp import autocast from tqdm import tqdm from train import VQADataset, com_loss, pearsonr, read_vid_mos_csv, spearmanr from model.qd_model import QD_MODEL def load_checkpoint(ckpt_path, device): ckpt = torch.load(str(ckpt_path), map_location=device, weights_only=True) if isinstance(ckpt, dict) and "model" in ckpt: return { "state_dict": ckpt["model"], "train_mos_mean": ckpt.get("mos_mean"), "train_mos_std": ckpt.get("mos_std"), "train_args": ckpt.get("args", {}), "is_full_checkpoint": True, } if isinstance(ckpt, dict): return { "state_dict": ckpt, "train_mos_mean": None, "train_mos_std": None, "train_args": {}, "is_full_checkpoint": False, } raise TypeError(f"Unsupported checkpoint type: {type(ckpt)!r}") def infer_test_scale(rows): mos_values = [float(mos) for _vid, mos in rows] if not mos_values: raise ValueError("Cannot infer test scale from empty rows") lo = min(mos_values) hi = max(mos_values) if 0.0 <= lo and hi <= 1.0: return 0.0, 1.0 if 1.0 <= lo and hi <= 5.0: return 1.0, 5.0 if 0.0 <= lo and hi <= 5.0: return 0.0, 5.0 return 0.0, 100.0 def linear_remap(x, src_min, src_max, dst_min, dst_max): src_min = float(src_min) src_max = float(src_max) dst_min = float(dst_min) dst_max = float(dst_max) if abs(src_max - src_min) <= 1e-12: raise ValueError("Source scale range must be non-zero") return (x - src_min) / (src_max - src_min) * (dst_max - dst_min) + dst_min def save_predictions_csv(save_path, vids, y_true_raw, pred_train_scale, pred_eval_scale): save_path = Path(save_path) save_path.parent.mkdir(parents=True, exist_ok=True) with open(save_path, "w", newline="", encoding="utf-8") as f: writer = csv.writer(f) writer.writerow(["vid", "y_true_raw", "pred_train_scale", "pred_eval_scale"]) for vid, y_true, pred_train, pred_eval in zip( vids, y_true_raw.tolist(), pred_train_scale.tolist(), pred_eval_scale.tolist(), strict=False, ): writer.writerow([vid, float(y_true), float(pred_train), float(pred_eval)]) return save_path @torch.no_grad() def evaluate_and_collect( model, loader, device, *, amp=True, train_mos_mean, train_mos_std, train_scale_min, train_scale_max, test_scale_min, test_scale_max, desc="", show_pbar=True, log_interval=10, ): model.eval() losses = [] y_all = [] yhat_all = [] vids_all = [] it = loader if show_pbar: it = tqdm(loader, desc=desc, leave=False, dynamic_ncols=True) for step, (rgb, w_art, w_str, y, vid) in enumerate(it, start=1): rgb = rgb.to(device, non_blocking=True) w_art = w_art.to(device, non_blocking=True) w_str = w_str.to(device, non_blocking=True) y = y.to(device, non_blocking=True).float() device_type = "cuda" if str(device).startswith("cuda") else "cpu" with autocast(device_type=device_type, enabled=(amp and device_type == "cuda")): yhat, _aux = model(rgb, w_art, w_str) loss, _loss_reg, _loss_rank = com_loss(yhat, y) losses.append(loss.detach().float().cpu()) y_all.append(y.detach().float().cpu()) yhat_all.append(yhat.detach().float().cpu()) vids_all.extend(list(vid)) if show_pbar and (step % int(log_interval) == 0 or step == len(loader)): avg_loss_so_far = torch.stack(losses).mean().item() it.set_postfix({"loss": f"{avg_loss_so_far:.4f}"}) if y_all: y_all = torch.cat(y_all, dim=0) yhat_all = torch.cat(yhat_all, dim=0) else: y_all = torch.empty(0) yhat_all = torch.empty(0) y_true_raw = y_all * float(train_mos_std) + float(train_mos_mean) pred_train_scale = yhat_all * float(train_mos_std) + float(train_mos_mean) pred_eval_scale = linear_remap( pred_train_scale, src_min=float(train_scale_min), src_max=float(train_scale_max), dst_min=float(test_scale_min), dst_max=float(test_scale_max), ) plcc = pearsonr(y_true_raw, pred_eval_scale).item() if y_true_raw.numel() > 1 else 0.0 srcc = spearmanr(y_true_raw, pred_eval_scale).item() if y_true_raw.numel() > 1 else 0.0 rmse = ( torch.sqrt(torch.mean((pred_eval_scale - y_true_raw) ** 2)).item() if y_true_raw.numel() > 0 else 0.0 ) avg_loss = torch.stack(losses).mean().item() if losses else 0.0 return { "loss": avg_loss, "plcc": plcc, "srcc": srcc, "rmse": rmse, "vids": vids_all, "y_true_raw": y_true_raw, "pred_train_scale": pred_train_scale, "pred_eval_scale": pred_eval_scale, } def main(): ap = argparse.ArgumentParser() ap.add_argument("--ckpt_path", type=str, default="/home/xinyi/Project/FD-VQA/src/checkpoints/lsvq/qd_model.best.pt") ap.add_argument("--csv_path", type=str, default="/home/xinyi/Project/FD-VQA/metadata/KVQ_metadata.csv") ap.add_argument("--db_path", type=str, default="/media/xinyi/server/video_dataset/KVQ") ap.add_argument("--clip_len", type=int, default=16) ap.add_argument("--resize", type=int, default=224) ap.add_argument("--win", type=int, default=6) ap.add_argument("--win_step", type=int, default=1) ap.add_argument("--batch_size", type=int, default=8) ap.add_argument("--num_workers", type=int, default=2) ap.add_argument("--device", type=str, default="cuda") ap.add_argument("--no_amp", action="store_true") ap.add_argument("--train_scale_min", type=float, default=0.0) ap.add_argument("--train_scale_max", type=float, default=100.0) ap.add_argument("--test_scale_min", type=float, default=1.0) ap.add_argument("--test_scale_max", type=float, default=5.0) ap.add_argument("--save_pred_csv", type=str, default="/home/xinyi/Project/FD-VQA/src/transfer_test/transfer_test_only_konvid_1k.csv") args = ap.parse_args() device = torch.device(args.device) amp = not bool(args.no_amp) ckpt_info = load_checkpoint(Path(args.ckpt_path), device) train_mos_mean = ckpt_info["train_mos_mean"] train_mos_std = ckpt_info["train_mos_std"] if train_mos_mean is None or train_mos_std is None: raise ValueError( "Prefer loading *.best.pt / *.pt, or pass --train_mos_mean and --train_mos_std manually." ) if float(train_mos_std) <= 1e-8: raise ValueError("train_mos_std must be > 0") rows = read_vid_mos_csv(args.csv_path) if not rows: raise ValueError(f"No rows found in csv: {args.csv_path}") if args.test_scale_min is None or args.test_scale_max is None: inferred_test_scale_min, inferred_test_scale_max = infer_test_scale(rows) test_scale_min = inferred_test_scale_min test_scale_max = inferred_test_scale_max else: test_scale_min = float(args.test_scale_min) test_scale_max = float(args.test_scale_max) dataset = VQADataset( rows, args.db_path, clip_len=args.clip_len, size=args.resize, win=args.win, win_step=args.win_step, mos_mean=float(train_mos_mean), mos_std=float(train_mos_std), ) pin = str(device).startswith("cuda") loader = torch.utils.data.DataLoader( dataset, batch_size=int(args.batch_size), shuffle=False, num_workers=int(args.num_workers), pin_memory=pin, drop_last=False, prefetch_factor=4 if int(args.num_workers) > 0 else None, ) model = QD_MODEL( clip_model="openai/clip-vit-base-patch16", ).to(device) model.load_state_dict(ckpt_info["state_dict"], strict=True) print(f"Loaded checkpoint: {args.ckpt_path}") print(f"Training normalization: mean={float(train_mos_mean):.6f}, std={float(train_mos_std):.6f}") print( f"Scale mapping: train=[{float(args.train_scale_min):.3f}, {float(args.train_scale_max):.3f}] -> " f"test=[{float(test_scale_min):.3f}, {float(test_scale_max):.3f}]" ) print(f"Test rows: {len(rows)}") metrics = evaluate_and_collect( model, loader, device, amp=amp, train_mos_mean=float(train_mos_mean), train_mos_std=float(train_mos_std), train_scale_min=float(args.train_scale_min), train_scale_max=float(args.train_scale_max), test_scale_min=float(test_scale_min), test_scale_max=float(test_scale_max), desc="Cross-dataset test", show_pbar=True, log_interval=10, ) print( "TEST | " f"loss={metrics['loss']:.4f} " f"plcc={metrics['plcc']:.4f} " f"srcc={metrics['srcc']:.4f} " f"rmse={metrics['rmse']:.4f}" ) if args.save_pred_csv: save_path = save_predictions_csv( args.save_pred_csv, metrics["vids"], metrics["y_true_raw"], metrics["pred_train_scale"], metrics["pred_eval_scale"], ) print(f"Saved predictions to: {save_path}") if __name__ == "__main__": main()