Harini1995 commited on
Commit
d487538
·
verified ·
1 Parent(s): 31a03de

Upload 17 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ demo/video_3.mp4 filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import face_alignment
5
+ import numpy as np
6
+ import mediapipe as mp
7
+ import tempfile
8
+ from multiprocessing import cpu_count
9
+ from PIL import Image
10
+ import streamlit as st
11
+ from test import test_single_image
12
+
13
+
14
+ reference_heatmap_dir = "motion_transfer/dataset_single/reference_heatmap"
15
+ output_dir = r"motion_transfer\dataset_single\test_heatmap"
16
+ final_output = "motion_transfer/outputs/final_result.mp4"
17
+ os.makedirs(output_dir, exist_ok=True)
18
+
19
+ num_workers = min(cpu_count(), 4)
20
+ target_size = 256
21
+ SIGMA = 2.0
22
+ NUM_FACE_POINTS = 68
23
+
24
+
25
+ fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D,
26
+ device="cuda" if torch.cuda.is_available() else "cpu")
27
+ mp_face_detection = mp.solutions.face_detection
28
+ detector = mp_face_detection.FaceDetection(model_selection=1, min_detection_confidence=0.5)
29
+
30
+
31
+
32
+ def gaussian_heatmaps(points, H, W, sigma=2.0):
33
+ yy, xx = np.meshgrid(np.arange(H), np.arange(W), indexing='ij')
34
+ heatmaps = np.exp(-((xx[..., None] - points[:, 0]) ** 2 +
35
+ (yy[..., None] - points[:, 1]) ** 2) / (2 * sigma ** 2))
36
+ return heatmaps.astype(np.float32)
37
+
38
+
39
+ def extract_keypoints(hmap):
40
+ kps = []
41
+ for i in range(hmap.shape[2]):
42
+ y, x = np.where(hmap[:, :, i] > 0)
43
+ if len(x) > 0:
44
+ kps.append([np.mean(x), np.mean(y)])
45
+ else:
46
+ kps.append([0, 0])
47
+ return np.array(kps, dtype=np.float32)
48
+
49
+
50
+ def trim_video(input_path, output_path, max_seconds=7):
51
+ cap = cv2.VideoCapture(input_path)
52
+ if not cap.isOpened():
53
+ print("Error opening video")
54
+ return False
55
+ fps = int(round(cap.get(cv2.CAP_PROP_FPS)))
56
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
57
+ max_frames = min(total_frames, fps * max_seconds)
58
+
59
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
60
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
61
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
62
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
63
+
64
+ count = 0
65
+ while cap.isOpened() and count < max_frames:
66
+ ret, frame = cap.read()
67
+ if not ret:
68
+ break
69
+ out.write(frame)
70
+ count += 1
71
+ cap.release()
72
+ out.release()
73
+ return True
74
+
75
+
76
+ def crop_head_with_bg(img_rgb, target_size=256, margin_top=0.6, margin_sides=0.3, margin_bottom=0.4):
77
+
78
+ ih, iw, _ = img_rgb.shape
79
+ results = detector.process(img_rgb)
80
+ if not results.detections:
81
+ return None
82
+
83
+ det = results.detections[0]
84
+ bbox = det.location_data.relative_bounding_box
85
+ x1 = int(bbox.xmin * iw)
86
+ y1 = int(bbox.ymin * ih)
87
+ w = int(bbox.width * iw)
88
+ h = int(bbox.height * ih)
89
+
90
+
91
+ x1 = max(0, int(x1 - w * margin_sides))
92
+ x2 = min(iw, int(x1 + w * (1 + 2 * margin_sides)))
93
+ y1 = max(0, int(y1 - h * margin_top))
94
+ y2 = min(ih, int(y1 + h * (1 + margin_bottom + margin_top)))
95
+
96
+ cropped = img_rgb[y1:y2, x1:x2]
97
+ ch, cw = cropped.shape[:2]
98
+
99
+
100
+ scale = target_size / max(ch, cw)
101
+ new_w, new_h = int(cw * scale), int(ch * scale)
102
+ resized = cv2.resize(cropped, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
103
+
104
+
105
+ blurred_bg = cv2.GaussianBlur(resized, (51, 51), 0)
106
+ background = cv2.resize(blurred_bg, (target_size, target_size), interpolation=cv2.INTER_AREA)
107
+
108
+ y_offset = (target_size - new_h) // 2
109
+ x_offset = (target_size - new_w) // 2
110
+ background[y_offset:y_offset + new_h, x_offset:x_offset + new_w] = resized
111
+
112
+ return background
113
+
114
+
115
+ # Streamlit
116
+ st.title("Sketch to Live")
117
+
118
+ src_img = st.file_uploader("Upload face sketch", type=["jpg", "png"])
119
+ cropped_head = None
120
+
121
+ if src_img is not None:
122
+
123
+ pil_img = Image.open(src_img).convert("RGB")
124
+ img_rgb = np.array(pil_img)
125
+ ih, iw, _ = img_rgb.shape
126
+ st.write(f"Uploaded image size: {iw}×{ih}")
127
+
128
+ if ih < target_size or iw < target_size:
129
+ st.warning("Image too small ({iw}×{ih}).Please upload one larger than {target_size}×{target_size}.")
130
+ else:
131
+ cropped_head = crop_head_with_bg(img_rgb, target_size=target_size)
132
+ if cropped_head is None:
133
+ st.warning("No face detected. Try another image.")
134
+ else:
135
+ st.subheader("Face Preview")
136
+ st.image(
137
+ cropped_head,
138
+ caption="Cropped Head",
139
+ width=256,
140
+ channels="RGB",
141
+ output_format="PNG",
142
+ )
143
+
144
+ # Save
145
+ cv2.imwrite("cropped_head.png",
146
+ cv2.cvtColor(cropped_head, cv2.COLOR_RGB2BGR),
147
+ [cv2.IMWRITE_PNG_COMPRESSION, 0])
148
+
149
+
150
+
151
+ if st.button("Lively Sketch"):
152
+ if cropped_head is None:
153
+ st.error("Please upload a face image.")
154
+ else:
155
+ progress_text = st.empty()
156
+ progress_bar = st.progress(0)
157
+ frame_preview = st.empty()
158
+ progress_text.text("Processing")
159
+
160
+ H, W = cropped_head.shape[:2]
161
+ fa_out = fa.get_landmarks(cropped_head)
162
+ if fa_out is None or len(fa_out) == 0:
163
+ st.error(" No face landmarks detected.")
164
+ else:
165
+ face68 = fa_out[0].astype(np.float32)
166
+ single_heatmap = gaussian_heatmaps(face68, H, W, sigma=SIGMA)
167
+ single_kp = face68
168
+
169
+ ref_files = sorted([f for f in os.listdir(reference_heatmap_dir) if f.endswith(".npy")])
170
+ if len(ref_files) == 0:
171
+ st.error(" No reference heatmaps found!")
172
+ else:
173
+ ref_heatmaps = [np.load(os.path.join(reference_heatmap_dir, f)) for f in ref_files]
174
+ ref_kp_list = [extract_keypoints(hm) for hm in ref_heatmaps]
175
+ ref_base_kp = ref_kp_list[0]
176
+ motion_vectors = [kp - ref_base_kp for kp in ref_kp_list]
177
+
178
+ os.makedirs(output_dir, exist_ok=True)
179
+ total_frames = len(motion_vectors)
180
+
181
+ for frame_idx, displacement in enumerate(motion_vectors):
182
+ moved_kp = single_kp + displacement
183
+ new_heatmap = gaussian_heatmaps(moved_kp, H, W, sigma=SIGMA)
184
+ np.save(os.path.join(output_dir, f"{frame_idx:05d}.npy"), new_heatmap)
185
+
186
+ frame_preview.image(cropped_head, width=128)
187
+ progress_bar.progress(int((frame_idx + 1) / total_frames * 100))
188
+
189
+ temp_img_path = "cropped_head.png"
190
+ test_single_image(temp_img_path, output_dir, final_output)
191
+
192
+ trimmed_output = "trimmed_result.mp4"
193
+ trim_video(final_output, trimmed_output, max_seconds=7)
194
+
195
+ progress_bar.progress(100)
196
+ progress_text.text("Done!")
197
+ frame_preview.empty()
198
+
199
+ st.success("Sketch-to-Live")
200
+ with open(trimmed_output, "rb") as f:
201
+ st.download_button("Download Result Video", f, file_name="Sketch.mp4")
202
+
203
+
204
+
205
+ st.markdown("""
206
+ <div style='position: fixed; bottom: 10px; right: 10px; color: gray; font-size: 12px;'>
207
+ <b>Inspired by prior explicit motion transfer methods.</b>
208
+ </div>
209
+ """, unsafe_allow_html=True)
checkpoints/best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f0c6eb8c5a3aa5f789f2dede6d1af99c90af21a4debe90af8a8c18ff9e8e07ce
3
+ size 33537446
checkpoints/best.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b51ee2b0821c1eaacb4f0b1d41a626cce85e04e6ba0290e326084328efec2b5
3
+ size 11171232
checkpoints/best_epoch_0.0145/best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8aee4c0acb8f77244b9b5bd5bd6f1037ba010b1f4f4b542fa38c021d6bc06fd1
3
+ size 66736953
demo/image_1.png ADDED
demo/image_2.png ADDED
demo/image_3.jpeg ADDED
demo/video_1.mp4 ADDED
Binary file (40.6 kB). View file
 
demo/video_2.mp4 ADDED
Binary file (85.4 kB). View file
 
demo/video_3.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:43d9b05b2135038d80f0b4438f6e20084fb509653f9da9870bc70da6f4dd7746
3
+ size 105725
evaluation.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import face_alignment
7
+ import lpips
8
+ import pandas as pd
9
+
10
+ from unet_acc import DenseMotion, UNetGenerator, warp_image
11
+
12
+
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+
15
+ dense_motion = DenseMotion(kp_channels=68).to(device)
16
+ generator = UNetGenerator(in_channels=4).to(device)
17
+
18
+ ckpt = torch.load("checkpoints/best.pth", map_location=device)
19
+ dense_motion.load_state_dict(ckpt["dense_motion"])
20
+ generator.load_state_dict(ckpt["generator"])
21
+
22
+ dense_motion.eval()
23
+ generator.eval()
24
+
25
+ lpips_fn = lpips.LPIPS(net="alex").to(device)
26
+
27
+ fa = face_alignment.FaceAlignment(
28
+ face_alignment.LandmarksType.TWO_D,
29
+ device=device
30
+ )
31
+
32
+
33
+ # Metrics
34
+
35
+ def landmark_distance(pred, gt):
36
+ pred = cv2.cvtColor(pred, cv2.COLOR_GRAY2RGB)
37
+ gt = cv2.cvtColor(gt, cv2.COLOR_GRAY2RGB)
38
+
39
+ pl = fa.get_landmarks(pred)
40
+ gl = fa.get_landmarks(gt)
41
+ if pl is None or gl is None:
42
+ return None
43
+
44
+ pl, gl = pl[0], gl[0]
45
+ eye_dist = np.linalg.norm(gl[36] - gl[45]) + 1e-6
46
+ return np.mean(np.linalg.norm(pl - gl, axis=1)) / eye_dist
47
+
48
+
49
+ def lpips_score(pred, gt):
50
+ pred = pred.repeat(1, 3, 1, 1)
51
+ gt = gt.repeat(1, 3, 1, 1)
52
+ return lpips_fn(pred, gt).item()
53
+
54
+
55
+ def l1_score(pred, gt):
56
+ return F.l1_loss(pred, gt).item()
57
+
58
+
59
+ def temporal_jitter(frames):
60
+ diffs = []
61
+ for i in range(1, len(frames)):
62
+ diffs.append(torch.mean(torch.abs(frames[i] - frames[i - 1])).item())
63
+ return np.std(diffs), np.mean(diffs)
64
+
65
+
66
+ LOCK_IDXS = list(range(36, 48)) + list(range(48, 68))
67
+
68
+ def infer_no_warp(src):
69
+ B, _, H, W = src.shape
70
+ flow = torch.zeros(B, 2, H, W).to(device)
71
+ occ = torch.ones(B, 1, H, W).to(device)
72
+ return torch.clamp(generator(torch.cat([src, flow, occ], 1)), 0, 1)
73
+
74
+
75
+ def infer_warp(src, src_kp, drv_kp):
76
+ flow, occ = dense_motion(src_kp, drv_kp)
77
+ warped = warp_image(src, flow)
78
+ return torch.clamp(generator(torch.cat([warped, flow, occ], 1)), 0, 1)
79
+
80
+
81
+ def infer_warp_lock(src, src_kp, drv_kp):
82
+ kp = src_kp.clone()
83
+ kp[:, LOCK_IDXS] = drv_kp[:, LOCK_IDXS]
84
+ flow, occ = dense_motion(src_kp, kp)
85
+ warped = warp_image(src, flow)
86
+ return torch.clamp(generator(torch.cat([warped, flow, occ], 1)), 0, 1)
87
+
88
+
89
+ def infer_warp_lock_mask(src, src_kp, drv_kp, mask):
90
+ kp = src_kp.clone()
91
+ kp[:, LOCK_IDXS] = drv_kp[:, LOCK_IDXS]
92
+ flow, occ = dense_motion(src_kp, kp)
93
+ warped = warp_image(src, flow)
94
+ pred = generator(torch.cat([warped, flow, occ], 1))
95
+ return torch.clamp(pred * mask + src * (1 - mask), 0, 1)
96
+
97
+
98
+ def evaluate_sequence(src, src_kp, drv_kps, gt_frames, mask, mode):
99
+ preds_torch = []
100
+ lmd, lp, l1 = [], [], []
101
+
102
+ with torch.no_grad():
103
+ for t, drv_kp in enumerate(drv_kps):
104
+ if mode == "no_warp":
105
+ pred = infer_no_warp(src)
106
+ elif mode == "warp":
107
+ pred = infer_warp(src, src_kp, drv_kp)
108
+ elif mode == "warp_lock":
109
+ pred = infer_warp_lock(src, src_kp, drv_kp)
110
+ elif mode == "warp_lock_mask":
111
+ pred = infer_warp_lock_mask(src, src_kp, drv_kp, mask)
112
+ else:
113
+ raise ValueError
114
+
115
+ gt = gt_frames[t]
116
+
117
+ pred_np = (pred.detach().cpu().squeeze().numpy() * 255).astype(np.uint8)
118
+ gt_np = (gt.detach().cpu().squeeze().numpy() * 255).astype(np.uint8)
119
+
120
+ lm = landmark_distance(pred_np, gt_np)
121
+ if lm is not None:
122
+ lmd.append(lm)
123
+
124
+ lp.append(lpips_score(pred, gt))
125
+ l1.append(l1_score(pred, gt))
126
+ preds_torch.append(pred)
127
+
128
+ jit_std, _ = temporal_jitter(preds_torch)
129
+
130
+ return {
131
+ "LMD": np.mean(lmd) if len(lmd) > 0 else np.nan,
132
+ "LPIPS": np.mean(lp),
133
+ "Jitter": jit_std
134
+ }
135
+
136
+ def run_all(src, src_kp, drv_kps, gt_frames, mask):
137
+ rows = []
138
+ for mode in ["no_warp", "warp", "warp_lock", "warp_lock_mask"]:
139
+ print(f"Evaluating {mode}")
140
+ res = evaluate_sequence(src, src_kp, drv_kps, gt_frames, mask, mode)
141
+ res["Method"] = mode
142
+ rows.append(res)
143
+
144
+ df = pd.DataFrame(rows)
145
+ df = df[["Method", "LMD", "LPIPS", "Jitter"]]
146
+ df.to_csv("ablation_results.csv", index=False)
147
+ print(df)
148
+ if __name__ == "__main__":
149
+ src_img = cv2.imread(r"motion_transfer\new_dataset\test\dataset\87\frames\00000.jpg", cv2.IMREAD_GRAYSCALE)
150
+ src = torch.tensor(
151
+ src_img / 255.0,
152
+ dtype=torch.float32
153
+ ).unsqueeze(0).unsqueeze(0).to(device)
154
+
155
+ src_kp = torch.tensor(
156
+ np.load(r"motion_transfer\new_dataset\test\dataset\87\combined\00000.npy"),
157
+ dtype=torch.float32
158
+ ).permute(2, 0, 1).unsqueeze(0).to(device)
159
+
160
+ drv_kps = []
161
+ gt_frames = []
162
+
163
+ for f in sorted(os.listdir(r"motion_transfer\new_dataset\test\dataset\87\frames")):
164
+ gt = cv2.imread(os.path.join(r"motion_transfer\new_dataset\test\dataset\87\frames", f), cv2.IMREAD_GRAYSCALE)
165
+ gt_frames.append(
166
+ torch.tensor(
167
+ gt / 255.0,
168
+ dtype=torch.float32
169
+ ).unsqueeze(0).unsqueeze(0).to(device)
170
+ )
171
+
172
+ kp = torch.tensor(
173
+ np.load(os.path.join(r"motion_transfer\new_dataset\test\dataset\87\combined", f.replace(".jpg", ".npy"))),
174
+ dtype=torch.float32
175
+ ).permute(2, 0, 1).unsqueeze(0).to(device)
176
+ drv_kps.append(kp)
177
+
178
+ mask = torch.ones_like(src)
179
+ run_all(src, src_kp, drv_kps, gt_frames, mask)
model/unet_acc.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from PIL import Image
7
+ import numpy as np
8
+ import cv2
9
+
10
+
11
+ class SketchMotionDataset(Dataset):
12
+ def __init__(self, data_root, transform=None):
13
+ self.transform = transform
14
+ self.data = []
15
+
16
+ persons = sorted(os.listdir(data_root))
17
+ for person in persons:
18
+ person_dir = os.path.join(data_root, person)
19
+ source_frame = os.path.join(person_dir, f"{person}.jpg")
20
+ frames_dir = os.path.join(person_dir, "frames")
21
+ heatmap_dir = os.path.join(person_dir, "combined")
22
+
23
+ frame_files = sorted(os.listdir(frames_dir))
24
+ heatmap_files = sorted(os.listdir(heatmap_dir))
25
+
26
+ for f_file, h_file in zip(frame_files, heatmap_files):
27
+ self.data.append({
28
+ "source_frame": source_frame,
29
+ "driving_frame": os.path.join(frames_dir, f_file),
30
+ "driving_heatmap": os.path.join(heatmap_dir, h_file),
31
+ "source_heatmap": os.path.join(heatmap_dir, heatmap_files[0]) # first frame heatmap
32
+ })
33
+
34
+ def __len__(self):
35
+ return len(self.data)
36
+
37
+ def __getitem__(self, idx):
38
+ item = self.data[idx]
39
+ src_img = Image.open(item["source_frame"]).convert("L")
40
+ drv_img = Image.open(item["driving_frame"]).convert("L")
41
+
42
+ src_img = torch.tensor(np.array(src_img)/255.0, dtype=torch.float32).unsqueeze(0)
43
+ drv_img = torch.tensor(np.array(drv_img)/255.0, dtype=torch.float32).unsqueeze(0)
44
+
45
+ src_kp = torch.tensor(np.load(item["source_heatmap"]), dtype=torch.float32).permute(2,0,1)
46
+ drv_kp = torch.tensor(np.load(item["driving_heatmap"]), dtype=torch.float32).permute(2,0,1)
47
+
48
+ return src_img, drv_img, src_kp, drv_kp
49
+
50
+
51
+ class DenseMotion(nn.Module):
52
+ def __init__(self, kp_channels=68):
53
+ super().__init__()
54
+ self.conv = nn.Sequential(
55
+ nn.Conv2d(kp_channels*2, 128, 7, padding=3),
56
+ nn.ReLU(),
57
+ nn.Conv2d(128, 64, 3, padding=1),
58
+ nn.ReLU(),
59
+ nn.Conv2d(64, 3, 3, padding=1)
60
+ )
61
+
62
+ def forward(self, src_kp, drv_kp):
63
+ x = torch.cat([src_kp, drv_kp], dim=1)
64
+ out = self.conv(x)
65
+ flow = out[:, :2, :, :]
66
+ occ = torch.sigmoid(out[:, 2:3, :, :])
67
+ return flow, occ
68
+
69
+ class UNetGenerator(nn.Module):
70
+ def __init__(self, in_channels=4, out_channels=1):
71
+ super().__init__()
72
+
73
+ def conv_block(in_c, out_c):
74
+ return nn.Sequential(
75
+ nn.Conv2d(in_c, out_c, 3, padding=1),
76
+ nn.ReLU(inplace=True),
77
+ nn.Conv2d(out_c, out_c, 3, padding=1),
78
+ nn.ReLU(inplace=True)
79
+ )
80
+
81
+ self.enc1 = conv_block(in_channels, 64)
82
+ self.pool1 = nn.MaxPool2d(2)
83
+ self.enc2 = conv_block(64, 128)
84
+ self.pool2 = nn.MaxPool2d(2)
85
+ self.bottleneck = conv_block(128, 256)
86
+ self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
87
+ self.dec2 = conv_block(256, 128)
88
+ self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
89
+ self.dec1 = conv_block(128, 64)
90
+ self.final = nn.Conv2d(64, out_channels, 1)
91
+
92
+ def forward(self, x):
93
+ e1 = self.enc1(x)
94
+ e2 = self.enc2(self.pool1(e1))
95
+ b = self.bottleneck(self.pool2(e2))
96
+ d2 = self.up2(b)
97
+ d2 = self.dec2(torch.cat([d2, e2], dim=1))
98
+ d1 = self.up1(d2)
99
+ d1 = self.dec1(torch.cat([d1, e1], dim=1))
100
+ return self.final(d1)
101
+
102
+ def warp_image(img, flow):
103
+ B, C, H, W = img.shape
104
+ grid_y, grid_x = torch.meshgrid(torch.linspace(-1,1,H), torch.linspace(-1,1,W))
105
+ grid = torch.stack((grid_x, grid_y),2).unsqueeze(0).repeat(B,1,1,1).to(img.device)
106
+ flow_norm = flow.permute(0,2,3,1) / torch.tensor([W/2, H/2]).to(img.device)
107
+ warped = nn.functional.grid_sample(img, grid + flow_norm, align_corners=True)
108
+ return warped
109
+
110
+
111
+
112
+ def save_checkpoint(state, is_best, checkpoint_dir="checkpoints", filename="last.pth", best_filename="best.pth"):
113
+ os.makedirs(checkpoint_dir, exist_ok=True)
114
+ filepath = os.path.join(checkpoint_dir, filename)
115
+ torch.save(state, filepath)
116
+ if is_best:
117
+ bestpath = os.path.join(checkpoint_dir, best_filename)
118
+ torch.save(state, bestpath)
119
+ print(f"Saved new best checkpoint: {bestpath}")
120
+ else:
121
+ print(f"Saved checkpoint: {filepath}")
122
+
123
+ def train(data_root, epochs=500, resume_checkpoint="checkpoints/last.pth"):
124
+ device = "cuda" if torch.cuda.is_available() else "cpu"
125
+
126
+ dataset = SketchMotionDataset(data_root)
127
+ dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
128
+
129
+ dense_motion = DenseMotion(kp_channels=68).to(device)
130
+ generator = UNetGenerator(in_channels=4).to(device)
131
+
132
+ optimizer = optim.Adam(list(dense_motion.parameters()) + list(generator.parameters()), lr=1e-4)
133
+ criterion = nn.L1Loss()
134
+
135
+ start_epoch = 0
136
+ best_loss = float("inf")
137
+
138
+
139
+ if os.path.exists(resume_checkpoint):
140
+ print(f"Resuming training from {resume_checkpoint}")
141
+ checkpoint = torch.load(resume_checkpoint, map_location=device)
142
+ dense_motion.load_state_dict(checkpoint["dense_motion"])
143
+ generator.load_state_dict(checkpoint["generator"])
144
+ optimizer.load_state_dict(checkpoint["optimizer"])
145
+ start_epoch = checkpoint.get("epoch", 0)
146
+ best_loss = checkpoint.get("loss", float("inf"))
147
+ print(f"Resumed from epoch {start_epoch}, last loss = {best_loss:.4f}")
148
+ else:
149
+ print("Starting new training")
150
+
151
+ # Training
152
+ for epoch in range(start_epoch, epochs):
153
+ epoch_loss = 0.0
154
+
155
+ for src_img, drv_img, src_kp, drv_kp in dataloader:
156
+ src_img, drv_img = src_img.to(device), drv_img.to(device)
157
+ src_kp, drv_kp = src_kp.to(device), drv_kp.to(device)
158
+
159
+ flow, occ = dense_motion(src_kp, drv_kp)
160
+ warped_src = warp_image(src_img, flow)
161
+
162
+ unet_input = torch.cat([warped_src, flow, occ], dim=1)
163
+ pred = generator(unet_input)
164
+
165
+ loss = criterion(pred, drv_img)
166
+ optimizer.zero_grad()
167
+ loss.backward()
168
+ optimizer.step()
169
+
170
+ epoch_loss += loss.item()
171
+
172
+ avg_loss = epoch_loss / len(dataloader)
173
+ print(f"Epoch {epoch+1}/{epochs} - Loss: {avg_loss:.4f}")
174
+
175
+ # Save checkpoint
176
+ is_best = avg_loss < best_loss
177
+ best_loss = min(avg_loss, best_loss)
178
+
179
+ save_checkpoint({
180
+ "epoch": epoch + 1,
181
+ "dense_motion": dense_motion.state_dict(),
182
+ "generator": generator.state_dict(),
183
+ "optimizer": optimizer.state_dict(),
184
+ "loss": avg_loss,
185
+ }, is_best, checkpoint_dir="checkpoints")
186
+
187
+ print("Training completed")
188
+
189
+
190
+
191
+ def generate_video(data_root, output_dir="outputs"):
192
+ os.makedirs(output_dir, exist_ok=True)
193
+ device = "cuda" if torch.cuda.is_available() else "cpu"
194
+
195
+ dense_motion = DenseMotion(kp_channels=68).to(device)
196
+ generator = UNetGenerator(in_channels=4).to(device)
197
+
198
+ # Load best checkpoint
199
+ checkpoint = torch.load("checkpoints/best.pth", map_location=device)
200
+ dense_motion.load_state_dict(checkpoint["dense_motion"])
201
+ generator.load_state_dict(checkpoint["generator"])
202
+ dense_motion.eval()
203
+ generator.eval()
204
+
205
+ for person in sorted(os.listdir(data_root)):
206
+ person_dir = os.path.join(data_root, person)
207
+ source_frame = os.path.join(person_dir, f"{person}.jpg")
208
+ heatmap_dir = os.path.join(person_dir, "combined")
209
+
210
+ src_img = Image.open(source_frame).convert("L")
211
+ src_img = torch.tensor(np.array(src_img)/255.0, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)
212
+
213
+ src_kp = torch.tensor(np.load(os.path.join(heatmap_dir, sorted(os.listdir(heatmap_dir))[0])), dtype=torch.float32).permute(2,0,1).unsqueeze(0).to(device)
214
+
215
+ generated_frames = []
216
+ for h_file in sorted(os.listdir(heatmap_dir)):
217
+ drv_kp = torch.tensor(np.load(os.path.join(heatmap_dir, h_file)), dtype=torch.float32).permute(2,0,1).unsqueeze(0).to(device)
218
+
219
+ flow, occ = dense_motion(src_kp, drv_kp)
220
+ warped_src = warp_image(src_img, flow)
221
+ unet_input = torch.cat([warped_src, flow, occ], dim=1)
222
+ pred = generator(unet_input)
223
+ generated_frames.append(pred.detach().cpu().squeeze().numpy())
224
+
225
+ H, W = generated_frames[0].shape
226
+ out_path = os.path.join(output_dir, f"{person}_sketch.avi")
227
+ out = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*"XVID"), 15, (W,H), False)
228
+ for f in generated_frames:
229
+ out.write((f*255).astype(np.uint8))
230
+ out.release()
231
+ print(f"Video saved: {out_path}")
232
+
233
+ if __name__ == "__main__":
234
+ data_root = "motion_transfer/dataset/"
235
+ train(data_root, epochs=500, resume_checkpoint="checkpoints/last.pth")
236
+ generate_video(data_root, output_dir="outputs")
requirements.txt.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.35.2
2
+ face-alignment==1.4.1
3
+ ffmpeg-python==0.2.0
4
+ huggingface-hub==0.36.0
5
+ imageio==2.37.0
6
+ imageio-ffmpeg==0.4.7
7
+ matplotlib==3.10.5
8
+ mediapipe==0.10.21
9
+ numpy==2.2.6
10
+ opencv-contrib-python==4.11.0.86
11
+ opencv-python==4.12.0.88
12
+ safetensors==0.6.2
13
+ scikit-image==0.25.2
14
+ scikit-learn==1.7.1
15
+ scipy==1.15.3
16
+ sentencepiece==0.2.0
17
+ torch==2.5.1+cu121
18
+ torchaudio==2.5.1+cu121
19
+ torchvision==0.20.1+cu121
20
+ tqdm==4.64.1
21
+ transformers==4.57.1
22
+ urllib3==2.5.0
test_acc_upgrade.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ import mediapipe as mp
7
+ from safetensors.torch import load_file
8
+
9
+ from unet_acc import DenseMotion, UNetGenerator, warp_image
10
+
11
+
12
+
13
+ # Face detector
14
+
15
+ mp_face_detection = mp.solutions.face_detection
16
+ detector = mp_face_detection.FaceDetection(
17
+ model_selection=1,
18
+ min_detection_confidence=0.5
19
+ )
20
+
21
+
22
+ def crop_head_with_bg(img_rgb, target_size=256,
23
+ margin_top=0.6, margin_sides=0.3, margin_bottom=0.4):
24
+ ih, iw, _ = img_rgb.shape
25
+ results = detector.process(img_rgb)
26
+ if not results.detections:
27
+ return None
28
+
29
+ det = results.detections[0]
30
+ bbox = det.location_data.relative_bounding_box
31
+
32
+ x1 = int(bbox.xmin * iw)
33
+ y1 = int(bbox.ymin * ih)
34
+ w = int(bbox.width * iw)
35
+ h = int(bbox.height * ih)
36
+
37
+ x1 = max(0, int(x1 - w * margin_sides))
38
+ x2 = min(iw, int(x1 + w * (1 + 2 * margin_sides)))
39
+ y1 = max(0, int(y1 - h * margin_top))
40
+ y2 = min(ih, int(y1 + h * (1 + margin_bottom + margin_top)))
41
+
42
+ cropped = img_rgb[y1:y2, x1:x2]
43
+ ch, cw = cropped.shape[:2]
44
+
45
+ scale = target_size / max(ch, cw)
46
+ new_w, new_h = int(cw * scale), int(ch * scale)
47
+ resized = cv2.resize(cropped, (new_w, new_h),
48
+ interpolation=cv2.INTER_LANCZOS4)
49
+
50
+ blurred_bg = cv2.GaussianBlur(resized, (51, 51), 0)
51
+ background = cv2.resize(
52
+ blurred_bg, (target_size, target_size),
53
+ interpolation=cv2.INTER_AREA
54
+ )
55
+
56
+ y_off = (target_size - new_h) // 2
57
+ x_off = (target_size - new_w) // 2
58
+ background[y_off:y_off + new_h, x_off:x_off + new_w] = resized
59
+
60
+ return background
61
+
62
+
63
+
64
+ # MediaPipe face mesh
65
+
66
+ def get_mediapipe_keypoints(img_rgb):
67
+ h, w = img_rgb.shape[:2]
68
+ mp_face = mp.solutions.face_mesh
69
+ with mp_face.FaceMesh(static_image_mode=True, max_num_faces=1) as mesh:
70
+ res = mesh.process(img_rgb)
71
+ if not res.multi_face_landmarks:
72
+ raise RuntimeError("No face landmarks detected")
73
+ pts = [(p.x * w, p.y * h)
74
+ for p in res.multi_face_landmarks[0].landmark]
75
+ return np.array(pts, dtype=np.float32)
76
+
77
+
78
+
79
+
80
+
81
+ def create_eye_mouth_mask(image_shape, keypoints):
82
+ H, W = image_shape
83
+ mask = np.zeros((H, W), dtype=np.uint8)
84
+
85
+ left_eye = [33, 133, 160, 159, 158, 157, 173]
86
+ right_eye = [362, 263, 387, 386, 385, 384, 398]
87
+ mouth_outer = list(range(61, 79))
88
+ mouth_inner = list(range(308, 325))
89
+
90
+ def fill(indices):
91
+ pts = keypoints[indices].astype(np.int32)
92
+ cv2.fillPoly(mask, [pts.reshape(-1, 1, 2)], 255)
93
+
94
+ fill(left_eye)
95
+ fill(right_eye)
96
+ fill(mouth_outer)
97
+ fill(mouth_inner)
98
+
99
+ mask = cv2.dilate(mask,
100
+ cv2.getStructuringElement(
101
+ cv2.MORPH_ELLIPSE, (5, 5)),
102
+ 1)
103
+ mask = cv2.GaussianBlur(mask, (7, 7), 2)
104
+ return mask.astype(np.float32) / 255.0
105
+
106
+
107
+
108
+ # Inference
109
+
110
+ def test_single_image(source_image_path, heatmap_dir, output_path):
111
+ device = "cuda" if torch.cuda.is_available() else "cpu"
112
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
113
+
114
+ # Load model
115
+ ckpt = load_file("checkpoints/best.safetensors")
116
+
117
+ dense_motion = DenseMotion(kp_channels=68).to(device)
118
+ generator = UNetGenerator(in_channels=4).to(device)
119
+
120
+ dm_state = {k.replace("dense_motion.", ""): v
121
+ for k, v in ckpt.items()
122
+ if k.startswith("dense_motion.")}
123
+ gen_state = {k.replace("generator.", ""): v
124
+ for k, v in ckpt.items()
125
+ if k.startswith("generator.")}
126
+
127
+ dense_motion.load_state_dict(dm_state, strict=False)
128
+ generator.load_state_dict(gen_state, strict=False)
129
+ dense_motion.eval()
130
+ generator.eval()
131
+
132
+ print("Model loaded")
133
+
134
+ img_bgr = cv2.imread(source_image_path)
135
+ img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
136
+
137
+ cropped = crop_head_with_bg(img_rgb, target_size=256)
138
+ if cropped is None:
139
+ raise RuntimeError("Face not detected")
140
+
141
+ src_gray = cv2.cvtColor(cropped, cv2.COLOR_RGB2GRAY)
142
+ src_np = src_gray.astype(np.float32) / 255.0
143
+ H, W = src_np.shape
144
+
145
+ src_tensor = torch.tensor(src_np).unsqueeze(0).unsqueeze(0).to(device)
146
+
147
+ # Mask
148
+ keypoints = get_mediapipe_keypoints(cropped)
149
+ mask_np = create_eye_mouth_mask((H, W), keypoints)
150
+ mask_tensor = torch.tensor(mask_np).unsqueeze(0).unsqueeze(0).to(device)
151
+
152
+ # Heatmaps
153
+ heatmap_files = sorted(
154
+ f for f in os.listdir(heatmap_dir) if f.endswith(".npy")
155
+ )
156
+ if not heatmap_files:
157
+ raise RuntimeError("No heatmaps found")
158
+
159
+ src_kp = torch.tensor(
160
+ np.load(os.path.join(heatmap_dir, heatmap_files[0])),
161
+ dtype=torch.float32
162
+ ).permute(2, 0, 1).unsqueeze(0).to(device)
163
+
164
+ out = cv2.VideoWriter(
165
+ output_path,
166
+ cv2.VideoWriter_fourcc(*"mp4v"),
167
+ 15,
168
+ (W, H),
169
+ False
170
+ )
171
+
172
+ # Inference
173
+ with torch.no_grad():
174
+ for i, hfile in enumerate(heatmap_files):
175
+ drv_kp = torch.tensor(
176
+ np.load(os.path.join(heatmap_dir, hfile)),
177
+ dtype=torch.float32
178
+ ).permute(2, 0, 1).unsqueeze(0).to(device)
179
+
180
+ combined_kp = src_kp.clone()
181
+ for idx in list(range(36, 48)) + list(range(48, 68)):
182
+ combined_kp[:, idx] = drv_kp[:, idx]
183
+
184
+ flow, occ = dense_motion(src_kp, combined_kp)
185
+ warped = torch.clamp(warp_image(src_tensor, flow), 0, 1)
186
+
187
+ pred = torch.clamp(
188
+ generator(torch.cat([warped, flow, occ], dim=1)),
189
+ 0, 1
190
+ )
191
+
192
+ final_frame = pred * mask_tensor + src_tensor * (1 - mask_tensor)
193
+ frame_np = (final_frame.cpu().squeeze().numpy() * 255).astype(np.uint8)
194
+
195
+ out.write(frame_np)
196
+
197
+ if i == 0:
198
+ cv2.imwrite("preview_streamlit_matched.png", frame_np)
199
+
200
+ out.release()
201
+ print(f"Output saved: {output_path}")
202
+
203
+
204
+
205
+ if __name__ == "__main__":
206
+ test_single_image(
207
+ source_image_path="motion_transfer/test/87.jpg",
208
+ heatmap_dir="motion_transfer/test/combined/",
209
+ output_path="outputs/final_streamlit_matched.mp4"
210
+ )
utilis/Face_keypoints_generate.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ import face_alignment
7
+ from multiprocessing import Pool, cpu_count
8
+ print("hello")
9
+
10
+ sources_dir = r"motion_transfer\new_dataset\test\image"
11
+ videos_dir = r"motion_transfer\new_dataset\test\video"
12
+ output_root = r"motion_transfer\new_dataset\test\dataset"
13
+
14
+ num_workers = min(cpu_count(), 4)
15
+ target_size = 256
16
+ SIGMA = 2.0
17
+
18
+ NUM_FACE_POINTS = 68
19
+
20
+ # Temporal smoothing
21
+ SMOOTH_ALPHA = 0.7
22
+
23
+ # Models
24
+ fa = face_alignment.FaceAlignment(
25
+ face_alignment.LandmarksType.TWO_D,
26
+ device='cuda' if torch.cuda.is_available() else 'cpu'
27
+ )
28
+
29
+
30
+ def resize_with_gradient_padding(img, target_size):
31
+ h, w = img.shape[:2]
32
+ scale = target_size / max(h, w)
33
+ img = cv2.resize(img, (int(w*scale), int(h*scale)), interpolation=cv2.INTER_AREA)
34
+ h, w = img.shape[:2]
35
+ tl = np.mean(img[0:5, 0:5, :], axis=(0,1))
36
+ tr = np.mean(img[0:5, -5:, :], axis=(0,1))
37
+ grad = np.linspace(tl, tr, target_size)
38
+ bg = np.tile(grad, (target_size, 1, 1)).astype(np.uint8)
39
+ y0 = (target_size - h) // 2
40
+ x0 = (target_size - w) // 2
41
+ bg[y0:y0+h, x0:x0+w] = img
42
+ return bg
43
+
44
+ def clip_xy(x, y, w, h):
45
+ return float(np.clip(x, 0, w - 1)), float(np.clip(y, 0, h - 1))
46
+
47
+ def ema(prev_pts, curr_pts, alpha=SMOOTH_ALPHA):
48
+ if prev_pts is None or prev_pts.shape != curr_pts.shape:
49
+ return curr_pts
50
+ return alpha * curr_pts + (1.0 - alpha) * prev_pts
51
+
52
+ def gaussian_heatmaps(points, H, W, sigma=2.0):
53
+ N = points.shape[0]
54
+ yy, xx = np.mgrid[0:H, 0:W].astype(np.float32)
55
+ heat = np.zeros((H, W, N), dtype=np.float32)
56
+ s2 = 2 * (sigma ** 2)
57
+ for i, (x, y) in enumerate(points):
58
+ d2 = (xx - x) ** 2 + (yy - y) ** 2
59
+ heat[..., i] = np.exp(-d2 / s2)
60
+ return heat
61
+
62
+
63
+ def process_person(person_name):
64
+ source_path = None
65
+ for ext in [".png", ".jpg", ".jpeg"]:
66
+ p = os.path.join(sources_dir, person_name + ext)
67
+ if os.path.exists(p):
68
+ source_path = p
69
+ break
70
+ video_path = os.path.join(videos_dir, f"{person_name}.mp4")
71
+ if not (os.path.isfile(source_path) and os.path.isfile(video_path)):
72
+ print(f"Missing files for {person_name}")
73
+ return
74
+
75
+ print(f" Processing {person_name}...")
76
+ person_root = os.path.join(output_root, person_name)
77
+ frames_dir = os.path.join(output_root, person_name, "frames")
78
+ combined_dir = os.path.join(output_root, person_name, "combined")
79
+ keypoints_preview_dir = os.path.join(output_root, person_name, "keypoints_preview")
80
+ os.makedirs(frames_dir, exist_ok=True)
81
+ os.makedirs(combined_dir, exist_ok=True)
82
+ os.makedirs(keypoints_preview_dir, exist_ok=True)
83
+
84
+ # Save resized
85
+ src_img = cv2.imread(source_path)
86
+ if src_img is not None:
87
+ src_ref = resize_with_gradient_padding(src_img, target_size)
88
+ cv2.imwrite(os.path.join(person_root, f"{person_name}.jpg"), src_ref)
89
+
90
+ cap = cv2.VideoCapture(video_path)
91
+ frame_idx = 0
92
+ prev_points = None
93
+
94
+ while True:
95
+ ok, frame_bgr = cap.read()
96
+ if not ok:
97
+ break
98
+
99
+ # resize
100
+ frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
101
+ frame_rgb = resize_with_gradient_padding(frame_rgb, target_size)
102
+ H, W = frame_rgb.shape[:2]
103
+
104
+ # 68 face landmarks
105
+ fa_out = fa.get_landmarks(frame_rgb)
106
+ if fa_out is None or len(fa_out) == 0 or fa_out[0].shape[0] < NUM_FACE_POINTS:
107
+ frame_idx += 1
108
+ continue
109
+ face68 = fa_out[0][:NUM_FACE_POINTS].astype(np.float32)
110
+
111
+ # Smooth for temporal stability
112
+ face68 = ema(prev_points, face68, alpha=SMOOTH_ALPHA)
113
+ prev_points = face68.copy()
114
+
115
+ # Heatmaps
116
+ heatmap = gaussian_heatmaps(face68, H, W, sigma=SIGMA)
117
+
118
+ # Save preview + .npy
119
+ vis = frame_rgb.copy()
120
+ for (x, y) in face68.astype(int):
121
+ cv2.circle(vis, (x, y), 2, (0, 255, 0), -1)
122
+ vis_bgr = cv2.cvtColor(vis, cv2.COLOR_RGB2BGR)
123
+
124
+ fname = f"{frame_idx:05d}"
125
+ frame_file = f"{frame_idx:05d}.jpg"
126
+ cv2.imwrite(os.path.join(frames_dir, frame_file), frame_rgb)
127
+ #cv2.imwrite(os.path.join(keypoints_preview_dir, f"{fname}.png"), vis_bgr)
128
+ np.save(os.path.join(combined_dir, f"{fname}.npy"), heatmap)
129
+
130
+ frame_idx += 1
131
+
132
+ cap.release()
133
+ print(f"Done: {person_name} | Frames: {frame_idx} | Points/frame: {NUM_FACE_POINTS}")
134
+
135
+
136
+ if __name__ == "__main__":
137
+ image_files = []
138
+ image_files += glob.glob(os.path.join(sources_dir, "*.png"))
139
+ image_files += glob.glob(os.path.join(sources_dir, "*.jpg"))
140
+ image_files += glob.glob(os.path.join(sources_dir, "*.jpeg"))
141
+
142
+ people = [os.path.splitext(os.path.basename(p))[0] for p in image_files]
143
+
144
+ print("People found:", people)
145
+ with Pool(num_workers) as p:
146
+ p.map(process_person, people)
utilis/generate_heatmap.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import cv2
4
+
5
+
6
+ single_heatmap_path = "motion_transfer/dataset_single/reference_heatmap/00000.npy"
7
+ reference_heatmap_dir = "motion_transfer/dataset_single/reference_heatmap" # contains 150 npy files
8
+ output_dir = "motion_transfer/dataset_single/test_heatmap"
9
+ preview_video = "motion_transfer/dataset_single/simulated_motion.mp4"
10
+
11
+ os.makedirs(output_dir, exist_ok=True)
12
+
13
+ single_heatmap = np.load(single_heatmap_path)
14
+ H, W, C = single_heatmap.shape
15
+
16
+ # Extract keypoints
17
+ def extract_keypoints(hmap):
18
+ kps = []
19
+ for i in range(hmap.shape[2]):
20
+ y, x = np.where(hmap[:, :, i] > 0)
21
+ if len(x) > 0:
22
+ kps.append([np.mean(x), np.mean(y)])
23
+ else:
24
+ kps.append([0, 0])
25
+ return np.array(kps, dtype=np.float32)
26
+
27
+ single_kp = extract_keypoints(single_heatmap)
28
+
29
+ # reference motion
30
+ ref_files = sorted([f for f in os.listdir(reference_heatmap_dir) if f.endswith(".npy")])
31
+ ref_heatmaps = [np.load(os.path.join(reference_heatmap_dir, f)) for f in ref_files]
32
+ ref_kp_list = [extract_keypoints(hm) for hm in ref_heatmaps]
33
+
34
+ # Compute motion relative to first reference frame
35
+ ref_base_kp = ref_kp_list[0]
36
+ motion_vectors = [kp - ref_base_kp for kp in ref_kp_list]
37
+
38
+ # Apply motion to single input keypoints
39
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
40
+ video_writer = cv2.VideoWriter(preview_video, fourcc, 30, (W, H))
41
+
42
+ def gaussian_heatmaps(points, H, W, sigma=2.0):
43
+ N = points.shape[0]
44
+ yy, xx = np.mgrid[0:H, 0:W].astype(np.float32)
45
+ heat = np.zeros((H, W, N), dtype=np.float32)
46
+ s2 = 2 * (sigma ** 2)
47
+ for i, (x, y) in enumerate(points):
48
+ d2 = (xx - x) ** 2 + (yy - y) ** 2
49
+ heat[..., i] = np.exp(-d2 / s2)
50
+ return heat
51
+
52
+ for frame_idx, displacement in enumerate(motion_vectors):
53
+ moved_kp = single_kp + displacement
54
+
55
+ # Generate Gaussian heatmap for all points
56
+ new_heatmap = gaussian_heatmaps(moved_kp, H, W, sigma=2.0)
57
+ np.save(os.path.join(output_dir, f"{frame_idx:05d}.npy"), new_heatmap)
58
+
59
+
60
+ frame_vis = np.zeros((H, W, 3), dtype=np.uint8)
61
+ for (x, y) in moved_kp.astype(int):
62
+ cv2.circle(frame_vis, (x, y), 2, (0, 255, 0), -1)
63
+ video_writer.write(frame_vis)
64
+
65
+ video_writer.release()
66
+
67
+ print(f"Simulated motion heatmaps saved in '{output_dir}'")
68
+ print(f"Preview video saved as '{preview_video}'")
utilis/jitter.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+
4
+ heatmap_dir = "motion_transfer/dataset_single/test_heatmap"
5
+ smoothed_dir = "motion_transfer/dataset_single/smoothed_heatmaps"
6
+ os.makedirs(smoothed_dir, exist_ok=True)
7
+
8
+ # Load existing heatmaps and extract keypoints
9
+ heatmap_files = sorted([f for f in os.listdir(heatmap_dir) if f.endswith(".npy")])
10
+ kp_list = [np.load(os.path.join(heatmap_dir, f)) for f in heatmap_files]
11
+
12
+ # Extract keypoints from each heatmap
13
+ def extract_kpoints(hm):
14
+ kps = []
15
+ for i in range(hm.shape[2]):
16
+ y, x = np.where(hm[:, :, i] > 0)
17
+ if len(x) > 0:
18
+ kps.append([np.mean(x), np.mean(y)])
19
+ else:
20
+ kps.append([0, 0])
21
+ return np.array(kps, dtype=np.float32)
22
+
23
+ kp_list = [extract_kpoints(hm) for hm in kp_list]
24
+
25
+ #Apply temporal smoothing
26
+ def temporal_smoothing(kp_list, alpha=0.7):
27
+ smoothed = [kp_list[0].copy()]
28
+ for i in range(1, len(kp_list)):
29
+ new_kp = alpha * smoothed[-1] + (1 - alpha) * kp_list[i]
30
+ smoothed.append(new_kp)
31
+ return smoothed
32
+
33
+ smoothed_kp_list = temporal_smoothing(kp_list, alpha=0.7)
34
+
35
+ # Recompute heatmaps with same dimensions
36
+ H, W, C = np.load(os.path.join(heatmap_dir, heatmap_files[0])).shape
37
+
38
+ def gaussian_heatmaps(points, H, W, sigma=2.0):
39
+ N = points.shape[0]
40
+ yy, xx = np.mgrid[0:H, 0:W].astype(np.float32)
41
+ heat = np.zeros((H, W, N), dtype=np.float32)
42
+ s2 = 2 * (sigma ** 2)
43
+ for i, (x, y) in enumerate(points):
44
+ d2 = (xx - x) ** 2 + (yy - y) ** 2
45
+ heat[..., i] = np.exp(-d2 / s2)
46
+ return heat
47
+
48
+ for idx, kp in enumerate(smoothed_kp_list):
49
+ new_hm = gaussian_heatmaps(kp, H, W, sigma=2.0)
50
+ np.save(os.path.join(smoothed_dir, heatmap_files[idx]), new_hm)