| | import os
|
| | import cv2
|
| | import numpy as np
|
| | import torch
|
| | import torch.nn.functional as F
|
| | import face_alignment
|
| | import lpips
|
| | import pandas as pd
|
| |
|
| | from unet_acc import DenseMotion, UNetGenerator, warp_image
|
| |
|
| |
|
| | device = "cuda" if torch.cuda.is_available() else "cpu"
|
| |
|
| | dense_motion = DenseMotion(kp_channels=68).to(device)
|
| | generator = UNetGenerator(in_channels=4).to(device)
|
| |
|
| | ckpt = torch.load("checkpoints/best.pth", map_location=device)
|
| | dense_motion.load_state_dict(ckpt["dense_motion"])
|
| | generator.load_state_dict(ckpt["generator"])
|
| |
|
| | dense_motion.eval()
|
| | generator.eval()
|
| |
|
| | lpips_fn = lpips.LPIPS(net="alex").to(device)
|
| |
|
| | fa = face_alignment.FaceAlignment(
|
| | face_alignment.LandmarksType.TWO_D,
|
| | device=device
|
| | )
|
| |
|
| |
|
| |
|
| |
|
| | def landmark_distance(pred, gt):
|
| | pred = cv2.cvtColor(pred, cv2.COLOR_GRAY2RGB)
|
| | gt = cv2.cvtColor(gt, cv2.COLOR_GRAY2RGB)
|
| |
|
| | pl = fa.get_landmarks(pred)
|
| | gl = fa.get_landmarks(gt)
|
| | if pl is None or gl is None:
|
| | return None
|
| |
|
| | pl, gl = pl[0], gl[0]
|
| | eye_dist = np.linalg.norm(gl[36] - gl[45]) + 1e-6
|
| | return np.mean(np.linalg.norm(pl - gl, axis=1)) / eye_dist
|
| |
|
| |
|
| | def lpips_score(pred, gt):
|
| | pred = pred.repeat(1, 3, 1, 1)
|
| | gt = gt.repeat(1, 3, 1, 1)
|
| | return lpips_fn(pred, gt).item()
|
| |
|
| |
|
| | def l1_score(pred, gt):
|
| | return F.l1_loss(pred, gt).item()
|
| |
|
| |
|
| | def temporal_jitter(frames):
|
| | diffs = []
|
| | for i in range(1, len(frames)):
|
| | diffs.append(torch.mean(torch.abs(frames[i] - frames[i - 1])).item())
|
| | return np.std(diffs), np.mean(diffs)
|
| |
|
| |
|
| | LOCK_IDXS = list(range(36, 48)) + list(range(48, 68))
|
| |
|
| | def infer_no_warp(src):
|
| | B, _, H, W = src.shape
|
| | flow = torch.zeros(B, 2, H, W).to(device)
|
| | occ = torch.ones(B, 1, H, W).to(device)
|
| | return torch.clamp(generator(torch.cat([src, flow, occ], 1)), 0, 1)
|
| |
|
| |
|
| | def infer_warp(src, src_kp, drv_kp):
|
| | flow, occ = dense_motion(src_kp, drv_kp)
|
| | warped = warp_image(src, flow)
|
| | return torch.clamp(generator(torch.cat([warped, flow, occ], 1)), 0, 1)
|
| |
|
| |
|
| | def infer_warp_lock(src, src_kp, drv_kp):
|
| | kp = src_kp.clone()
|
| | kp[:, LOCK_IDXS] = drv_kp[:, LOCK_IDXS]
|
| | flow, occ = dense_motion(src_kp, kp)
|
| | warped = warp_image(src, flow)
|
| | return torch.clamp(generator(torch.cat([warped, flow, occ], 1)), 0, 1)
|
| |
|
| |
|
| | def infer_warp_lock_mask(src, src_kp, drv_kp, mask):
|
| | kp = src_kp.clone()
|
| | kp[:, LOCK_IDXS] = drv_kp[:, LOCK_IDXS]
|
| | flow, occ = dense_motion(src_kp, kp)
|
| | warped = warp_image(src, flow)
|
| | pred = generator(torch.cat([warped, flow, occ], 1))
|
| | return torch.clamp(pred * mask + src * (1 - mask), 0, 1)
|
| |
|
| |
|
| | def evaluate_sequence(src, src_kp, drv_kps, gt_frames, mask, mode):
|
| | preds_torch = []
|
| | lmd, lp, l1 = [], [], []
|
| |
|
| | with torch.no_grad():
|
| | for t, drv_kp in enumerate(drv_kps):
|
| | if mode == "no_warp":
|
| | pred = infer_no_warp(src)
|
| | elif mode == "warp":
|
| | pred = infer_warp(src, src_kp, drv_kp)
|
| | elif mode == "warp_lock":
|
| | pred = infer_warp_lock(src, src_kp, drv_kp)
|
| | elif mode == "warp_lock_mask":
|
| | pred = infer_warp_lock_mask(src, src_kp, drv_kp, mask)
|
| | else:
|
| | raise ValueError
|
| |
|
| | gt = gt_frames[t]
|
| |
|
| | pred_np = (pred.detach().cpu().squeeze().numpy() * 255).astype(np.uint8)
|
| | gt_np = (gt.detach().cpu().squeeze().numpy() * 255).astype(np.uint8)
|
| |
|
| | lm = landmark_distance(pred_np, gt_np)
|
| | if lm is not None:
|
| | lmd.append(lm)
|
| |
|
| | lp.append(lpips_score(pred, gt))
|
| | l1.append(l1_score(pred, gt))
|
| | preds_torch.append(pred)
|
| |
|
| | jit_std, _ = temporal_jitter(preds_torch)
|
| |
|
| | return {
|
| | "LMD": np.mean(lmd) if len(lmd) > 0 else np.nan,
|
| | "LPIPS": np.mean(lp),
|
| | "Jitter": jit_std
|
| | }
|
| |
|
| | def run_all(src, src_kp, drv_kps, gt_frames, mask):
|
| | rows = []
|
| | for mode in ["no_warp", "warp", "warp_lock", "warp_lock_mask"]:
|
| | print(f"Evaluating {mode}")
|
| | res = evaluate_sequence(src, src_kp, drv_kps, gt_frames, mask, mode)
|
| | res["Method"] = mode
|
| | rows.append(res)
|
| |
|
| | df = pd.DataFrame(rows)
|
| | df = df[["Method", "LMD", "LPIPS", "Jitter"]]
|
| | df.to_csv("ablation_results.csv", index=False)
|
| | print(df)
|
| | if __name__ == "__main__":
|
| | src_img = cv2.imread(r"motion_transfer\new_dataset\test\dataset\87\frames\00000.jpg", cv2.IMREAD_GRAYSCALE)
|
| | src = torch.tensor(
|
| | src_img / 255.0,
|
| | dtype=torch.float32
|
| | ).unsqueeze(0).unsqueeze(0).to(device)
|
| |
|
| | src_kp = torch.tensor(
|
| | np.load(r"motion_transfer\new_dataset\test\dataset\87\combined\00000.npy"),
|
| | dtype=torch.float32
|
| | ).permute(2, 0, 1).unsqueeze(0).to(device)
|
| |
|
| | drv_kps = []
|
| | gt_frames = []
|
| |
|
| | for f in sorted(os.listdir(r"motion_transfer\new_dataset\test\dataset\87\frames")):
|
| | gt = cv2.imread(os.path.join(r"motion_transfer\new_dataset\test\dataset\87\frames", f), cv2.IMREAD_GRAYSCALE)
|
| | gt_frames.append(
|
| | torch.tensor(
|
| | gt / 255.0,
|
| | dtype=torch.float32
|
| | ).unsqueeze(0).unsqueeze(0).to(device)
|
| | )
|
| |
|
| | kp = torch.tensor(
|
| | np.load(os.path.join(r"motion_transfer\new_dataset\test\dataset\87\combined", f.replace(".jpg", ".npy"))),
|
| | dtype=torch.float32
|
| | ).permute(2, 0, 1).unsqueeze(0).to(device)
|
| | drv_kps.append(kp)
|
| |
|
| | mask = torch.ones_like(src)
|
| | run_all(src, src_kp, drv_kps, gt_frames, mask)
|
| |
|