import os import cv2 import numpy as np import torch import torch.nn.functional as F import face_alignment import lpips import pandas as pd from unet_acc import DenseMotion, UNetGenerator, warp_image device = "cuda" if torch.cuda.is_available() else "cpu" dense_motion = DenseMotion(kp_channels=68).to(device) generator = UNetGenerator(in_channels=4).to(device) ckpt = torch.load("checkpoints/best.pth", map_location=device) dense_motion.load_state_dict(ckpt["dense_motion"]) generator.load_state_dict(ckpt["generator"]) dense_motion.eval() generator.eval() lpips_fn = lpips.LPIPS(net="alex").to(device) fa = face_alignment.FaceAlignment( face_alignment.LandmarksType.TWO_D, device=device ) # Metrics def landmark_distance(pred, gt): pred = cv2.cvtColor(pred, cv2.COLOR_GRAY2RGB) gt = cv2.cvtColor(gt, cv2.COLOR_GRAY2RGB) pl = fa.get_landmarks(pred) gl = fa.get_landmarks(gt) if pl is None or gl is None: return None pl, gl = pl[0], gl[0] eye_dist = np.linalg.norm(gl[36] - gl[45]) + 1e-6 return np.mean(np.linalg.norm(pl - gl, axis=1)) / eye_dist def lpips_score(pred, gt): pred = pred.repeat(1, 3, 1, 1) gt = gt.repeat(1, 3, 1, 1) return lpips_fn(pred, gt).item() def l1_score(pred, gt): return F.l1_loss(pred, gt).item() def temporal_jitter(frames): diffs = [] for i in range(1, len(frames)): diffs.append(torch.mean(torch.abs(frames[i] - frames[i - 1])).item()) return np.std(diffs), np.mean(diffs) LOCK_IDXS = list(range(36, 48)) + list(range(48, 68)) def infer_no_warp(src): B, _, H, W = src.shape flow = torch.zeros(B, 2, H, W).to(device) occ = torch.ones(B, 1, H, W).to(device) return torch.clamp(generator(torch.cat([src, flow, occ], 1)), 0, 1) def infer_warp(src, src_kp, drv_kp): flow, occ = dense_motion(src_kp, drv_kp) warped = warp_image(src, flow) return torch.clamp(generator(torch.cat([warped, flow, occ], 1)), 0, 1) def infer_warp_lock(src, src_kp, drv_kp): kp = src_kp.clone() kp[:, LOCK_IDXS] = drv_kp[:, LOCK_IDXS] flow, occ = dense_motion(src_kp, kp) warped = warp_image(src, flow) return torch.clamp(generator(torch.cat([warped, flow, occ], 1)), 0, 1) def infer_warp_lock_mask(src, src_kp, drv_kp, mask): kp = src_kp.clone() kp[:, LOCK_IDXS] = drv_kp[:, LOCK_IDXS] flow, occ = dense_motion(src_kp, kp) warped = warp_image(src, flow) pred = generator(torch.cat([warped, flow, occ], 1)) return torch.clamp(pred * mask + src * (1 - mask), 0, 1) def evaluate_sequence(src, src_kp, drv_kps, gt_frames, mask, mode): preds_torch = [] lmd, lp, l1 = [], [], [] with torch.no_grad(): for t, drv_kp in enumerate(drv_kps): if mode == "no_warp": pred = infer_no_warp(src) elif mode == "warp": pred = infer_warp(src, src_kp, drv_kp) elif mode == "warp_lock": pred = infer_warp_lock(src, src_kp, drv_kp) elif mode == "warp_lock_mask": pred = infer_warp_lock_mask(src, src_kp, drv_kp, mask) else: raise ValueError gt = gt_frames[t] pred_np = (pred.detach().cpu().squeeze().numpy() * 255).astype(np.uint8) gt_np = (gt.detach().cpu().squeeze().numpy() * 255).astype(np.uint8) lm = landmark_distance(pred_np, gt_np) if lm is not None: lmd.append(lm) lp.append(lpips_score(pred, gt)) l1.append(l1_score(pred, gt)) preds_torch.append(pred) jit_std, _ = temporal_jitter(preds_torch) return { "LMD": np.mean(lmd) if len(lmd) > 0 else np.nan, "LPIPS": np.mean(lp), "Jitter": jit_std } def run_all(src, src_kp, drv_kps, gt_frames, mask): rows = [] for mode in ["no_warp", "warp", "warp_lock", "warp_lock_mask"]: print(f"Evaluating {mode}") res = evaluate_sequence(src, src_kp, drv_kps, gt_frames, mask, mode) res["Method"] = mode rows.append(res) df = pd.DataFrame(rows) df = df[["Method", "LMD", "LPIPS", "Jitter"]] df.to_csv("ablation_results.csv", index=False) print(df) if __name__ == "__main__": src_img = cv2.imread(r"motion_transfer\new_dataset\test\dataset\87\frames\00000.jpg", cv2.IMREAD_GRAYSCALE) src = torch.tensor( src_img / 255.0, dtype=torch.float32 ).unsqueeze(0).unsqueeze(0).to(device) src_kp = torch.tensor( np.load(r"motion_transfer\new_dataset\test\dataset\87\combined\00000.npy"), dtype=torch.float32 ).permute(2, 0, 1).unsqueeze(0).to(device) drv_kps = [] gt_frames = [] for f in sorted(os.listdir(r"motion_transfer\new_dataset\test\dataset\87\frames")): gt = cv2.imread(os.path.join(r"motion_transfer\new_dataset\test\dataset\87\frames", f), cv2.IMREAD_GRAYSCALE) gt_frames.append( torch.tensor( gt / 255.0, dtype=torch.float32 ).unsqueeze(0).unsqueeze(0).to(device) ) kp = torch.tensor( np.load(os.path.join(r"motion_transfer\new_dataset\test\dataset\87\combined", f.replace(".jpg", ".npy"))), dtype=torch.float32 ).permute(2, 0, 1).unsqueeze(0).to(device) drv_kps.append(kp) mask = torch.ones_like(src) run_all(src, src_kp, drv_kps, gt_frames, mask)