Lively_sketch / evaluation.py
Harini1995's picture
Upload 17 files
d487538 verified
import os
import cv2
import numpy as np
import torch
import torch.nn.functional as F
import face_alignment
import lpips
import pandas as pd
from unet_acc import DenseMotion, UNetGenerator, warp_image
device = "cuda" if torch.cuda.is_available() else "cpu"
dense_motion = DenseMotion(kp_channels=68).to(device)
generator = UNetGenerator(in_channels=4).to(device)
ckpt = torch.load("checkpoints/best.pth", map_location=device)
dense_motion.load_state_dict(ckpt["dense_motion"])
generator.load_state_dict(ckpt["generator"])
dense_motion.eval()
generator.eval()
lpips_fn = lpips.LPIPS(net="alex").to(device)
fa = face_alignment.FaceAlignment(
face_alignment.LandmarksType.TWO_D,
device=device
)
# Metrics
def landmark_distance(pred, gt):
pred = cv2.cvtColor(pred, cv2.COLOR_GRAY2RGB)
gt = cv2.cvtColor(gt, cv2.COLOR_GRAY2RGB)
pl = fa.get_landmarks(pred)
gl = fa.get_landmarks(gt)
if pl is None or gl is None:
return None
pl, gl = pl[0], gl[0]
eye_dist = np.linalg.norm(gl[36] - gl[45]) + 1e-6
return np.mean(np.linalg.norm(pl - gl, axis=1)) / eye_dist
def lpips_score(pred, gt):
pred = pred.repeat(1, 3, 1, 1)
gt = gt.repeat(1, 3, 1, 1)
return lpips_fn(pred, gt).item()
def l1_score(pred, gt):
return F.l1_loss(pred, gt).item()
def temporal_jitter(frames):
diffs = []
for i in range(1, len(frames)):
diffs.append(torch.mean(torch.abs(frames[i] - frames[i - 1])).item())
return np.std(diffs), np.mean(diffs)
LOCK_IDXS = list(range(36, 48)) + list(range(48, 68))
def infer_no_warp(src):
B, _, H, W = src.shape
flow = torch.zeros(B, 2, H, W).to(device)
occ = torch.ones(B, 1, H, W).to(device)
return torch.clamp(generator(torch.cat([src, flow, occ], 1)), 0, 1)
def infer_warp(src, src_kp, drv_kp):
flow, occ = dense_motion(src_kp, drv_kp)
warped = warp_image(src, flow)
return torch.clamp(generator(torch.cat([warped, flow, occ], 1)), 0, 1)
def infer_warp_lock(src, src_kp, drv_kp):
kp = src_kp.clone()
kp[:, LOCK_IDXS] = drv_kp[:, LOCK_IDXS]
flow, occ = dense_motion(src_kp, kp)
warped = warp_image(src, flow)
return torch.clamp(generator(torch.cat([warped, flow, occ], 1)), 0, 1)
def infer_warp_lock_mask(src, src_kp, drv_kp, mask):
kp = src_kp.clone()
kp[:, LOCK_IDXS] = drv_kp[:, LOCK_IDXS]
flow, occ = dense_motion(src_kp, kp)
warped = warp_image(src, flow)
pred = generator(torch.cat([warped, flow, occ], 1))
return torch.clamp(pred * mask + src * (1 - mask), 0, 1)
def evaluate_sequence(src, src_kp, drv_kps, gt_frames, mask, mode):
preds_torch = []
lmd, lp, l1 = [], [], []
with torch.no_grad():
for t, drv_kp in enumerate(drv_kps):
if mode == "no_warp":
pred = infer_no_warp(src)
elif mode == "warp":
pred = infer_warp(src, src_kp, drv_kp)
elif mode == "warp_lock":
pred = infer_warp_lock(src, src_kp, drv_kp)
elif mode == "warp_lock_mask":
pred = infer_warp_lock_mask(src, src_kp, drv_kp, mask)
else:
raise ValueError
gt = gt_frames[t]
pred_np = (pred.detach().cpu().squeeze().numpy() * 255).astype(np.uint8)
gt_np = (gt.detach().cpu().squeeze().numpy() * 255).astype(np.uint8)
lm = landmark_distance(pred_np, gt_np)
if lm is not None:
lmd.append(lm)
lp.append(lpips_score(pred, gt))
l1.append(l1_score(pred, gt))
preds_torch.append(pred)
jit_std, _ = temporal_jitter(preds_torch)
return {
"LMD": np.mean(lmd) if len(lmd) > 0 else np.nan,
"LPIPS": np.mean(lp),
"Jitter": jit_std
}
def run_all(src, src_kp, drv_kps, gt_frames, mask):
rows = []
for mode in ["no_warp", "warp", "warp_lock", "warp_lock_mask"]:
print(f"Evaluating {mode}")
res = evaluate_sequence(src, src_kp, drv_kps, gt_frames, mask, mode)
res["Method"] = mode
rows.append(res)
df = pd.DataFrame(rows)
df = df[["Method", "LMD", "LPIPS", "Jitter"]]
df.to_csv("ablation_results.csv", index=False)
print(df)
if __name__ == "__main__":
src_img = cv2.imread(r"motion_transfer\new_dataset\test\dataset\87\frames\00000.jpg", cv2.IMREAD_GRAYSCALE)
src = torch.tensor(
src_img / 255.0,
dtype=torch.float32
).unsqueeze(0).unsqueeze(0).to(device)
src_kp = torch.tensor(
np.load(r"motion_transfer\new_dataset\test\dataset\87\combined\00000.npy"),
dtype=torch.float32
).permute(2, 0, 1).unsqueeze(0).to(device)
drv_kps = []
gt_frames = []
for f in sorted(os.listdir(r"motion_transfer\new_dataset\test\dataset\87\frames")):
gt = cv2.imread(os.path.join(r"motion_transfer\new_dataset\test\dataset\87\frames", f), cv2.IMREAD_GRAYSCALE)
gt_frames.append(
torch.tensor(
gt / 255.0,
dtype=torch.float32
).unsqueeze(0).unsqueeze(0).to(device)
)
kp = torch.tensor(
np.load(os.path.join(r"motion_transfer\new_dataset\test\dataset\87\combined", f.replace(".jpg", ".npy"))),
dtype=torch.float32
).permute(2, 0, 1).unsqueeze(0).to(device)
drv_kps.append(kp)
mask = torch.ones_like(src)
run_all(src, src_kp, drv_kps, gt_frames, mask)