FGSVQA / src /train.py
Xinyi Wang
project files
db25ead
import os
os.environ["DECORD_DUPLICATE_WARNING_THRESHOLD"] = "1.0"
import argparse
import csv
from pathlib import Path
import numpy as np
from scipy.stats import pearsonr as _pearsonr_scipy, spearmanr as _spearmanr_scipy
import torch
import torch.nn.functional as F
import cv2
from decord import VideoReader, cpu
from torch.amp import GradScaler, autocast
from tqdm import tqdm
from module.compute_weight_map import process_video, compute_weight_map
from model.qd_model import QD_MODEL
# ----------------------------
# Data utils
# ----------------------------
def read_vid_mos_csv(csv_path):
rows = []
with open(csv_path, "r", encoding="utf-8") as f:
reader = csv.DictReader(f)
if not reader.fieldnames:
raise RuntimeError("CSV has no header")
for r in reader:
vid = str(r["vid"]).strip()
mos = float(r["mos"])
rows.append((vid, mos))
return rows
def split_rows(rows, seed=42, test_ratio=0.2, val_ratio=0.1):
rng = np.random.default_rng(int(seed))
idx = np.arange(len(rows))
rng.shuffle(idx)
n = len(rows)
n_test = int(round(n * test_ratio))
n_train_all = n - n_test # train+val
n_val = int(round(n_train_all * val_ratio)) # val from train_all
val = [rows[i] for i in idx[:n_val]]
train = [rows[i] for i in idx[n_val:n_train_all]]
test = [rows[i] for i in idx[n_train_all:]]
return train, val, test
def split_train_val(rows, seed=42, val_ratio=0.1):
rng = np.random.default_rng(int(seed))
idx = np.arange(len(rows))
rng.shuffle(idx)
n = len(rows)
n_val = int(round(n * val_ratio))
val = [rows[i] for i in idx[:n_val]]
train = [rows[i] for i in idx[n_val:]]
return train, val
def pearsonr(x, y, eps=1e-12):
# PLCC (SciPy): returns a torch scalar tensor so call-site ".item()" still works
if hasattr(x, "detach"):
x = x.detach().cpu().numpy()
if hasattr(y, "detach"):
y = y.detach().cpu().numpy()
x = np.asarray(x).reshape(-1)
y = np.asarray(y).reshape(-1)
# avoid NaN when constant / too short
if x.size < 2 or np.std(x) < eps or np.std(y) < eps:
return torch.tensor(0.0)
r, _p = _pearsonr_scipy(x, y)
if np.isnan(r):
r = 0.0
return torch.tensor(float(r))
def spearmanr(x, y, eps=1e-12):
# SRCC (SciPy): handles ties correctly; returns torch scalar tensor
if hasattr(x, "detach"):
x = x.detach().cpu().numpy()
if hasattr(y, "detach"):
y = y.detach().cpu().numpy()
x = np.asarray(x).reshape(-1)
y = np.asarray(y).reshape(-1)
if x.size < 2 or np.std(x) < eps or np.std(y) < eps:
return torch.tensor(0.0)
r, _p = _spearmanr_scipy(x, y)
if np.isnan(r):
r = 0.0
return torch.tensor(float(r))
# ----------------------------
# Train utils
# ----------------------------
def build_scheduler(optim, args):
warm = int(args.warmup_epochs)
total = int(args.epochs)
warm = max(0, min(warm, total - 1))
# warmup
warmup = torch.optim.lr_scheduler.LinearLR(
optim,
start_factor=0.1,
total_iters=warm if warm > 0 else 1,
)
# cosine warmup
cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
optim,
T_max=(total - warm) if (total - warm) > 0 else 1,
eta_min=float(args.min_lr),
)
if warm > 0:
return torch.optim.lr_scheduler.SequentialLR(
optim,
schedulers=[warmup, cosine],
milestones=[warm],
)
return cosine
def com_loss(y_pred, y_true, reg_w=0.6, rank_w=1.0, huber_beta=1.0, margin=0.0):
# 1) Huber / SmoothL1
if huber_beta is None:
reg_loss = F.l1_loss(y_pred, y_true, reduction="mean")
else:
reg_loss = F.smooth_l1_loss(y_pred, y_true, beta=float(huber_beta), reduction="mean")
reg_loss = reg_loss * float(reg_w)
# 2) pairwise hinge rank
B = y_true.shape[0]
if B < 2 or float(rank_w) == 0.0:
rank_loss = y_pred.new_tensor(0.0)
return reg_loss + rank_loss, reg_loss, rank_loss
pred_diff = y_pred.unsqueeze(1) - y_pred.unsqueeze(0) # [B,B]
true_diff = y_true.unsqueeze(1) - y_true.unsqueeze(0) # [B,B]
s = torch.sign(true_diff) # -1,0,+1
# y_true = pair, ignore
mask = (s != 0).float()
# hinge: max(0, margin - s*(pred_i - pred_j))
rank_mat = F.relu(float(margin) - s * pred_diff) * mask
denom = mask.sum().clamp_min(1.0)
rank_loss = rank_mat.sum() / denom
rank_loss = rank_loss * float(rank_w)
total = reg_loss + rank_loss
return total, reg_loss, rank_loss
# ----------------------------
# Dataset
# ----------------------------
class VQADataset(torch.utils.data.Dataset):
"""
Returns per item:
rgb: [3, T, H, W] float in [0,1] (RGB)
w_art: [1, T, H, W] float in [0,1]
w_str: [1, T, H, W] float in [0,1]
y: scalar float (MOS, optional normalized)
vid: str
"""
def __init__(self, rows, db_path, clip_len, size, win, win_step, mos_mean=None, mos_std=None):
self.rows = rows
self.db_path = str(db_path)
self.clip_len = int(clip_len)
self.size = int(size)
self.win = int(win)
self.win_step = int(win_step)
self.mos_mean = mos_mean
self.mos_std = mos_std
def __len__(self):
return len(self.rows)
def __getitem__(self, idx):
vid, mos = self.rows[int(idx)]
num_anchors = self.clip_len
size = self.size
win = self.win
win_step = self.win_step
# get video path
base_path = Path(self.db_path) / vid
video_path = None
for ext in ("mp4", "avi", "mkv"):
p = Path(str(base_path) + f".{ext}")
if p.exists():
video_path = str(p)
break
if video_path is None:
raise FileNotFoundError(f"Cannot find {vid} video")
try:
# read video
vr = VideoReader(video_path, ctx=cpu(0))
frame_all, w_art_all, w_str_all, anchors_kept = process_video(
vr,
size=size, num_anchors=num_anchors, win=win, win_step=win_step,
)
frames_np, w_art_np, w_str_np = compute_weight_map(frame_all, w_art_all, w_str_all)
# print("frames_np:", frames_np.shape, frames_np.dtype)
# print("w_art_np:", w_art_np.shape, w_art_np.dtype)
# print("w_str_np:", w_str_np.shape, w_str_np.dtype)
# print("anchors_kept:", len(anchors_kept), "example:", anchors_kept[0])
except Exception as e:
print("\n[DATA ERROR]")
print("idx:", idx)
print("vid:", vid)
raise
finally:
# release decord video reader
try:
if vr is not None:
del vr
except Exception:
pass
# fixed length sampling
T = self.clip_len
# frames_sel = [cv2.cvtColor(frames_np[i], cv2.COLOR_BGR2RGB) for i in range(T)] # frame BGR -> RGB: [T,H,W,3] -> [3,T,H,W]
frames_sel = [frames_np[i] for i in range(T)] # RGB frames_np
rgb = torch.from_numpy(np.stack(frames_sel, axis=0)).float()
rgb = rgb.permute(3, 0, 1, 2).contiguous() / 255.0
# W_art / W_str: [T,H,W] -> [1,T,H,W]
w_art = torch.from_numpy(np.stack([w_art_np[i] for i in range(T)], axis=0).astype(np.float32)).unsqueeze(0).float()
w_str = torch.from_numpy(np.stack([w_str_np[i] for i in range(T)], axis=0).astype(np.float32)).unsqueeze(0).float()
# MOS
y = float(mos)
if self.mos_mean is not None and self.mos_std is not None:
y = (y - self.mos_mean) / (self.mos_std + 1e-8)
y = torch.tensor(y).float()
return rgb, w_art, w_str, y, str(vid)
# ----------------------------
# Train
# ----------------------------
@torch.no_grad()
def _gather_cat(xs):
if not xs:
return torch.empty(0)
return torch.cat(xs, dim=0)
def run_epoch(model, loader, device, *, optim=None, amp=True, mos_mean=None, mos_std=None, desc="", show_pbar=True, log_interval=10):
is_train = optim is not None
model.train(is_train)
scaler = getattr(run_epoch, "_scaler", None)
if scaler is None:
device_type = "cuda" if str(device).startswith("cuda") else "cpu"
run_epoch._scaler = GradScaler(device_type, enabled=(amp and device_type == "cuda"))
scaler = run_epoch._scaler
losses = []
y_all = []
yhat_all = []
# ---- tqdm progress bar ----
it = loader
if show_pbar:
it = tqdm(loader, desc=desc, leave=False, dynamic_ncols=True)
for step, (rgb, w_art, w_str, y, vid) in enumerate(it, start=1):
rgb = rgb.to(device, non_blocking=True) # [B,3,T,H,W]
w_art = w_art.to(device, non_blocking=True) # [B,1,T,H,W]
w_str = w_str.to(device, non_blocking=True) # [B,1,T,H,W]
y = y.to(device, non_blocking=True).float() # [B]
if is_train:
optim.zero_grad(set_to_none=True)
device_type = "cuda" if str(device).startswith("cuda") else "cpu"
if is_train:
with autocast(device_type=device_type, enabled=(amp and device_type == "cuda")):
yhat, _aux = model(rgb, w_art, w_str) # yhat: [B]
loss, loss_reg, loss_rank = com_loss(yhat, y)
scaler.scale(loss).backward()
scaler.step(optim)
scaler.update()
else:
with torch.inference_mode():
with autocast(device_type=device_type, enabled=(amp and device_type == "cuda")):
yhat, _aux = model(rgb, w_art, w_str)
loss, loss_reg, loss_rank = com_loss(yhat, y)
loss_cpu = loss.detach().float().cpu()
losses.append(loss_cpu)
y_all.append(y.detach().float().cpu())
yhat_all.append(yhat.detach().float().cpu())
# ---- update bar every log_interval steps ----
if show_pbar and (step % int(log_interval) == 0 or step == len(loader)):
avg_loss_so_far = torch.stack(losses).mean().item()
lrs = None
if is_train and hasattr(optim, "param_groups") and optim.param_groups:
lrs = [pg.get("lr", None) for pg in optim.param_groups]
postfix = {"loss": f"{avg_loss_so_far:.4f}"}
if lrs is not None:
postfix["lrs"] = ",".join([f"{x:.2e}" for x in lrs if x is not None])
it.set_postfix(postfix)
y_all = _gather_cat(y_all)
yhat_all = _gather_cat(yhat_all)
# ---- 反标准化:在 MOS 原尺度上算相关系数 ----
if mos_mean is not None and mos_std is not None:
y_all = y_all * mos_std + mos_mean
yhat_all = yhat_all * mos_std + mos_mean
plcc = pearsonr(y_all, yhat_all).item() if y_all.numel() > 1 else 0.0
srcc = spearmanr(y_all, yhat_all).item() if y_all.numel() > 1 else 0.0
rmse = torch.sqrt(torch.mean((yhat_all - y_all) ** 2)).item() if y_all.numel() > 0 else 0.0
avg_loss = torch.stack(losses).mean().item() if losses else 0.0
return avg_loss, plcc, srcc, rmse
def main():
ap = argparse.ArgumentParser()
# ----- data -----
ap.add_argument("--csv_path", default="/home/xinyi/Project/FD-VQA/metadata/LSVQ_TRAIN_metadata.csv")
ap.add_argument("--db_path", default="/media/xinyi/server/LSVQ/")
ap.add_argument("--split_seed", type=int, default=42)
ap.add_argument("--test_ratio", type=float, default=0.2)
ap.add_argument("--val_ratio", type=float, default=0.1) # train 80%
# ----- video processing -----
ap.add_argument("--clip_len", type=int, default=16)
ap.add_argument("--resize", type=int, default=224)
ap.add_argument("--win", type=int, default=6)
ap.add_argument("--win_step", type=int, default=1)
# ----- runtime -----
ap.add_argument("--batch_size", type=int, default=8)
ap.add_argument("--num_workers", type=int, default=2)
ap.add_argument("--device", type=str, default="cuda")
ap.add_argument("--no_amp", action="store_true")
# ----- hyperparams -----
ap.add_argument("--epochs", type=int, default=35)
ap.add_argument("--warmup_epochs", type=int, default=3)
ap.add_argument("--lr", type=float, default=1e-5)
ap.add_argument("--min_lr", type=float, default=1e-6)
ap.add_argument("--finetune_lr", type=float, default=5e-5)
ap.add_argument("--weight_decay", type=float, default=1e-2)
ap.add_argument("--clip_unfreeze_blocks", type=int, default=4)
ap.add_argument("--finetune_last_stage", action="store_true")
ap.add_argument("--patience", type=int, default=6)
# ----- save -----
ap.add_argument("--save_dir", type=str, default="checkpoints")
ap.add_argument("--save_name", type=str, default="qd_model.pt")
args = ap.parse_args()
torch.manual_seed(args.split_seed)
device = torch.device(args.device)
amp = not bool(args.no_amp)
# ----------------------------
# Load rows and split
# ----------------------------
csv_path = Path(args.csv_path)
if csv_path.name == "LSVQ_TRAIN_metadata.csv":
# LSVQ official split
test_csv = csv_path.parent / "LSVQ_TEST_metadata.csv"
if not test_csv.exists():
raise FileNotFoundError(f"Cannot find LSVQ test csv: {test_csv}")
train_all = read_vid_mos_csv(str(csv_path))
test_rows = read_vid_mos_csv(str(test_csv))
train_rows, val_rows = split_train_val(
train_all,
seed=args.split_seed,
val_ratio=args.val_ratio,
)
print(f"[LSVQ split] train={len(train_rows)} val={len(val_rows)} test={len(test_rows)}")
elif csv_path.name == "KVQ_TRAIN_metadata.csv":
# KVQ challenge split
val_csv = csv_path.parent / "KVQ_VAL_metadata.csv"
test_csv = csv_path.parent / "KVQ_TEST_metadata.csv"
train_rows = read_vid_mos_csv(str(csv_path))
val_rows = read_vid_mos_csv(str(val_csv))
test_rows = read_vid_mos_csv(str(test_csv))
print(f"[KVQ split] train={len(train_rows)} val={len(val_rows)} test={len(test_rows)}")
else:
# default split for other datasets
rows = read_vid_mos_csv(str(csv_path))
train_rows, val_rows, test_rows = split_rows(
rows,
seed=args.split_seed,
test_ratio=args.test_ratio,
val_ratio=args.val_ratio,
)
# print("sizes:", len(rows), len(train_rows), len(val_rows), len(test_rows))
# print("train first 3:", train_rows[:3])
# print("val first 3:", val_rows[:3])
# print("test first 3:", test_rows[:3])
# MOS normalization stats from train split
mos_train = np.array([mos for _vid, mos in train_rows], dtype=np.float32)
mos_mean = float(mos_train.mean()) if len(mos_train) else 0.0
mos_std = float(mos_train.std()) if len(mos_train) else 1.0
if mos_std <= 1e-8:
mos_std = 1.0
# ----------------------------
# DB and datasets
# ----------------------------
ds_train = VQADataset(
train_rows, args.db_path,
clip_len=args.clip_len,
size=args.resize,
win=args.win,
win_step=args.win_step,
mos_mean=mos_mean,
mos_std=mos_std,
)
ds_val = VQADataset(
val_rows, args.db_path,
clip_len=args.clip_len,
size=args.resize,
win=args.win,
win_step=args.win_step,
mos_mean=mos_mean,
mos_std=mos_std,
)
ds_test = VQADataset(
test_rows, args.db_path,
clip_len=args.clip_len,
size=args.resize,
win=args.win,
win_step=args.win_step,
mos_mean=mos_mean,
mos_std=mos_std,
)
pin = str(device).startswith("cuda")
loader_train = torch.utils.data.DataLoader(
ds_train,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=pin,
# persistent_workers=(args.num_workers > 0),
prefetch_factor=4 if args.num_workers > 0 else None,
drop_last=False,
)
loader_val = torch.utils.data.DataLoader(
ds_val,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
pin_memory=pin,
drop_last=False,
)
loader_test = torch.utils.data.DataLoader(
ds_test,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
pin_memory=pin,
drop_last=False,
)
# ----------------------------
# Model
# ----------------------------
model = QD_MODEL(
clip_model="openai/clip-vit-base-patch16",
).to(device)
# Stage A: freeze CLIP
model.freeze_clip_all()
clip_params_all = []
other_params_all = []
for name, p in model.named_parameters():
if name.startswith("encoder."):
clip_params_all.append(p)
else:
other_params_all.append(p)
param_groups = []
if other_params_all:
param_groups.append({"params": other_params_all, "lr": float(args.lr)})
if clip_params_all:
param_groups.append({"params": clip_params_all, "lr": float(args.finetune_lr)})
optim = torch.optim.AdamW(param_groups, weight_decay=float(args.weight_decay))
scheduler = build_scheduler(optim, args)
# ----------------------------
# Train loop
# ----------------------------
save_dir = Path(args.save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
save_path = save_dir / args.save_name
best_path = save_dir / (Path(args.save_name).stem + ".best.pt")
best_weights_path = save_dir / (Path(args.save_name).stem + ".best_weights.pt")
best_val_srcc = -1e18
bad_epochs = 0
did_unfreeze = False
for epoch in tqdm(range(1, int(args.epochs) + 1), desc="Epochs", dynamic_ncols=True):
# Stage B: optional CLIP finetune after warmup
if (not did_unfreeze) and bool(args.finetune_last_stage) and epoch == (int(args.warmup_epochs) + 1):
model.unfreeze_clip_last_blocks(n_blocks=int(args.clip_unfreeze_blocks), also_unfreeze_ln=True)
did_unfreeze = True
# 不重建 optim, 保留 Adam 状态,只更新 lr
if hasattr(optim, "param_groups") and len(optim.param_groups) >= 2:
optim.param_groups[0]["lr"] = float(args.lr)
optim.param_groups[1]["lr"] = float(args.finetune_lr)
print(
f"[Stage B] Unfroze CLIP last {int(args.clip_unfreeze_blocks)} blocks | "
f"lr={float(args.lr)} finetune_lr={float(args.finetune_lr)}"
)
tr_loss, tr_plcc, tr_srcc, tr_rmse = run_epoch(
model, loader_train, device,
optim=optim,
amp=amp,
mos_mean=mos_mean,
mos_std=mos_std,
desc=f"Train e{epoch:03d}",
show_pbar=True,
log_interval=10,
)
va_loss, va_plcc, va_srcc, va_rmse = run_epoch(
model, loader_val, device,
optim=None,
amp=amp,
mos_mean=mos_mean,
mos_std=mos_std,
desc=f"Val e{epoch:03d}",
show_pbar=True,
log_interval=10,
)
scheduler.step()
print(
f"epoch {epoch:03d} | "
f"train: loss={tr_loss:.4f} plcc={tr_plcc:.4f} srcc={tr_srcc:.4f} rmse={tr_rmse:.4f} | "
f"val: loss={va_loss:.4f} plcc={va_plcc:.4f} srcc={va_srcc:.4f} rmse={va_rmse:.4f}"
)
# Save "last" checkpoint every epoch
ckpt = {
"epoch": epoch,
"model": model.state_dict(),
"optim": optim.state_dict(),
"mos_mean": mos_mean,
"mos_std": mos_std,
"args": vars(args),
"best_val_srcc": best_val_srcc,
}
torch.save(ckpt, str(save_path))
# Save best by val SRCC (higher is better)
if va_srcc > best_val_srcc:
best_val_srcc = va_srcc
bad_epochs = 0
ckpt["best_val_srcc"] = best_val_srcc
torch.save(ckpt, str(best_path))
torch.save(model.state_dict(), str(best_weights_path))
print(
f" [best] val_srcc={best_val_srcc:.4f} (val_rmse={va_rmse:.4f}) -> saved "
f"{best_path} and {best_weights_path}"
)
else:
bad_epochs += 1
if bad_epochs >= int(args.patience):
print(
f"[EarlyStop] val_srcc did not improve for {bad_epochs} epochs. "
f"Stop at epoch {epoch}."
)
break
# ----------------------------
# Test (load best)
# ----------------------------
if best_weights_path.exists():
sd = torch.load(str(best_weights_path), map_location=device, weights_only=True)
model.load_state_dict(sd, strict=True)
print(f"Loaded best weights: {best_weights_path}")
elif best_path.exists():
best = torch.load(str(best_path), map_location=device)
model.load_state_dict(best["model"], strict=True)
print(f"Loaded best checkpoint: {best_path} (val_srcc={best.get('best_val_srcc', None)})")
te_loss, te_plcc, te_srcc, te_rmse = run_epoch(
model, loader_test, device,
optim=None,
amp=amp,
mos_mean=mos_mean,
mos_std=mos_std,
)
print(f"TEST | loss={te_loss:.4f} plcc={te_plcc:.4f} srcc={te_srcc:.4f} rmse={te_rmse:.4f}")
if __name__ == "__main__":
main()