| import os |
| os.environ["DECORD_DUPLICATE_WARNING_THRESHOLD"] = "1.0" |
| import argparse |
| import csv |
| from pathlib import Path |
| import numpy as np |
| from scipy.stats import pearsonr as _pearsonr_scipy, spearmanr as _spearmanr_scipy |
| import torch |
| import torch.nn.functional as F |
| import cv2 |
| from decord import VideoReader, cpu |
| from torch.amp import GradScaler, autocast |
| from tqdm import tqdm |
|
|
| from module.compute_weight_map import process_video, compute_weight_map |
| from model.qd_model import QD_MODEL |
|
|
|
|
| |
| |
| |
| def read_vid_mos_csv(csv_path): |
| rows = [] |
| with open(csv_path, "r", encoding="utf-8") as f: |
| reader = csv.DictReader(f) |
| if not reader.fieldnames: |
| raise RuntimeError("CSV has no header") |
| for r in reader: |
| vid = str(r["vid"]).strip() |
| mos = float(r["mos"]) |
| rows.append((vid, mos)) |
| return rows |
|
|
| def split_rows(rows, seed=42, test_ratio=0.2, val_ratio=0.1): |
| rng = np.random.default_rng(int(seed)) |
| idx = np.arange(len(rows)) |
| rng.shuffle(idx) |
|
|
| n = len(rows) |
| n_test = int(round(n * test_ratio)) |
| n_train_all = n - n_test |
| n_val = int(round(n_train_all * val_ratio)) |
|
|
| val = [rows[i] for i in idx[:n_val]] |
| train = [rows[i] for i in idx[n_val:n_train_all]] |
| test = [rows[i] for i in idx[n_train_all:]] |
| return train, val, test |
|
|
| def split_train_val(rows, seed=42, val_ratio=0.1): |
| rng = np.random.default_rng(int(seed)) |
| idx = np.arange(len(rows)) |
| rng.shuffle(idx) |
|
|
| n = len(rows) |
| n_val = int(round(n * val_ratio)) |
| val = [rows[i] for i in idx[:n_val]] |
| train = [rows[i] for i in idx[n_val:]] |
| return train, val |
|
|
| def pearsonr(x, y, eps=1e-12): |
| |
| if hasattr(x, "detach"): |
| x = x.detach().cpu().numpy() |
| if hasattr(y, "detach"): |
| y = y.detach().cpu().numpy() |
| x = np.asarray(x).reshape(-1) |
| y = np.asarray(y).reshape(-1) |
|
|
| |
| if x.size < 2 or np.std(x) < eps or np.std(y) < eps: |
| return torch.tensor(0.0) |
| r, _p = _pearsonr_scipy(x, y) |
| if np.isnan(r): |
| r = 0.0 |
| return torch.tensor(float(r)) |
|
|
| def spearmanr(x, y, eps=1e-12): |
| |
| if hasattr(x, "detach"): |
| x = x.detach().cpu().numpy() |
| if hasattr(y, "detach"): |
| y = y.detach().cpu().numpy() |
|
|
| x = np.asarray(x).reshape(-1) |
| y = np.asarray(y).reshape(-1) |
| if x.size < 2 or np.std(x) < eps or np.std(y) < eps: |
| return torch.tensor(0.0) |
|
|
| r, _p = _spearmanr_scipy(x, y) |
| if np.isnan(r): |
| r = 0.0 |
| return torch.tensor(float(r)) |
|
|
| |
| |
| |
| def build_scheduler(optim, args): |
| warm = int(args.warmup_epochs) |
| total = int(args.epochs) |
| warm = max(0, min(warm, total - 1)) |
|
|
| |
| warmup = torch.optim.lr_scheduler.LinearLR( |
| optim, |
| start_factor=0.1, |
| total_iters=warm if warm > 0 else 1, |
| ) |
| |
| cosine = torch.optim.lr_scheduler.CosineAnnealingLR( |
| optim, |
| T_max=(total - warm) if (total - warm) > 0 else 1, |
| eta_min=float(args.min_lr), |
| ) |
|
|
| if warm > 0: |
| return torch.optim.lr_scheduler.SequentialLR( |
| optim, |
| schedulers=[warmup, cosine], |
| milestones=[warm], |
| ) |
| return cosine |
|
|
| def com_loss(y_pred, y_true, reg_w=0.6, rank_w=1.0, huber_beta=1.0, margin=0.0): |
| |
| if huber_beta is None: |
| reg_loss = F.l1_loss(y_pred, y_true, reduction="mean") |
| else: |
| reg_loss = F.smooth_l1_loss(y_pred, y_true, beta=float(huber_beta), reduction="mean") |
| reg_loss = reg_loss * float(reg_w) |
|
|
| |
| B = y_true.shape[0] |
| if B < 2 or float(rank_w) == 0.0: |
| rank_loss = y_pred.new_tensor(0.0) |
| return reg_loss + rank_loss, reg_loss, rank_loss |
| pred_diff = y_pred.unsqueeze(1) - y_pred.unsqueeze(0) |
| true_diff = y_true.unsqueeze(1) - y_true.unsqueeze(0) |
| s = torch.sign(true_diff) |
|
|
| |
| mask = (s != 0).float() |
| |
| rank_mat = F.relu(float(margin) - s * pred_diff) * mask |
| denom = mask.sum().clamp_min(1.0) |
| rank_loss = rank_mat.sum() / denom |
| rank_loss = rank_loss * float(rank_w) |
|
|
| total = reg_loss + rank_loss |
| return total, reg_loss, rank_loss |
|
|
| |
| |
| |
| class VQADataset(torch.utils.data.Dataset): |
| """ |
| Returns per item: |
| rgb: [3, T, H, W] float in [0,1] (RGB) |
| w_art: [1, T, H, W] float in [0,1] |
| w_str: [1, T, H, W] float in [0,1] |
| y: scalar float (MOS, optional normalized) |
| vid: str |
| """ |
| def __init__(self, rows, db_path, clip_len, size, win, win_step, mos_mean=None, mos_std=None): |
| self.rows = rows |
| self.db_path = str(db_path) |
| self.clip_len = int(clip_len) |
| self.size = int(size) |
| self.win = int(win) |
| self.win_step = int(win_step) |
| self.mos_mean = mos_mean |
| self.mos_std = mos_std |
|
|
| def __len__(self): |
| return len(self.rows) |
|
|
| def __getitem__(self, idx): |
| vid, mos = self.rows[int(idx)] |
| num_anchors = self.clip_len |
| size = self.size |
| win = self.win |
| win_step = self.win_step |
|
|
| |
| base_path = Path(self.db_path) / vid |
| video_path = None |
| for ext in ("mp4", "avi", "mkv"): |
| p = Path(str(base_path) + f".{ext}") |
| if p.exists(): |
| video_path = str(p) |
| break |
| if video_path is None: |
| raise FileNotFoundError(f"Cannot find {vid} video") |
|
|
| try: |
| |
| vr = VideoReader(video_path, ctx=cpu(0)) |
| frame_all, w_art_all, w_str_all, anchors_kept = process_video( |
| vr, |
| size=size, num_anchors=num_anchors, win=win, win_step=win_step, |
| ) |
| frames_np, w_art_np, w_str_np = compute_weight_map(frame_all, w_art_all, w_str_all) |
| |
| |
| |
| |
| except Exception as e: |
| print("\n[DATA ERROR]") |
| print("idx:", idx) |
| print("vid:", vid) |
| raise |
| finally: |
| |
| try: |
| if vr is not None: |
| del vr |
| except Exception: |
| pass |
|
|
| |
| T = self.clip_len |
|
|
| |
| frames_sel = [frames_np[i] for i in range(T)] |
| rgb = torch.from_numpy(np.stack(frames_sel, axis=0)).float() |
| rgb = rgb.permute(3, 0, 1, 2).contiguous() / 255.0 |
|
|
| |
| w_art = torch.from_numpy(np.stack([w_art_np[i] for i in range(T)], axis=0).astype(np.float32)).unsqueeze(0).float() |
| w_str = torch.from_numpy(np.stack([w_str_np[i] for i in range(T)], axis=0).astype(np.float32)).unsqueeze(0).float() |
|
|
| |
| y = float(mos) |
| if self.mos_mean is not None and self.mos_std is not None: |
| y = (y - self.mos_mean) / (self.mos_std + 1e-8) |
| y = torch.tensor(y).float() |
| return rgb, w_art, w_str, y, str(vid) |
|
|
|
|
| |
| |
| |
| @torch.no_grad() |
| def _gather_cat(xs): |
| if not xs: |
| return torch.empty(0) |
| return torch.cat(xs, dim=0) |
|
|
| def run_epoch(model, loader, device, *, optim=None, amp=True, mos_mean=None, mos_std=None, desc="", show_pbar=True, log_interval=10): |
| is_train = optim is not None |
| model.train(is_train) |
|
|
| scaler = getattr(run_epoch, "_scaler", None) |
| if scaler is None: |
| device_type = "cuda" if str(device).startswith("cuda") else "cpu" |
| run_epoch._scaler = GradScaler(device_type, enabled=(amp and device_type == "cuda")) |
| scaler = run_epoch._scaler |
|
|
| losses = [] |
| y_all = [] |
| yhat_all = [] |
|
|
| |
| it = loader |
| if show_pbar: |
| it = tqdm(loader, desc=desc, leave=False, dynamic_ncols=True) |
|
|
| for step, (rgb, w_art, w_str, y, vid) in enumerate(it, start=1): |
| rgb = rgb.to(device, non_blocking=True) |
| w_art = w_art.to(device, non_blocking=True) |
| w_str = w_str.to(device, non_blocking=True) |
| y = y.to(device, non_blocking=True).float() |
|
|
| if is_train: |
| optim.zero_grad(set_to_none=True) |
|
|
| device_type = "cuda" if str(device).startswith("cuda") else "cpu" |
| if is_train: |
| with autocast(device_type=device_type, enabled=(amp and device_type == "cuda")): |
| yhat, _aux = model(rgb, w_art, w_str) |
| loss, loss_reg, loss_rank = com_loss(yhat, y) |
|
|
| scaler.scale(loss).backward() |
| scaler.step(optim) |
| scaler.update() |
| else: |
| with torch.inference_mode(): |
| with autocast(device_type=device_type, enabled=(amp and device_type == "cuda")): |
| yhat, _aux = model(rgb, w_art, w_str) |
| loss, loss_reg, loss_rank = com_loss(yhat, y) |
|
|
| loss_cpu = loss.detach().float().cpu() |
| losses.append(loss_cpu) |
| y_all.append(y.detach().float().cpu()) |
| yhat_all.append(yhat.detach().float().cpu()) |
|
|
| |
| if show_pbar and (step % int(log_interval) == 0 or step == len(loader)): |
| avg_loss_so_far = torch.stack(losses).mean().item() |
| lrs = None |
| if is_train and hasattr(optim, "param_groups") and optim.param_groups: |
| lrs = [pg.get("lr", None) for pg in optim.param_groups] |
|
|
| postfix = {"loss": f"{avg_loss_so_far:.4f}"} |
| if lrs is not None: |
| postfix["lrs"] = ",".join([f"{x:.2e}" for x in lrs if x is not None]) |
| it.set_postfix(postfix) |
|
|
| y_all = _gather_cat(y_all) |
| yhat_all = _gather_cat(yhat_all) |
|
|
| |
| if mos_mean is not None and mos_std is not None: |
| y_all = y_all * mos_std + mos_mean |
| yhat_all = yhat_all * mos_std + mos_mean |
|
|
| plcc = pearsonr(y_all, yhat_all).item() if y_all.numel() > 1 else 0.0 |
| srcc = spearmanr(y_all, yhat_all).item() if y_all.numel() > 1 else 0.0 |
| rmse = torch.sqrt(torch.mean((yhat_all - y_all) ** 2)).item() if y_all.numel() > 0 else 0.0 |
|
|
| avg_loss = torch.stack(losses).mean().item() if losses else 0.0 |
| return avg_loss, plcc, srcc, rmse |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| |
| ap.add_argument("--csv_path", default="/home/xinyi/Project/FD-VQA/metadata/LSVQ_TRAIN_metadata.csv") |
| ap.add_argument("--db_path", default="/media/xinyi/server/LSVQ/") |
| ap.add_argument("--split_seed", type=int, default=42) |
| ap.add_argument("--test_ratio", type=float, default=0.2) |
| ap.add_argument("--val_ratio", type=float, default=0.1) |
| |
| ap.add_argument("--clip_len", type=int, default=16) |
| ap.add_argument("--resize", type=int, default=224) |
| ap.add_argument("--win", type=int, default=6) |
| ap.add_argument("--win_step", type=int, default=1) |
| |
| ap.add_argument("--batch_size", type=int, default=8) |
| ap.add_argument("--num_workers", type=int, default=2) |
| ap.add_argument("--device", type=str, default="cuda") |
| ap.add_argument("--no_amp", action="store_true") |
| |
| ap.add_argument("--epochs", type=int, default=35) |
| ap.add_argument("--warmup_epochs", type=int, default=3) |
| ap.add_argument("--lr", type=float, default=1e-5) |
| ap.add_argument("--min_lr", type=float, default=1e-6) |
| ap.add_argument("--finetune_lr", type=float, default=5e-5) |
| ap.add_argument("--weight_decay", type=float, default=1e-2) |
| ap.add_argument("--clip_unfreeze_blocks", type=int, default=4) |
| ap.add_argument("--finetune_last_stage", action="store_true") |
| ap.add_argument("--patience", type=int, default=6) |
| |
| ap.add_argument("--save_dir", type=str, default="checkpoints") |
| ap.add_argument("--save_name", type=str, default="qd_model.pt") |
|
|
| args = ap.parse_args() |
| torch.manual_seed(args.split_seed) |
| device = torch.device(args.device) |
| amp = not bool(args.no_amp) |
|
|
| |
| |
| |
| csv_path = Path(args.csv_path) |
| if csv_path.name == "LSVQ_TRAIN_metadata.csv": |
| |
| test_csv = csv_path.parent / "LSVQ_TEST_metadata.csv" |
| if not test_csv.exists(): |
| raise FileNotFoundError(f"Cannot find LSVQ test csv: {test_csv}") |
| train_all = read_vid_mos_csv(str(csv_path)) |
| test_rows = read_vid_mos_csv(str(test_csv)) |
| train_rows, val_rows = split_train_val( |
| train_all, |
| seed=args.split_seed, |
| val_ratio=args.val_ratio, |
| ) |
| print(f"[LSVQ split] train={len(train_rows)} val={len(val_rows)} test={len(test_rows)}") |
| elif csv_path.name == "KVQ_TRAIN_metadata.csv": |
| |
| val_csv = csv_path.parent / "KVQ_VAL_metadata.csv" |
| test_csv = csv_path.parent / "KVQ_TEST_metadata.csv" |
| train_rows = read_vid_mos_csv(str(csv_path)) |
| val_rows = read_vid_mos_csv(str(val_csv)) |
| test_rows = read_vid_mos_csv(str(test_csv)) |
| print(f"[KVQ split] train={len(train_rows)} val={len(val_rows)} test={len(test_rows)}") |
| else: |
| |
| rows = read_vid_mos_csv(str(csv_path)) |
| train_rows, val_rows, test_rows = split_rows( |
| rows, |
| seed=args.split_seed, |
| test_ratio=args.test_ratio, |
| val_ratio=args.val_ratio, |
| ) |
| |
| |
| |
| |
|
|
| |
| mos_train = np.array([mos for _vid, mos in train_rows], dtype=np.float32) |
| mos_mean = float(mos_train.mean()) if len(mos_train) else 0.0 |
| mos_std = float(mos_train.std()) if len(mos_train) else 1.0 |
| if mos_std <= 1e-8: |
| mos_std = 1.0 |
|
|
| |
| |
| |
| ds_train = VQADataset( |
| train_rows, args.db_path, |
| clip_len=args.clip_len, |
| size=args.resize, |
| win=args.win, |
| win_step=args.win_step, |
| mos_mean=mos_mean, |
| mos_std=mos_std, |
| ) |
| ds_val = VQADataset( |
| val_rows, args.db_path, |
| clip_len=args.clip_len, |
| size=args.resize, |
| win=args.win, |
| win_step=args.win_step, |
| mos_mean=mos_mean, |
| mos_std=mos_std, |
| ) |
| ds_test = VQADataset( |
| test_rows, args.db_path, |
| clip_len=args.clip_len, |
| size=args.resize, |
| win=args.win, |
| win_step=args.win_step, |
| mos_mean=mos_mean, |
| mos_std=mos_std, |
| ) |
|
|
| pin = str(device).startswith("cuda") |
|
|
| loader_train = torch.utils.data.DataLoader( |
| ds_train, |
| batch_size=args.batch_size, |
| shuffle=True, |
| num_workers=args.num_workers, |
| pin_memory=pin, |
| |
| prefetch_factor=4 if args.num_workers > 0 else None, |
| drop_last=False, |
| ) |
| loader_val = torch.utils.data.DataLoader( |
| ds_val, |
| batch_size=args.batch_size, |
| shuffle=False, |
| num_workers=args.num_workers, |
| pin_memory=pin, |
| drop_last=False, |
| ) |
| loader_test = torch.utils.data.DataLoader( |
| ds_test, |
| batch_size=args.batch_size, |
| shuffle=False, |
| num_workers=args.num_workers, |
| pin_memory=pin, |
| drop_last=False, |
| ) |
|
|
| |
| |
| |
| model = QD_MODEL( |
| clip_model="openai/clip-vit-base-patch16", |
| ).to(device) |
|
|
| |
| model.freeze_clip_all() |
|
|
| clip_params_all = [] |
| other_params_all = [] |
| for name, p in model.named_parameters(): |
| if name.startswith("encoder."): |
| clip_params_all.append(p) |
| else: |
| other_params_all.append(p) |
|
|
| param_groups = [] |
| if other_params_all: |
| param_groups.append({"params": other_params_all, "lr": float(args.lr)}) |
| if clip_params_all: |
| param_groups.append({"params": clip_params_all, "lr": float(args.finetune_lr)}) |
|
|
| optim = torch.optim.AdamW(param_groups, weight_decay=float(args.weight_decay)) |
| scheduler = build_scheduler(optim, args) |
| |
| |
| |
| save_dir = Path(args.save_dir) |
| save_dir.mkdir(parents=True, exist_ok=True) |
| save_path = save_dir / args.save_name |
| best_path = save_dir / (Path(args.save_name).stem + ".best.pt") |
| best_weights_path = save_dir / (Path(args.save_name).stem + ".best_weights.pt") |
|
|
| best_val_srcc = -1e18 |
| bad_epochs = 0 |
| did_unfreeze = False |
|
|
| for epoch in tqdm(range(1, int(args.epochs) + 1), desc="Epochs", dynamic_ncols=True): |
|
|
| |
| if (not did_unfreeze) and bool(args.finetune_last_stage) and epoch == (int(args.warmup_epochs) + 1): |
| model.unfreeze_clip_last_blocks(n_blocks=int(args.clip_unfreeze_blocks), also_unfreeze_ln=True) |
| did_unfreeze = True |
|
|
| |
| if hasattr(optim, "param_groups") and len(optim.param_groups) >= 2: |
| optim.param_groups[0]["lr"] = float(args.lr) |
| optim.param_groups[1]["lr"] = float(args.finetune_lr) |
|
|
| print( |
| f"[Stage B] Unfroze CLIP last {int(args.clip_unfreeze_blocks)} blocks | " |
| f"lr={float(args.lr)} finetune_lr={float(args.finetune_lr)}" |
| ) |
|
|
| tr_loss, tr_plcc, tr_srcc, tr_rmse = run_epoch( |
| model, loader_train, device, |
| optim=optim, |
| amp=amp, |
| mos_mean=mos_mean, |
| mos_std=mos_std, |
| desc=f"Train e{epoch:03d}", |
| show_pbar=True, |
| log_interval=10, |
| ) |
| va_loss, va_plcc, va_srcc, va_rmse = run_epoch( |
| model, loader_val, device, |
| optim=None, |
| amp=amp, |
| mos_mean=mos_mean, |
| mos_std=mos_std, |
| desc=f"Val e{epoch:03d}", |
| show_pbar=True, |
| log_interval=10, |
| ) |
| scheduler.step() |
| print( |
| f"epoch {epoch:03d} | " |
| f"train: loss={tr_loss:.4f} plcc={tr_plcc:.4f} srcc={tr_srcc:.4f} rmse={tr_rmse:.4f} | " |
| f"val: loss={va_loss:.4f} plcc={va_plcc:.4f} srcc={va_srcc:.4f} rmse={va_rmse:.4f}" |
| ) |
|
|
| |
| ckpt = { |
| "epoch": epoch, |
| "model": model.state_dict(), |
| "optim": optim.state_dict(), |
| "mos_mean": mos_mean, |
| "mos_std": mos_std, |
| "args": vars(args), |
| "best_val_srcc": best_val_srcc, |
| } |
| torch.save(ckpt, str(save_path)) |
|
|
| |
| if va_srcc > best_val_srcc: |
| best_val_srcc = va_srcc |
| bad_epochs = 0 |
| ckpt["best_val_srcc"] = best_val_srcc |
| torch.save(ckpt, str(best_path)) |
| torch.save(model.state_dict(), str(best_weights_path)) |
| print( |
| f" [best] val_srcc={best_val_srcc:.4f} (val_rmse={va_rmse:.4f}) -> saved " |
| f"{best_path} and {best_weights_path}" |
| ) |
| else: |
| bad_epochs += 1 |
| if bad_epochs >= int(args.patience): |
| print( |
| f"[EarlyStop] val_srcc did not improve for {bad_epochs} epochs. " |
| f"Stop at epoch {epoch}." |
| ) |
| break |
| |
| |
| |
| if best_weights_path.exists(): |
| sd = torch.load(str(best_weights_path), map_location=device, weights_only=True) |
| model.load_state_dict(sd, strict=True) |
| print(f"Loaded best weights: {best_weights_path}") |
| elif best_path.exists(): |
| best = torch.load(str(best_path), map_location=device) |
| model.load_state_dict(best["model"], strict=True) |
| print(f"Loaded best checkpoint: {best_path} (val_srcc={best.get('best_val_srcc', None)})") |
|
|
| te_loss, te_plcc, te_srcc, te_rmse = run_epoch( |
| model, loader_test, device, |
| optim=None, |
| amp=amp, |
| mos_mean=mos_mean, |
| mos_std=mos_std, |
| ) |
| print(f"TEST | loss={te_loss:.4f} plcc={te_plcc:.4f} srcc={te_srcc:.4f} rmse={te_rmse:.4f}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |