FGSVQA / src /transfer.py
Xinyi Wang
project files
db25ead
import argparse
from pathlib import Path
import numpy as np
import torch
from train import (
read_vid_mos_csv,
split_rows,
VQADataset,
run_epoch,
build_scheduler,
)
from model.qd_model import QD_MODEL
def load_pretrained_weights(model, pretrained_path, device):
p = Path(pretrained_path)
obj = torch.load(str(p), map_location=device, weights_only=True)
# *.pt: dict checkpoint
if isinstance(obj, dict) and "model" in obj:
model.load_state_dict(obj["model"], strict=True)
return obj
# best_weights.pt: state_dict
else:
model.load_state_dict(obj, strict=True)
return None
def make_loaders(rows_train, rows_val, rows_test, args, mos_mean, mos_std, device):
ds_train = VQADataset(
rows_train, args.db_path,
clip_len=args.clip_len, size=args.resize, win=args.win, win_step=args.win_step,
mos_mean=mos_mean, mos_std=mos_std,
)
ds_val = VQADataset(
rows_val, args.db_path,
clip_len=args.clip_len, size=args.resize, win=args.win, win_step=args.win_step,
mos_mean=mos_mean, mos_std=mos_std,
)
ds_test = VQADataset(
rows_test, args.db_path,
clip_len=args.clip_len, size=args.resize, win=args.win, win_step=args.win_step,
mos_mean=mos_mean, mos_std=mos_std,
)
pin = str(device).startswith("cuda")
loader_train = torch.utils.data.DataLoader(
ds_train, batch_size=args.batch_size, shuffle=True,
num_workers=args.num_workers, pin_memory=pin,
persistent_workers=(args.num_workers > 0),
prefetch_factor=4 if args.num_workers > 0 else None,
drop_last=False,
)
loader_val = torch.utils.data.DataLoader(
ds_val, batch_size=args.batch_size, shuffle=False,
num_workers=args.num_workers, pin_memory=pin, drop_last=False,
)
loader_test = torch.utils.data.DataLoader(
ds_test, batch_size=args.batch_size, shuffle=False,
num_workers=args.num_workers, pin_memory=pin, drop_last=False,
)
return loader_train, loader_val, loader_test
def main():
ap = argparse.ArgumentParser()
# ----- mode -----
ap.add_argument("--mode", choices=["finetune", "test_only"], required=True)
ap.add_argument("--pretrained", default="/home/xinyi/Project/FD-VQA/src/checkpoints/kvq/qd_model.best.pt", help="pretrain model path")
# ----- data -----
ap.add_argument("--csv_path", default="/home/xinyi/Project/FD-VQA/metadata/SHORTS-SDR-DATASET_metadata.csv")
ap.add_argument("--db_path", default="/media/xinyi/server/video_dataset/shorts-hdr-dataset/sdr/")
ap.add_argument("--split_seed", type=int, default=0)
ap.add_argument("--test_ratio", type=float, default=0.2)
ap.add_argument("--val_ratio", type=float, default=0.1) # train 80%
# ----- video processing -----
ap.add_argument("--clip_len", type=int, default=16)
ap.add_argument("--resize", type=int, default=224)
ap.add_argument("--win", type=int, default=6)
ap.add_argument("--win_step", type=int, default=1)
# ----- runtime -----
ap.add_argument("--batch_size", type=int, default=8)
ap.add_argument("--num_workers", type=int, default=4)
ap.add_argument("--device", type=str, default="cuda")
ap.add_argument("--no_amp", action="store_true")
# ----- finetune hyperparams -----
ap.add_argument("--epochs", type=int, default=10)
ap.add_argument("--warmup_epochs", type=int, default=1)
ap.add_argument("--lr", type=float, default=1e-5)
ap.add_argument("--min_lr", type=float, default=1e-6)
ap.add_argument("--finetune_lr", type=float, default=5e-5)
ap.add_argument("--weight_decay", type=float, default=1e-2)
ap.add_argument("--clip_unfreeze_blocks", type=int, default=4)
ap.add_argument("--finetune_last_stage", action="store_true")
ap.add_argument("--patience", type=int, default=6)
# ----- save -----
ap.add_argument("--save_dir", type=str, default="checkpoints_transfer")
ap.add_argument("--save_name", type=str, default="transfer.pt")
# ----- MOS normalization control -----
# finetune: use target train mean/std
ap.add_argument("--test_only_norm", choices=["none", "use_source_ckpt"], default="use_source_ckpt")
args = ap.parse_args()
torch.manual_seed(args.split_seed)
device = torch.device(args.device)
amp = not bool(args.no_amp)
# ----------------------------
# Load target rows and split
# ----------------------------
rows = read_vid_mos_csv(args.csv_path)
if args.mode == "finetune":
csv_path = Path(args.csv_path)
if csv_path.name == "KVQ_TRAIN_metadata.csv":
# KVQ challenge split
val_csv = csv_path.parent / "KVQ_VAL_metadata.csv"
test_csv = csv_path.parent / "KVQ_TEST_metadata.csv"
train_rows = read_vid_mos_csv(str(csv_path))
val_rows = read_vid_mos_csv(str(val_csv))
test_rows = read_vid_mos_csv(str(test_csv))
print(f"[KVQ split] train={len(train_rows)} val={len(val_rows)} test={len(test_rows)}")
else:
train_rows, val_rows, test_rows = split_rows(
rows, seed=args.split_seed, test_ratio=args.test_ratio, val_ratio=args.val_ratio
)
mos_train = np.array([mos for _vid, mos in train_rows], dtype=np.float32)
mos_mean = float(mos_train.mean()) if len(mos_train) else 0.0
mos_std = float(mos_train.std()) if len(mos_train) else 1.0
if mos_std <= 1e-8:
mos_std = 1.0
else:
# test_only
train_rows, val_rows, test_rows = [], [], rows
mos_mean, mos_std = None, None
# ----------------------------
# Build model, load pretrained
# ----------------------------
model = QD_MODEL(
clip_model="openai/clip-vit-base-patch16",
).to(device)
ckpt = load_pretrained_weights(model, args.pretrained, device)
print(f"Loaded pretrained: {args.pretrained}")
# ----------------------------
# Decide normalization for test_only
# ----------------------------
if args.mode == "test_only":
if args.test_only_norm == "use_source_ckpt" and isinstance(ckpt, dict):
# 注:这样对 SRCC/PLCC 没影响
mos_mean = ckpt.get("mos_mean", None)
mos_std = ckpt.get("mos_std", None)
if mos_mean is None or mos_std is None:
mos_mean, mos_std = None, None
print("[warn] pretrained ckpt has no mos_mean/std, fallback to no normalization.")
else:
print(f"test_only uses source mos_mean/std from ckpt: mean={mos_mean:.4f}, std={mos_std:.4f}")
else:
print("test_only uses no MOS normalization.")
# ----------------------------
# Mode: test_only
# ----------------------------
if args.mode == "test_only":
ds_test = VQADataset(
test_rows, args.db_path,
clip_len=args.clip_len, size=args.resize, win=args.win, win_step=args.win_step,
mos_mean=mos_mean, mos_std=mos_std,
)
pin = str(device).startswith("cuda")
loader_test = torch.utils.data.DataLoader(
ds_test, batch_size=args.batch_size, shuffle=False,
num_workers=args.num_workers, pin_memory=pin, drop_last=False,
)
print("num_test_rows =", len(test_rows))
print("len(ds_test) =", len(ds_test))
print("len(loader_test) =", len(loader_test), "batch_size =", args.batch_size)
te_loss, te_plcc, te_srcc, te_rmse = run_epoch(
model, loader_test, device,
optim=None, amp=amp, mos_mean=mos_mean, mos_std=mos_std,
desc="TestOnly", show_pbar=True
)
print(f"TEST_ONLY | loss={te_loss:.4f} plcc={te_plcc:.4f} srcc={te_srcc:.4f} rmse={te_rmse:.4f}")
return
# ----------------------------
# DataLoaders
# ----------------------------
loader_train, loader_val, loader_test = make_loaders(
train_rows, val_rows, test_rows, args,
mos_mean=mos_mean, mos_std=mos_std, device=device
)
# ----------------------------
# Mode: finetune
# ----------------------------
model.freeze_clip_all()
did_unfreeze = False
clip_params_all = []
other_params_all = []
for name, p in model.named_parameters():
if name.startswith("encoder."):
clip_params_all.append(p)
else:
other_params_all.append(p)
param_groups = []
if other_params_all:
param_groups.append({"params": other_params_all, "lr": float(args.lr)})
if clip_params_all:
param_groups.append({"params": clip_params_all, "lr": float(args.finetune_lr)})
optim = torch.optim.AdamW(param_groups, weight_decay=float(args.weight_decay))
scheduler = build_scheduler(optim, args)
save_dir = Path(args.save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
last_path = save_dir / args.save_name
best_path = save_dir / (Path(args.save_name).stem + ".best.pt")
best_weights_path = save_dir / (Path(args.save_name).stem + ".best_weights.pt")
best_val_srcc = -1e18
bad_epochs = 0
for epoch in range(1, int(args.epochs) + 1):
if (not did_unfreeze) and bool(args.finetune_last_stage) and epoch == (int(args.warmup_epochs) + 1):
model.unfreeze_clip_last_blocks(n_blocks=int(args.clip_unfreeze_blocks), also_unfreeze_ln=True)
did_unfreeze = True
print(f"[Finetune] Unfroze CLIP last {int(args.clip_unfreeze_blocks)} blocks")
tr_loss, tr_plcc, tr_srcc, tr_rmse = run_epoch(
model, loader_train, device,
optim=optim, amp=amp, mos_mean=mos_mean, mos_std=mos_std,
desc=f"FT Train e{epoch:03d}", show_pbar=True
)
va_loss, va_plcc, va_srcc, va_rmse = run_epoch(
model, loader_val, device,
optim=None, amp=amp, mos_mean=mos_mean, mos_std=mos_std,
desc=f"FT Val e{epoch:03d}", show_pbar=True
)
scheduler.step()
print(
f"epoch {epoch:03d} | "
f"train: loss={tr_loss:.4f} plcc={tr_plcc:.4f} srcc={tr_srcc:.4f} rmse={tr_rmse:.4f} | "
f"val: loss={va_loss:.4f} plcc={va_plcc:.4f} srcc={va_srcc:.4f} rmse={va_rmse:.4f}"
)
ckpt_out = {
"epoch": epoch,
"model": model.state_dict(),
"optim": optim.state_dict(),
"mos_mean": mos_mean,
"mos_std": mos_std,
"args": vars(args),
"best_val_srcc": best_val_srcc,
}
torch.save(ckpt_out, str(last_path))
if va_srcc > best_val_srcc:
best_val_srcc = va_srcc
bad_epochs = 0
ckpt_out["best_val_srcc"] = best_val_srcc
torch.save(ckpt_out, str(best_path))
torch.save(model.state_dict(), str(best_weights_path))
print(f" [best] val_srcc={best_val_srcc:.4f} -> saved {best_weights_path}")
else:
bad_epochs += 1
if bad_epochs >= int(args.patience):
print(f"[EarlyStop] val_srcc not improved for {bad_epochs} epochs. Stop.")
break
# load best and test
if best_weights_path.exists():
sd = torch.load(str(best_weights_path), map_location=device, weights_only=True)
model.load_state_dict(sd, strict=True)
print(f"Loaded best weights: {best_weights_path}")
te_loss, te_plcc, te_srcc, te_rmse = run_epoch(
model, loader_test, device,
optim=None, amp=amp, mos_mean=mos_mean, mos_std=mos_std,
desc="FT Test", show_pbar=True
)
print(f"FINETUNE TEST | loss={te_loss:.4f} plcc={te_plcc:.4f} srcc={te_srcc:.4f} rmse={te_rmse:.4f}")
if __name__ == "__main__":
main()