Upload 17 files
Browse files- .gitattributes +1 -0
- app.py +209 -0
- checkpoints/best.pth +3 -0
- checkpoints/best.safetensors +3 -0
- checkpoints/best_epoch_0.0145/best.pth +3 -0
- demo/image_1.png +0 -0
- demo/image_2.png +0 -0
- demo/image_3.jpeg +0 -0
- demo/video_1.mp4 +0 -0
- demo/video_2.mp4 +0 -0
- demo/video_3.mp4 +3 -0
- evaluation.py +179 -0
- model/unet_acc.py +236 -0
- requirements.txt.txt +22 -0
- test_acc_upgrade.py +210 -0
- utilis/Face_keypoints_generate.py +146 -0
- utilis/generate_heatmap.py +68 -0
- utilis/jitter.py +50 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
demo/video_3.mp4 filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import torch
|
| 4 |
+
import face_alignment
|
| 5 |
+
import numpy as np
|
| 6 |
+
import mediapipe as mp
|
| 7 |
+
import tempfile
|
| 8 |
+
from multiprocessing import cpu_count
|
| 9 |
+
from PIL import Image
|
| 10 |
+
import streamlit as st
|
| 11 |
+
from test import test_single_image
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
reference_heatmap_dir = "motion_transfer/dataset_single/reference_heatmap"
|
| 15 |
+
output_dir = r"motion_transfer\dataset_single\test_heatmap"
|
| 16 |
+
final_output = "motion_transfer/outputs/final_result.mp4"
|
| 17 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 18 |
+
|
| 19 |
+
num_workers = min(cpu_count(), 4)
|
| 20 |
+
target_size = 256
|
| 21 |
+
SIGMA = 2.0
|
| 22 |
+
NUM_FACE_POINTS = 68
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D,
|
| 26 |
+
device="cuda" if torch.cuda.is_available() else "cpu")
|
| 27 |
+
mp_face_detection = mp.solutions.face_detection
|
| 28 |
+
detector = mp_face_detection.FaceDetection(model_selection=1, min_detection_confidence=0.5)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def gaussian_heatmaps(points, H, W, sigma=2.0):
|
| 33 |
+
yy, xx = np.meshgrid(np.arange(H), np.arange(W), indexing='ij')
|
| 34 |
+
heatmaps = np.exp(-((xx[..., None] - points[:, 0]) ** 2 +
|
| 35 |
+
(yy[..., None] - points[:, 1]) ** 2) / (2 * sigma ** 2))
|
| 36 |
+
return heatmaps.astype(np.float32)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def extract_keypoints(hmap):
|
| 40 |
+
kps = []
|
| 41 |
+
for i in range(hmap.shape[2]):
|
| 42 |
+
y, x = np.where(hmap[:, :, i] > 0)
|
| 43 |
+
if len(x) > 0:
|
| 44 |
+
kps.append([np.mean(x), np.mean(y)])
|
| 45 |
+
else:
|
| 46 |
+
kps.append([0, 0])
|
| 47 |
+
return np.array(kps, dtype=np.float32)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def trim_video(input_path, output_path, max_seconds=7):
|
| 51 |
+
cap = cv2.VideoCapture(input_path)
|
| 52 |
+
if not cap.isOpened():
|
| 53 |
+
print("Error opening video")
|
| 54 |
+
return False
|
| 55 |
+
fps = int(round(cap.get(cv2.CAP_PROP_FPS)))
|
| 56 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 57 |
+
max_frames = min(total_frames, fps * max_seconds)
|
| 58 |
+
|
| 59 |
+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
| 60 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 61 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 62 |
+
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
|
| 63 |
+
|
| 64 |
+
count = 0
|
| 65 |
+
while cap.isOpened() and count < max_frames:
|
| 66 |
+
ret, frame = cap.read()
|
| 67 |
+
if not ret:
|
| 68 |
+
break
|
| 69 |
+
out.write(frame)
|
| 70 |
+
count += 1
|
| 71 |
+
cap.release()
|
| 72 |
+
out.release()
|
| 73 |
+
return True
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def crop_head_with_bg(img_rgb, target_size=256, margin_top=0.6, margin_sides=0.3, margin_bottom=0.4):
|
| 77 |
+
|
| 78 |
+
ih, iw, _ = img_rgb.shape
|
| 79 |
+
results = detector.process(img_rgb)
|
| 80 |
+
if not results.detections:
|
| 81 |
+
return None
|
| 82 |
+
|
| 83 |
+
det = results.detections[0]
|
| 84 |
+
bbox = det.location_data.relative_bounding_box
|
| 85 |
+
x1 = int(bbox.xmin * iw)
|
| 86 |
+
y1 = int(bbox.ymin * ih)
|
| 87 |
+
w = int(bbox.width * iw)
|
| 88 |
+
h = int(bbox.height * ih)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
x1 = max(0, int(x1 - w * margin_sides))
|
| 92 |
+
x2 = min(iw, int(x1 + w * (1 + 2 * margin_sides)))
|
| 93 |
+
y1 = max(0, int(y1 - h * margin_top))
|
| 94 |
+
y2 = min(ih, int(y1 + h * (1 + margin_bottom + margin_top)))
|
| 95 |
+
|
| 96 |
+
cropped = img_rgb[y1:y2, x1:x2]
|
| 97 |
+
ch, cw = cropped.shape[:2]
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
scale = target_size / max(ch, cw)
|
| 101 |
+
new_w, new_h = int(cw * scale), int(ch * scale)
|
| 102 |
+
resized = cv2.resize(cropped, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
blurred_bg = cv2.GaussianBlur(resized, (51, 51), 0)
|
| 106 |
+
background = cv2.resize(blurred_bg, (target_size, target_size), interpolation=cv2.INTER_AREA)
|
| 107 |
+
|
| 108 |
+
y_offset = (target_size - new_h) // 2
|
| 109 |
+
x_offset = (target_size - new_w) // 2
|
| 110 |
+
background[y_offset:y_offset + new_h, x_offset:x_offset + new_w] = resized
|
| 111 |
+
|
| 112 |
+
return background
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
# Streamlit
|
| 116 |
+
st.title("Sketch to Live")
|
| 117 |
+
|
| 118 |
+
src_img = st.file_uploader("Upload face sketch", type=["jpg", "png"])
|
| 119 |
+
cropped_head = None
|
| 120 |
+
|
| 121 |
+
if src_img is not None:
|
| 122 |
+
|
| 123 |
+
pil_img = Image.open(src_img).convert("RGB")
|
| 124 |
+
img_rgb = np.array(pil_img)
|
| 125 |
+
ih, iw, _ = img_rgb.shape
|
| 126 |
+
st.write(f"Uploaded image size: {iw}×{ih}")
|
| 127 |
+
|
| 128 |
+
if ih < target_size or iw < target_size:
|
| 129 |
+
st.warning("Image too small ({iw}×{ih}).Please upload one larger than {target_size}×{target_size}.")
|
| 130 |
+
else:
|
| 131 |
+
cropped_head = crop_head_with_bg(img_rgb, target_size=target_size)
|
| 132 |
+
if cropped_head is None:
|
| 133 |
+
st.warning("No face detected. Try another image.")
|
| 134 |
+
else:
|
| 135 |
+
st.subheader("Face Preview")
|
| 136 |
+
st.image(
|
| 137 |
+
cropped_head,
|
| 138 |
+
caption="Cropped Head",
|
| 139 |
+
width=256,
|
| 140 |
+
channels="RGB",
|
| 141 |
+
output_format="PNG",
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# Save
|
| 145 |
+
cv2.imwrite("cropped_head.png",
|
| 146 |
+
cv2.cvtColor(cropped_head, cv2.COLOR_RGB2BGR),
|
| 147 |
+
[cv2.IMWRITE_PNG_COMPRESSION, 0])
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
if st.button("Lively Sketch"):
|
| 152 |
+
if cropped_head is None:
|
| 153 |
+
st.error("Please upload a face image.")
|
| 154 |
+
else:
|
| 155 |
+
progress_text = st.empty()
|
| 156 |
+
progress_bar = st.progress(0)
|
| 157 |
+
frame_preview = st.empty()
|
| 158 |
+
progress_text.text("Processing")
|
| 159 |
+
|
| 160 |
+
H, W = cropped_head.shape[:2]
|
| 161 |
+
fa_out = fa.get_landmarks(cropped_head)
|
| 162 |
+
if fa_out is None or len(fa_out) == 0:
|
| 163 |
+
st.error(" No face landmarks detected.")
|
| 164 |
+
else:
|
| 165 |
+
face68 = fa_out[0].astype(np.float32)
|
| 166 |
+
single_heatmap = gaussian_heatmaps(face68, H, W, sigma=SIGMA)
|
| 167 |
+
single_kp = face68
|
| 168 |
+
|
| 169 |
+
ref_files = sorted([f for f in os.listdir(reference_heatmap_dir) if f.endswith(".npy")])
|
| 170 |
+
if len(ref_files) == 0:
|
| 171 |
+
st.error(" No reference heatmaps found!")
|
| 172 |
+
else:
|
| 173 |
+
ref_heatmaps = [np.load(os.path.join(reference_heatmap_dir, f)) for f in ref_files]
|
| 174 |
+
ref_kp_list = [extract_keypoints(hm) for hm in ref_heatmaps]
|
| 175 |
+
ref_base_kp = ref_kp_list[0]
|
| 176 |
+
motion_vectors = [kp - ref_base_kp for kp in ref_kp_list]
|
| 177 |
+
|
| 178 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 179 |
+
total_frames = len(motion_vectors)
|
| 180 |
+
|
| 181 |
+
for frame_idx, displacement in enumerate(motion_vectors):
|
| 182 |
+
moved_kp = single_kp + displacement
|
| 183 |
+
new_heatmap = gaussian_heatmaps(moved_kp, H, W, sigma=SIGMA)
|
| 184 |
+
np.save(os.path.join(output_dir, f"{frame_idx:05d}.npy"), new_heatmap)
|
| 185 |
+
|
| 186 |
+
frame_preview.image(cropped_head, width=128)
|
| 187 |
+
progress_bar.progress(int((frame_idx + 1) / total_frames * 100))
|
| 188 |
+
|
| 189 |
+
temp_img_path = "cropped_head.png"
|
| 190 |
+
test_single_image(temp_img_path, output_dir, final_output)
|
| 191 |
+
|
| 192 |
+
trimmed_output = "trimmed_result.mp4"
|
| 193 |
+
trim_video(final_output, trimmed_output, max_seconds=7)
|
| 194 |
+
|
| 195 |
+
progress_bar.progress(100)
|
| 196 |
+
progress_text.text("Done!")
|
| 197 |
+
frame_preview.empty()
|
| 198 |
+
|
| 199 |
+
st.success("Sketch-to-Live")
|
| 200 |
+
with open(trimmed_output, "rb") as f:
|
| 201 |
+
st.download_button("Download Result Video", f, file_name="Sketch.mp4")
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
st.markdown("""
|
| 206 |
+
<div style='position: fixed; bottom: 10px; right: 10px; color: gray; font-size: 12px;'>
|
| 207 |
+
<b>Inspired by prior explicit motion transfer methods.</b>
|
| 208 |
+
</div>
|
| 209 |
+
""", unsafe_allow_html=True)
|
checkpoints/best.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f0c6eb8c5a3aa5f789f2dede6d1af99c90af21a4debe90af8a8c18ff9e8e07ce
|
| 3 |
+
size 33537446
|
checkpoints/best.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8b51ee2b0821c1eaacb4f0b1d41a626cce85e04e6ba0290e326084328efec2b5
|
| 3 |
+
size 11171232
|
checkpoints/best_epoch_0.0145/best.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8aee4c0acb8f77244b9b5bd5bd6f1037ba010b1f4f4b542fa38c021d6bc06fd1
|
| 3 |
+
size 66736953
|
demo/image_1.png
ADDED
|
demo/image_2.png
ADDED
|
demo/image_3.jpeg
ADDED
|
demo/video_1.mp4
ADDED
|
Binary file (40.6 kB). View file
|
|
|
demo/video_2.mp4
ADDED
|
Binary file (85.4 kB). View file
|
|
|
demo/video_3.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:43d9b05b2135038d80f0b4438f6e20084fb509653f9da9870bc70da6f4dd7746
|
| 3 |
+
size 105725
|
evaluation.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import face_alignment
|
| 7 |
+
import lpips
|
| 8 |
+
import pandas as pd
|
| 9 |
+
|
| 10 |
+
from unet_acc import DenseMotion, UNetGenerator, warp_image
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 14 |
+
|
| 15 |
+
dense_motion = DenseMotion(kp_channels=68).to(device)
|
| 16 |
+
generator = UNetGenerator(in_channels=4).to(device)
|
| 17 |
+
|
| 18 |
+
ckpt = torch.load("checkpoints/best.pth", map_location=device)
|
| 19 |
+
dense_motion.load_state_dict(ckpt["dense_motion"])
|
| 20 |
+
generator.load_state_dict(ckpt["generator"])
|
| 21 |
+
|
| 22 |
+
dense_motion.eval()
|
| 23 |
+
generator.eval()
|
| 24 |
+
|
| 25 |
+
lpips_fn = lpips.LPIPS(net="alex").to(device)
|
| 26 |
+
|
| 27 |
+
fa = face_alignment.FaceAlignment(
|
| 28 |
+
face_alignment.LandmarksType.TWO_D,
|
| 29 |
+
device=device
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# Metrics
|
| 34 |
+
|
| 35 |
+
def landmark_distance(pred, gt):
|
| 36 |
+
pred = cv2.cvtColor(pred, cv2.COLOR_GRAY2RGB)
|
| 37 |
+
gt = cv2.cvtColor(gt, cv2.COLOR_GRAY2RGB)
|
| 38 |
+
|
| 39 |
+
pl = fa.get_landmarks(pred)
|
| 40 |
+
gl = fa.get_landmarks(gt)
|
| 41 |
+
if pl is None or gl is None:
|
| 42 |
+
return None
|
| 43 |
+
|
| 44 |
+
pl, gl = pl[0], gl[0]
|
| 45 |
+
eye_dist = np.linalg.norm(gl[36] - gl[45]) + 1e-6
|
| 46 |
+
return np.mean(np.linalg.norm(pl - gl, axis=1)) / eye_dist
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def lpips_score(pred, gt):
|
| 50 |
+
pred = pred.repeat(1, 3, 1, 1)
|
| 51 |
+
gt = gt.repeat(1, 3, 1, 1)
|
| 52 |
+
return lpips_fn(pred, gt).item()
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def l1_score(pred, gt):
|
| 56 |
+
return F.l1_loss(pred, gt).item()
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def temporal_jitter(frames):
|
| 60 |
+
diffs = []
|
| 61 |
+
for i in range(1, len(frames)):
|
| 62 |
+
diffs.append(torch.mean(torch.abs(frames[i] - frames[i - 1])).item())
|
| 63 |
+
return np.std(diffs), np.mean(diffs)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
LOCK_IDXS = list(range(36, 48)) + list(range(48, 68))
|
| 67 |
+
|
| 68 |
+
def infer_no_warp(src):
|
| 69 |
+
B, _, H, W = src.shape
|
| 70 |
+
flow = torch.zeros(B, 2, H, W).to(device)
|
| 71 |
+
occ = torch.ones(B, 1, H, W).to(device)
|
| 72 |
+
return torch.clamp(generator(torch.cat([src, flow, occ], 1)), 0, 1)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def infer_warp(src, src_kp, drv_kp):
|
| 76 |
+
flow, occ = dense_motion(src_kp, drv_kp)
|
| 77 |
+
warped = warp_image(src, flow)
|
| 78 |
+
return torch.clamp(generator(torch.cat([warped, flow, occ], 1)), 0, 1)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def infer_warp_lock(src, src_kp, drv_kp):
|
| 82 |
+
kp = src_kp.clone()
|
| 83 |
+
kp[:, LOCK_IDXS] = drv_kp[:, LOCK_IDXS]
|
| 84 |
+
flow, occ = dense_motion(src_kp, kp)
|
| 85 |
+
warped = warp_image(src, flow)
|
| 86 |
+
return torch.clamp(generator(torch.cat([warped, flow, occ], 1)), 0, 1)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def infer_warp_lock_mask(src, src_kp, drv_kp, mask):
|
| 90 |
+
kp = src_kp.clone()
|
| 91 |
+
kp[:, LOCK_IDXS] = drv_kp[:, LOCK_IDXS]
|
| 92 |
+
flow, occ = dense_motion(src_kp, kp)
|
| 93 |
+
warped = warp_image(src, flow)
|
| 94 |
+
pred = generator(torch.cat([warped, flow, occ], 1))
|
| 95 |
+
return torch.clamp(pred * mask + src * (1 - mask), 0, 1)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def evaluate_sequence(src, src_kp, drv_kps, gt_frames, mask, mode):
|
| 99 |
+
preds_torch = []
|
| 100 |
+
lmd, lp, l1 = [], [], []
|
| 101 |
+
|
| 102 |
+
with torch.no_grad():
|
| 103 |
+
for t, drv_kp in enumerate(drv_kps):
|
| 104 |
+
if mode == "no_warp":
|
| 105 |
+
pred = infer_no_warp(src)
|
| 106 |
+
elif mode == "warp":
|
| 107 |
+
pred = infer_warp(src, src_kp, drv_kp)
|
| 108 |
+
elif mode == "warp_lock":
|
| 109 |
+
pred = infer_warp_lock(src, src_kp, drv_kp)
|
| 110 |
+
elif mode == "warp_lock_mask":
|
| 111 |
+
pred = infer_warp_lock_mask(src, src_kp, drv_kp, mask)
|
| 112 |
+
else:
|
| 113 |
+
raise ValueError
|
| 114 |
+
|
| 115 |
+
gt = gt_frames[t]
|
| 116 |
+
|
| 117 |
+
pred_np = (pred.detach().cpu().squeeze().numpy() * 255).astype(np.uint8)
|
| 118 |
+
gt_np = (gt.detach().cpu().squeeze().numpy() * 255).astype(np.uint8)
|
| 119 |
+
|
| 120 |
+
lm = landmark_distance(pred_np, gt_np)
|
| 121 |
+
if lm is not None:
|
| 122 |
+
lmd.append(lm)
|
| 123 |
+
|
| 124 |
+
lp.append(lpips_score(pred, gt))
|
| 125 |
+
l1.append(l1_score(pred, gt))
|
| 126 |
+
preds_torch.append(pred)
|
| 127 |
+
|
| 128 |
+
jit_std, _ = temporal_jitter(preds_torch)
|
| 129 |
+
|
| 130 |
+
return {
|
| 131 |
+
"LMD": np.mean(lmd) if len(lmd) > 0 else np.nan,
|
| 132 |
+
"LPIPS": np.mean(lp),
|
| 133 |
+
"Jitter": jit_std
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
def run_all(src, src_kp, drv_kps, gt_frames, mask):
|
| 137 |
+
rows = []
|
| 138 |
+
for mode in ["no_warp", "warp", "warp_lock", "warp_lock_mask"]:
|
| 139 |
+
print(f"Evaluating {mode}")
|
| 140 |
+
res = evaluate_sequence(src, src_kp, drv_kps, gt_frames, mask, mode)
|
| 141 |
+
res["Method"] = mode
|
| 142 |
+
rows.append(res)
|
| 143 |
+
|
| 144 |
+
df = pd.DataFrame(rows)
|
| 145 |
+
df = df[["Method", "LMD", "LPIPS", "Jitter"]]
|
| 146 |
+
df.to_csv("ablation_results.csv", index=False)
|
| 147 |
+
print(df)
|
| 148 |
+
if __name__ == "__main__":
|
| 149 |
+
src_img = cv2.imread(r"motion_transfer\new_dataset\test\dataset\87\frames\00000.jpg", cv2.IMREAD_GRAYSCALE)
|
| 150 |
+
src = torch.tensor(
|
| 151 |
+
src_img / 255.0,
|
| 152 |
+
dtype=torch.float32
|
| 153 |
+
).unsqueeze(0).unsqueeze(0).to(device)
|
| 154 |
+
|
| 155 |
+
src_kp = torch.tensor(
|
| 156 |
+
np.load(r"motion_transfer\new_dataset\test\dataset\87\combined\00000.npy"),
|
| 157 |
+
dtype=torch.float32
|
| 158 |
+
).permute(2, 0, 1).unsqueeze(0).to(device)
|
| 159 |
+
|
| 160 |
+
drv_kps = []
|
| 161 |
+
gt_frames = []
|
| 162 |
+
|
| 163 |
+
for f in sorted(os.listdir(r"motion_transfer\new_dataset\test\dataset\87\frames")):
|
| 164 |
+
gt = cv2.imread(os.path.join(r"motion_transfer\new_dataset\test\dataset\87\frames", f), cv2.IMREAD_GRAYSCALE)
|
| 165 |
+
gt_frames.append(
|
| 166 |
+
torch.tensor(
|
| 167 |
+
gt / 255.0,
|
| 168 |
+
dtype=torch.float32
|
| 169 |
+
).unsqueeze(0).unsqueeze(0).to(device)
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
kp = torch.tensor(
|
| 173 |
+
np.load(os.path.join(r"motion_transfer\new_dataset\test\dataset\87\combined", f.replace(".jpg", ".npy"))),
|
| 174 |
+
dtype=torch.float32
|
| 175 |
+
).permute(2, 0, 1).unsqueeze(0).to(device)
|
| 176 |
+
drv_kps.append(kp)
|
| 177 |
+
|
| 178 |
+
mask = torch.ones_like(src)
|
| 179 |
+
run_all(src, src_kp, drv_kps, gt_frames, mask)
|
model/unet_acc.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.optim as optim
|
| 5 |
+
from torch.utils.data import Dataset, DataLoader
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import numpy as np
|
| 8 |
+
import cv2
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class SketchMotionDataset(Dataset):
|
| 12 |
+
def __init__(self, data_root, transform=None):
|
| 13 |
+
self.transform = transform
|
| 14 |
+
self.data = []
|
| 15 |
+
|
| 16 |
+
persons = sorted(os.listdir(data_root))
|
| 17 |
+
for person in persons:
|
| 18 |
+
person_dir = os.path.join(data_root, person)
|
| 19 |
+
source_frame = os.path.join(person_dir, f"{person}.jpg")
|
| 20 |
+
frames_dir = os.path.join(person_dir, "frames")
|
| 21 |
+
heatmap_dir = os.path.join(person_dir, "combined")
|
| 22 |
+
|
| 23 |
+
frame_files = sorted(os.listdir(frames_dir))
|
| 24 |
+
heatmap_files = sorted(os.listdir(heatmap_dir))
|
| 25 |
+
|
| 26 |
+
for f_file, h_file in zip(frame_files, heatmap_files):
|
| 27 |
+
self.data.append({
|
| 28 |
+
"source_frame": source_frame,
|
| 29 |
+
"driving_frame": os.path.join(frames_dir, f_file),
|
| 30 |
+
"driving_heatmap": os.path.join(heatmap_dir, h_file),
|
| 31 |
+
"source_heatmap": os.path.join(heatmap_dir, heatmap_files[0]) # first frame heatmap
|
| 32 |
+
})
|
| 33 |
+
|
| 34 |
+
def __len__(self):
|
| 35 |
+
return len(self.data)
|
| 36 |
+
|
| 37 |
+
def __getitem__(self, idx):
|
| 38 |
+
item = self.data[idx]
|
| 39 |
+
src_img = Image.open(item["source_frame"]).convert("L")
|
| 40 |
+
drv_img = Image.open(item["driving_frame"]).convert("L")
|
| 41 |
+
|
| 42 |
+
src_img = torch.tensor(np.array(src_img)/255.0, dtype=torch.float32).unsqueeze(0)
|
| 43 |
+
drv_img = torch.tensor(np.array(drv_img)/255.0, dtype=torch.float32).unsqueeze(0)
|
| 44 |
+
|
| 45 |
+
src_kp = torch.tensor(np.load(item["source_heatmap"]), dtype=torch.float32).permute(2,0,1)
|
| 46 |
+
drv_kp = torch.tensor(np.load(item["driving_heatmap"]), dtype=torch.float32).permute(2,0,1)
|
| 47 |
+
|
| 48 |
+
return src_img, drv_img, src_kp, drv_kp
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class DenseMotion(nn.Module):
|
| 52 |
+
def __init__(self, kp_channels=68):
|
| 53 |
+
super().__init__()
|
| 54 |
+
self.conv = nn.Sequential(
|
| 55 |
+
nn.Conv2d(kp_channels*2, 128, 7, padding=3),
|
| 56 |
+
nn.ReLU(),
|
| 57 |
+
nn.Conv2d(128, 64, 3, padding=1),
|
| 58 |
+
nn.ReLU(),
|
| 59 |
+
nn.Conv2d(64, 3, 3, padding=1)
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
def forward(self, src_kp, drv_kp):
|
| 63 |
+
x = torch.cat([src_kp, drv_kp], dim=1)
|
| 64 |
+
out = self.conv(x)
|
| 65 |
+
flow = out[:, :2, :, :]
|
| 66 |
+
occ = torch.sigmoid(out[:, 2:3, :, :])
|
| 67 |
+
return flow, occ
|
| 68 |
+
|
| 69 |
+
class UNetGenerator(nn.Module):
|
| 70 |
+
def __init__(self, in_channels=4, out_channels=1):
|
| 71 |
+
super().__init__()
|
| 72 |
+
|
| 73 |
+
def conv_block(in_c, out_c):
|
| 74 |
+
return nn.Sequential(
|
| 75 |
+
nn.Conv2d(in_c, out_c, 3, padding=1),
|
| 76 |
+
nn.ReLU(inplace=True),
|
| 77 |
+
nn.Conv2d(out_c, out_c, 3, padding=1),
|
| 78 |
+
nn.ReLU(inplace=True)
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
self.enc1 = conv_block(in_channels, 64)
|
| 82 |
+
self.pool1 = nn.MaxPool2d(2)
|
| 83 |
+
self.enc2 = conv_block(64, 128)
|
| 84 |
+
self.pool2 = nn.MaxPool2d(2)
|
| 85 |
+
self.bottleneck = conv_block(128, 256)
|
| 86 |
+
self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
|
| 87 |
+
self.dec2 = conv_block(256, 128)
|
| 88 |
+
self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
|
| 89 |
+
self.dec1 = conv_block(128, 64)
|
| 90 |
+
self.final = nn.Conv2d(64, out_channels, 1)
|
| 91 |
+
|
| 92 |
+
def forward(self, x):
|
| 93 |
+
e1 = self.enc1(x)
|
| 94 |
+
e2 = self.enc2(self.pool1(e1))
|
| 95 |
+
b = self.bottleneck(self.pool2(e2))
|
| 96 |
+
d2 = self.up2(b)
|
| 97 |
+
d2 = self.dec2(torch.cat([d2, e2], dim=1))
|
| 98 |
+
d1 = self.up1(d2)
|
| 99 |
+
d1 = self.dec1(torch.cat([d1, e1], dim=1))
|
| 100 |
+
return self.final(d1)
|
| 101 |
+
|
| 102 |
+
def warp_image(img, flow):
|
| 103 |
+
B, C, H, W = img.shape
|
| 104 |
+
grid_y, grid_x = torch.meshgrid(torch.linspace(-1,1,H), torch.linspace(-1,1,W))
|
| 105 |
+
grid = torch.stack((grid_x, grid_y),2).unsqueeze(0).repeat(B,1,1,1).to(img.device)
|
| 106 |
+
flow_norm = flow.permute(0,2,3,1) / torch.tensor([W/2, H/2]).to(img.device)
|
| 107 |
+
warped = nn.functional.grid_sample(img, grid + flow_norm, align_corners=True)
|
| 108 |
+
return warped
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def save_checkpoint(state, is_best, checkpoint_dir="checkpoints", filename="last.pth", best_filename="best.pth"):
|
| 113 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 114 |
+
filepath = os.path.join(checkpoint_dir, filename)
|
| 115 |
+
torch.save(state, filepath)
|
| 116 |
+
if is_best:
|
| 117 |
+
bestpath = os.path.join(checkpoint_dir, best_filename)
|
| 118 |
+
torch.save(state, bestpath)
|
| 119 |
+
print(f"Saved new best checkpoint: {bestpath}")
|
| 120 |
+
else:
|
| 121 |
+
print(f"Saved checkpoint: {filepath}")
|
| 122 |
+
|
| 123 |
+
def train(data_root, epochs=500, resume_checkpoint="checkpoints/last.pth"):
|
| 124 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 125 |
+
|
| 126 |
+
dataset = SketchMotionDataset(data_root)
|
| 127 |
+
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
|
| 128 |
+
|
| 129 |
+
dense_motion = DenseMotion(kp_channels=68).to(device)
|
| 130 |
+
generator = UNetGenerator(in_channels=4).to(device)
|
| 131 |
+
|
| 132 |
+
optimizer = optim.Adam(list(dense_motion.parameters()) + list(generator.parameters()), lr=1e-4)
|
| 133 |
+
criterion = nn.L1Loss()
|
| 134 |
+
|
| 135 |
+
start_epoch = 0
|
| 136 |
+
best_loss = float("inf")
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
if os.path.exists(resume_checkpoint):
|
| 140 |
+
print(f"Resuming training from {resume_checkpoint}")
|
| 141 |
+
checkpoint = torch.load(resume_checkpoint, map_location=device)
|
| 142 |
+
dense_motion.load_state_dict(checkpoint["dense_motion"])
|
| 143 |
+
generator.load_state_dict(checkpoint["generator"])
|
| 144 |
+
optimizer.load_state_dict(checkpoint["optimizer"])
|
| 145 |
+
start_epoch = checkpoint.get("epoch", 0)
|
| 146 |
+
best_loss = checkpoint.get("loss", float("inf"))
|
| 147 |
+
print(f"Resumed from epoch {start_epoch}, last loss = {best_loss:.4f}")
|
| 148 |
+
else:
|
| 149 |
+
print("Starting new training")
|
| 150 |
+
|
| 151 |
+
# Training
|
| 152 |
+
for epoch in range(start_epoch, epochs):
|
| 153 |
+
epoch_loss = 0.0
|
| 154 |
+
|
| 155 |
+
for src_img, drv_img, src_kp, drv_kp in dataloader:
|
| 156 |
+
src_img, drv_img = src_img.to(device), drv_img.to(device)
|
| 157 |
+
src_kp, drv_kp = src_kp.to(device), drv_kp.to(device)
|
| 158 |
+
|
| 159 |
+
flow, occ = dense_motion(src_kp, drv_kp)
|
| 160 |
+
warped_src = warp_image(src_img, flow)
|
| 161 |
+
|
| 162 |
+
unet_input = torch.cat([warped_src, flow, occ], dim=1)
|
| 163 |
+
pred = generator(unet_input)
|
| 164 |
+
|
| 165 |
+
loss = criterion(pred, drv_img)
|
| 166 |
+
optimizer.zero_grad()
|
| 167 |
+
loss.backward()
|
| 168 |
+
optimizer.step()
|
| 169 |
+
|
| 170 |
+
epoch_loss += loss.item()
|
| 171 |
+
|
| 172 |
+
avg_loss = epoch_loss / len(dataloader)
|
| 173 |
+
print(f"Epoch {epoch+1}/{epochs} - Loss: {avg_loss:.4f}")
|
| 174 |
+
|
| 175 |
+
# Save checkpoint
|
| 176 |
+
is_best = avg_loss < best_loss
|
| 177 |
+
best_loss = min(avg_loss, best_loss)
|
| 178 |
+
|
| 179 |
+
save_checkpoint({
|
| 180 |
+
"epoch": epoch + 1,
|
| 181 |
+
"dense_motion": dense_motion.state_dict(),
|
| 182 |
+
"generator": generator.state_dict(),
|
| 183 |
+
"optimizer": optimizer.state_dict(),
|
| 184 |
+
"loss": avg_loss,
|
| 185 |
+
}, is_best, checkpoint_dir="checkpoints")
|
| 186 |
+
|
| 187 |
+
print("Training completed")
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def generate_video(data_root, output_dir="outputs"):
|
| 192 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 193 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 194 |
+
|
| 195 |
+
dense_motion = DenseMotion(kp_channels=68).to(device)
|
| 196 |
+
generator = UNetGenerator(in_channels=4).to(device)
|
| 197 |
+
|
| 198 |
+
# Load best checkpoint
|
| 199 |
+
checkpoint = torch.load("checkpoints/best.pth", map_location=device)
|
| 200 |
+
dense_motion.load_state_dict(checkpoint["dense_motion"])
|
| 201 |
+
generator.load_state_dict(checkpoint["generator"])
|
| 202 |
+
dense_motion.eval()
|
| 203 |
+
generator.eval()
|
| 204 |
+
|
| 205 |
+
for person in sorted(os.listdir(data_root)):
|
| 206 |
+
person_dir = os.path.join(data_root, person)
|
| 207 |
+
source_frame = os.path.join(person_dir, f"{person}.jpg")
|
| 208 |
+
heatmap_dir = os.path.join(person_dir, "combined")
|
| 209 |
+
|
| 210 |
+
src_img = Image.open(source_frame).convert("L")
|
| 211 |
+
src_img = torch.tensor(np.array(src_img)/255.0, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)
|
| 212 |
+
|
| 213 |
+
src_kp = torch.tensor(np.load(os.path.join(heatmap_dir, sorted(os.listdir(heatmap_dir))[0])), dtype=torch.float32).permute(2,0,1).unsqueeze(0).to(device)
|
| 214 |
+
|
| 215 |
+
generated_frames = []
|
| 216 |
+
for h_file in sorted(os.listdir(heatmap_dir)):
|
| 217 |
+
drv_kp = torch.tensor(np.load(os.path.join(heatmap_dir, h_file)), dtype=torch.float32).permute(2,0,1).unsqueeze(0).to(device)
|
| 218 |
+
|
| 219 |
+
flow, occ = dense_motion(src_kp, drv_kp)
|
| 220 |
+
warped_src = warp_image(src_img, flow)
|
| 221 |
+
unet_input = torch.cat([warped_src, flow, occ], dim=1)
|
| 222 |
+
pred = generator(unet_input)
|
| 223 |
+
generated_frames.append(pred.detach().cpu().squeeze().numpy())
|
| 224 |
+
|
| 225 |
+
H, W = generated_frames[0].shape
|
| 226 |
+
out_path = os.path.join(output_dir, f"{person}_sketch.avi")
|
| 227 |
+
out = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*"XVID"), 15, (W,H), False)
|
| 228 |
+
for f in generated_frames:
|
| 229 |
+
out.write((f*255).astype(np.uint8))
|
| 230 |
+
out.release()
|
| 231 |
+
print(f"Video saved: {out_path}")
|
| 232 |
+
|
| 233 |
+
if __name__ == "__main__":
|
| 234 |
+
data_root = "motion_transfer/dataset/"
|
| 235 |
+
train(data_root, epochs=500, resume_checkpoint="checkpoints/last.pth")
|
| 236 |
+
generate_video(data_root, output_dir="outputs")
|
requirements.txt.txt
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
diffusers==0.35.2
|
| 2 |
+
face-alignment==1.4.1
|
| 3 |
+
ffmpeg-python==0.2.0
|
| 4 |
+
huggingface-hub==0.36.0
|
| 5 |
+
imageio==2.37.0
|
| 6 |
+
imageio-ffmpeg==0.4.7
|
| 7 |
+
matplotlib==3.10.5
|
| 8 |
+
mediapipe==0.10.21
|
| 9 |
+
numpy==2.2.6
|
| 10 |
+
opencv-contrib-python==4.11.0.86
|
| 11 |
+
opencv-python==4.12.0.88
|
| 12 |
+
safetensors==0.6.2
|
| 13 |
+
scikit-image==0.25.2
|
| 14 |
+
scikit-learn==1.7.1
|
| 15 |
+
scipy==1.15.3
|
| 16 |
+
sentencepiece==0.2.0
|
| 17 |
+
torch==2.5.1+cu121
|
| 18 |
+
torchaudio==2.5.1+cu121
|
| 19 |
+
torchvision==0.20.1+cu121
|
| 20 |
+
tqdm==4.64.1
|
| 21 |
+
transformers==4.57.1
|
| 22 |
+
urllib3==2.5.0
|
test_acc_upgrade.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import mediapipe as mp
|
| 7 |
+
from safetensors.torch import load_file
|
| 8 |
+
|
| 9 |
+
from unet_acc import DenseMotion, UNetGenerator, warp_image
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# Face detector
|
| 14 |
+
|
| 15 |
+
mp_face_detection = mp.solutions.face_detection
|
| 16 |
+
detector = mp_face_detection.FaceDetection(
|
| 17 |
+
model_selection=1,
|
| 18 |
+
min_detection_confidence=0.5
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def crop_head_with_bg(img_rgb, target_size=256,
|
| 23 |
+
margin_top=0.6, margin_sides=0.3, margin_bottom=0.4):
|
| 24 |
+
ih, iw, _ = img_rgb.shape
|
| 25 |
+
results = detector.process(img_rgb)
|
| 26 |
+
if not results.detections:
|
| 27 |
+
return None
|
| 28 |
+
|
| 29 |
+
det = results.detections[0]
|
| 30 |
+
bbox = det.location_data.relative_bounding_box
|
| 31 |
+
|
| 32 |
+
x1 = int(bbox.xmin * iw)
|
| 33 |
+
y1 = int(bbox.ymin * ih)
|
| 34 |
+
w = int(bbox.width * iw)
|
| 35 |
+
h = int(bbox.height * ih)
|
| 36 |
+
|
| 37 |
+
x1 = max(0, int(x1 - w * margin_sides))
|
| 38 |
+
x2 = min(iw, int(x1 + w * (1 + 2 * margin_sides)))
|
| 39 |
+
y1 = max(0, int(y1 - h * margin_top))
|
| 40 |
+
y2 = min(ih, int(y1 + h * (1 + margin_bottom + margin_top)))
|
| 41 |
+
|
| 42 |
+
cropped = img_rgb[y1:y2, x1:x2]
|
| 43 |
+
ch, cw = cropped.shape[:2]
|
| 44 |
+
|
| 45 |
+
scale = target_size / max(ch, cw)
|
| 46 |
+
new_w, new_h = int(cw * scale), int(ch * scale)
|
| 47 |
+
resized = cv2.resize(cropped, (new_w, new_h),
|
| 48 |
+
interpolation=cv2.INTER_LANCZOS4)
|
| 49 |
+
|
| 50 |
+
blurred_bg = cv2.GaussianBlur(resized, (51, 51), 0)
|
| 51 |
+
background = cv2.resize(
|
| 52 |
+
blurred_bg, (target_size, target_size),
|
| 53 |
+
interpolation=cv2.INTER_AREA
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
y_off = (target_size - new_h) // 2
|
| 57 |
+
x_off = (target_size - new_w) // 2
|
| 58 |
+
background[y_off:y_off + new_h, x_off:x_off + new_w] = resized
|
| 59 |
+
|
| 60 |
+
return background
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# MediaPipe face mesh
|
| 65 |
+
|
| 66 |
+
def get_mediapipe_keypoints(img_rgb):
|
| 67 |
+
h, w = img_rgb.shape[:2]
|
| 68 |
+
mp_face = mp.solutions.face_mesh
|
| 69 |
+
with mp_face.FaceMesh(static_image_mode=True, max_num_faces=1) as mesh:
|
| 70 |
+
res = mesh.process(img_rgb)
|
| 71 |
+
if not res.multi_face_landmarks:
|
| 72 |
+
raise RuntimeError("No face landmarks detected")
|
| 73 |
+
pts = [(p.x * w, p.y * h)
|
| 74 |
+
for p in res.multi_face_landmarks[0].landmark]
|
| 75 |
+
return np.array(pts, dtype=np.float32)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def create_eye_mouth_mask(image_shape, keypoints):
|
| 82 |
+
H, W = image_shape
|
| 83 |
+
mask = np.zeros((H, W), dtype=np.uint8)
|
| 84 |
+
|
| 85 |
+
left_eye = [33, 133, 160, 159, 158, 157, 173]
|
| 86 |
+
right_eye = [362, 263, 387, 386, 385, 384, 398]
|
| 87 |
+
mouth_outer = list(range(61, 79))
|
| 88 |
+
mouth_inner = list(range(308, 325))
|
| 89 |
+
|
| 90 |
+
def fill(indices):
|
| 91 |
+
pts = keypoints[indices].astype(np.int32)
|
| 92 |
+
cv2.fillPoly(mask, [pts.reshape(-1, 1, 2)], 255)
|
| 93 |
+
|
| 94 |
+
fill(left_eye)
|
| 95 |
+
fill(right_eye)
|
| 96 |
+
fill(mouth_outer)
|
| 97 |
+
fill(mouth_inner)
|
| 98 |
+
|
| 99 |
+
mask = cv2.dilate(mask,
|
| 100 |
+
cv2.getStructuringElement(
|
| 101 |
+
cv2.MORPH_ELLIPSE, (5, 5)),
|
| 102 |
+
1)
|
| 103 |
+
mask = cv2.GaussianBlur(mask, (7, 7), 2)
|
| 104 |
+
return mask.astype(np.float32) / 255.0
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# Inference
|
| 109 |
+
|
| 110 |
+
def test_single_image(source_image_path, heatmap_dir, output_path):
|
| 111 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 112 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 113 |
+
|
| 114 |
+
# Load model
|
| 115 |
+
ckpt = load_file("checkpoints/best.safetensors")
|
| 116 |
+
|
| 117 |
+
dense_motion = DenseMotion(kp_channels=68).to(device)
|
| 118 |
+
generator = UNetGenerator(in_channels=4).to(device)
|
| 119 |
+
|
| 120 |
+
dm_state = {k.replace("dense_motion.", ""): v
|
| 121 |
+
for k, v in ckpt.items()
|
| 122 |
+
if k.startswith("dense_motion.")}
|
| 123 |
+
gen_state = {k.replace("generator.", ""): v
|
| 124 |
+
for k, v in ckpt.items()
|
| 125 |
+
if k.startswith("generator.")}
|
| 126 |
+
|
| 127 |
+
dense_motion.load_state_dict(dm_state, strict=False)
|
| 128 |
+
generator.load_state_dict(gen_state, strict=False)
|
| 129 |
+
dense_motion.eval()
|
| 130 |
+
generator.eval()
|
| 131 |
+
|
| 132 |
+
print("Model loaded")
|
| 133 |
+
|
| 134 |
+
img_bgr = cv2.imread(source_image_path)
|
| 135 |
+
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
| 136 |
+
|
| 137 |
+
cropped = crop_head_with_bg(img_rgb, target_size=256)
|
| 138 |
+
if cropped is None:
|
| 139 |
+
raise RuntimeError("Face not detected")
|
| 140 |
+
|
| 141 |
+
src_gray = cv2.cvtColor(cropped, cv2.COLOR_RGB2GRAY)
|
| 142 |
+
src_np = src_gray.astype(np.float32) / 255.0
|
| 143 |
+
H, W = src_np.shape
|
| 144 |
+
|
| 145 |
+
src_tensor = torch.tensor(src_np).unsqueeze(0).unsqueeze(0).to(device)
|
| 146 |
+
|
| 147 |
+
# Mask
|
| 148 |
+
keypoints = get_mediapipe_keypoints(cropped)
|
| 149 |
+
mask_np = create_eye_mouth_mask((H, W), keypoints)
|
| 150 |
+
mask_tensor = torch.tensor(mask_np).unsqueeze(0).unsqueeze(0).to(device)
|
| 151 |
+
|
| 152 |
+
# Heatmaps
|
| 153 |
+
heatmap_files = sorted(
|
| 154 |
+
f for f in os.listdir(heatmap_dir) if f.endswith(".npy")
|
| 155 |
+
)
|
| 156 |
+
if not heatmap_files:
|
| 157 |
+
raise RuntimeError("No heatmaps found")
|
| 158 |
+
|
| 159 |
+
src_kp = torch.tensor(
|
| 160 |
+
np.load(os.path.join(heatmap_dir, heatmap_files[0])),
|
| 161 |
+
dtype=torch.float32
|
| 162 |
+
).permute(2, 0, 1).unsqueeze(0).to(device)
|
| 163 |
+
|
| 164 |
+
out = cv2.VideoWriter(
|
| 165 |
+
output_path,
|
| 166 |
+
cv2.VideoWriter_fourcc(*"mp4v"),
|
| 167 |
+
15,
|
| 168 |
+
(W, H),
|
| 169 |
+
False
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
# Inference
|
| 173 |
+
with torch.no_grad():
|
| 174 |
+
for i, hfile in enumerate(heatmap_files):
|
| 175 |
+
drv_kp = torch.tensor(
|
| 176 |
+
np.load(os.path.join(heatmap_dir, hfile)),
|
| 177 |
+
dtype=torch.float32
|
| 178 |
+
).permute(2, 0, 1).unsqueeze(0).to(device)
|
| 179 |
+
|
| 180 |
+
combined_kp = src_kp.clone()
|
| 181 |
+
for idx in list(range(36, 48)) + list(range(48, 68)):
|
| 182 |
+
combined_kp[:, idx] = drv_kp[:, idx]
|
| 183 |
+
|
| 184 |
+
flow, occ = dense_motion(src_kp, combined_kp)
|
| 185 |
+
warped = torch.clamp(warp_image(src_tensor, flow), 0, 1)
|
| 186 |
+
|
| 187 |
+
pred = torch.clamp(
|
| 188 |
+
generator(torch.cat([warped, flow, occ], dim=1)),
|
| 189 |
+
0, 1
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
final_frame = pred * mask_tensor + src_tensor * (1 - mask_tensor)
|
| 193 |
+
frame_np = (final_frame.cpu().squeeze().numpy() * 255).astype(np.uint8)
|
| 194 |
+
|
| 195 |
+
out.write(frame_np)
|
| 196 |
+
|
| 197 |
+
if i == 0:
|
| 198 |
+
cv2.imwrite("preview_streamlit_matched.png", frame_np)
|
| 199 |
+
|
| 200 |
+
out.release()
|
| 201 |
+
print(f"Output saved: {output_path}")
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
if __name__ == "__main__":
|
| 206 |
+
test_single_image(
|
| 207 |
+
source_image_path="motion_transfer/test/87.jpg",
|
| 208 |
+
heatmap_dir="motion_transfer/test/combined/",
|
| 209 |
+
output_path="outputs/final_streamlit_matched.mp4"
|
| 210 |
+
)
|
utilis/Face_keypoints_generate.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import glob
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import face_alignment
|
| 7 |
+
from multiprocessing import Pool, cpu_count
|
| 8 |
+
print("hello")
|
| 9 |
+
|
| 10 |
+
sources_dir = r"motion_transfer\new_dataset\test\image"
|
| 11 |
+
videos_dir = r"motion_transfer\new_dataset\test\video"
|
| 12 |
+
output_root = r"motion_transfer\new_dataset\test\dataset"
|
| 13 |
+
|
| 14 |
+
num_workers = min(cpu_count(), 4)
|
| 15 |
+
target_size = 256
|
| 16 |
+
SIGMA = 2.0
|
| 17 |
+
|
| 18 |
+
NUM_FACE_POINTS = 68
|
| 19 |
+
|
| 20 |
+
# Temporal smoothing
|
| 21 |
+
SMOOTH_ALPHA = 0.7
|
| 22 |
+
|
| 23 |
+
# Models
|
| 24 |
+
fa = face_alignment.FaceAlignment(
|
| 25 |
+
face_alignment.LandmarksType.TWO_D,
|
| 26 |
+
device='cuda' if torch.cuda.is_available() else 'cpu'
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def resize_with_gradient_padding(img, target_size):
|
| 31 |
+
h, w = img.shape[:2]
|
| 32 |
+
scale = target_size / max(h, w)
|
| 33 |
+
img = cv2.resize(img, (int(w*scale), int(h*scale)), interpolation=cv2.INTER_AREA)
|
| 34 |
+
h, w = img.shape[:2]
|
| 35 |
+
tl = np.mean(img[0:5, 0:5, :], axis=(0,1))
|
| 36 |
+
tr = np.mean(img[0:5, -5:, :], axis=(0,1))
|
| 37 |
+
grad = np.linspace(tl, tr, target_size)
|
| 38 |
+
bg = np.tile(grad, (target_size, 1, 1)).astype(np.uint8)
|
| 39 |
+
y0 = (target_size - h) // 2
|
| 40 |
+
x0 = (target_size - w) // 2
|
| 41 |
+
bg[y0:y0+h, x0:x0+w] = img
|
| 42 |
+
return bg
|
| 43 |
+
|
| 44 |
+
def clip_xy(x, y, w, h):
|
| 45 |
+
return float(np.clip(x, 0, w - 1)), float(np.clip(y, 0, h - 1))
|
| 46 |
+
|
| 47 |
+
def ema(prev_pts, curr_pts, alpha=SMOOTH_ALPHA):
|
| 48 |
+
if prev_pts is None or prev_pts.shape != curr_pts.shape:
|
| 49 |
+
return curr_pts
|
| 50 |
+
return alpha * curr_pts + (1.0 - alpha) * prev_pts
|
| 51 |
+
|
| 52 |
+
def gaussian_heatmaps(points, H, W, sigma=2.0):
|
| 53 |
+
N = points.shape[0]
|
| 54 |
+
yy, xx = np.mgrid[0:H, 0:W].astype(np.float32)
|
| 55 |
+
heat = np.zeros((H, W, N), dtype=np.float32)
|
| 56 |
+
s2 = 2 * (sigma ** 2)
|
| 57 |
+
for i, (x, y) in enumerate(points):
|
| 58 |
+
d2 = (xx - x) ** 2 + (yy - y) ** 2
|
| 59 |
+
heat[..., i] = np.exp(-d2 / s2)
|
| 60 |
+
return heat
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def process_person(person_name):
|
| 64 |
+
source_path = None
|
| 65 |
+
for ext in [".png", ".jpg", ".jpeg"]:
|
| 66 |
+
p = os.path.join(sources_dir, person_name + ext)
|
| 67 |
+
if os.path.exists(p):
|
| 68 |
+
source_path = p
|
| 69 |
+
break
|
| 70 |
+
video_path = os.path.join(videos_dir, f"{person_name}.mp4")
|
| 71 |
+
if not (os.path.isfile(source_path) and os.path.isfile(video_path)):
|
| 72 |
+
print(f"Missing files for {person_name}")
|
| 73 |
+
return
|
| 74 |
+
|
| 75 |
+
print(f" Processing {person_name}...")
|
| 76 |
+
person_root = os.path.join(output_root, person_name)
|
| 77 |
+
frames_dir = os.path.join(output_root, person_name, "frames")
|
| 78 |
+
combined_dir = os.path.join(output_root, person_name, "combined")
|
| 79 |
+
keypoints_preview_dir = os.path.join(output_root, person_name, "keypoints_preview")
|
| 80 |
+
os.makedirs(frames_dir, exist_ok=True)
|
| 81 |
+
os.makedirs(combined_dir, exist_ok=True)
|
| 82 |
+
os.makedirs(keypoints_preview_dir, exist_ok=True)
|
| 83 |
+
|
| 84 |
+
# Save resized
|
| 85 |
+
src_img = cv2.imread(source_path)
|
| 86 |
+
if src_img is not None:
|
| 87 |
+
src_ref = resize_with_gradient_padding(src_img, target_size)
|
| 88 |
+
cv2.imwrite(os.path.join(person_root, f"{person_name}.jpg"), src_ref)
|
| 89 |
+
|
| 90 |
+
cap = cv2.VideoCapture(video_path)
|
| 91 |
+
frame_idx = 0
|
| 92 |
+
prev_points = None
|
| 93 |
+
|
| 94 |
+
while True:
|
| 95 |
+
ok, frame_bgr = cap.read()
|
| 96 |
+
if not ok:
|
| 97 |
+
break
|
| 98 |
+
|
| 99 |
+
# resize
|
| 100 |
+
frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
|
| 101 |
+
frame_rgb = resize_with_gradient_padding(frame_rgb, target_size)
|
| 102 |
+
H, W = frame_rgb.shape[:2]
|
| 103 |
+
|
| 104 |
+
# 68 face landmarks
|
| 105 |
+
fa_out = fa.get_landmarks(frame_rgb)
|
| 106 |
+
if fa_out is None or len(fa_out) == 0 or fa_out[0].shape[0] < NUM_FACE_POINTS:
|
| 107 |
+
frame_idx += 1
|
| 108 |
+
continue
|
| 109 |
+
face68 = fa_out[0][:NUM_FACE_POINTS].astype(np.float32)
|
| 110 |
+
|
| 111 |
+
# Smooth for temporal stability
|
| 112 |
+
face68 = ema(prev_points, face68, alpha=SMOOTH_ALPHA)
|
| 113 |
+
prev_points = face68.copy()
|
| 114 |
+
|
| 115 |
+
# Heatmaps
|
| 116 |
+
heatmap = gaussian_heatmaps(face68, H, W, sigma=SIGMA)
|
| 117 |
+
|
| 118 |
+
# Save preview + .npy
|
| 119 |
+
vis = frame_rgb.copy()
|
| 120 |
+
for (x, y) in face68.astype(int):
|
| 121 |
+
cv2.circle(vis, (x, y), 2, (0, 255, 0), -1)
|
| 122 |
+
vis_bgr = cv2.cvtColor(vis, cv2.COLOR_RGB2BGR)
|
| 123 |
+
|
| 124 |
+
fname = f"{frame_idx:05d}"
|
| 125 |
+
frame_file = f"{frame_idx:05d}.jpg"
|
| 126 |
+
cv2.imwrite(os.path.join(frames_dir, frame_file), frame_rgb)
|
| 127 |
+
#cv2.imwrite(os.path.join(keypoints_preview_dir, f"{fname}.png"), vis_bgr)
|
| 128 |
+
np.save(os.path.join(combined_dir, f"{fname}.npy"), heatmap)
|
| 129 |
+
|
| 130 |
+
frame_idx += 1
|
| 131 |
+
|
| 132 |
+
cap.release()
|
| 133 |
+
print(f"Done: {person_name} | Frames: {frame_idx} | Points/frame: {NUM_FACE_POINTS}")
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
if __name__ == "__main__":
|
| 137 |
+
image_files = []
|
| 138 |
+
image_files += glob.glob(os.path.join(sources_dir, "*.png"))
|
| 139 |
+
image_files += glob.glob(os.path.join(sources_dir, "*.jpg"))
|
| 140 |
+
image_files += glob.glob(os.path.join(sources_dir, "*.jpeg"))
|
| 141 |
+
|
| 142 |
+
people = [os.path.splitext(os.path.basename(p))[0] for p in image_files]
|
| 143 |
+
|
| 144 |
+
print("People found:", people)
|
| 145 |
+
with Pool(num_workers) as p:
|
| 146 |
+
p.map(process_person, people)
|
utilis/generate_heatmap.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cv2
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
single_heatmap_path = "motion_transfer/dataset_single/reference_heatmap/00000.npy"
|
| 7 |
+
reference_heatmap_dir = "motion_transfer/dataset_single/reference_heatmap" # contains 150 npy files
|
| 8 |
+
output_dir = "motion_transfer/dataset_single/test_heatmap"
|
| 9 |
+
preview_video = "motion_transfer/dataset_single/simulated_motion.mp4"
|
| 10 |
+
|
| 11 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 12 |
+
|
| 13 |
+
single_heatmap = np.load(single_heatmap_path)
|
| 14 |
+
H, W, C = single_heatmap.shape
|
| 15 |
+
|
| 16 |
+
# Extract keypoints
|
| 17 |
+
def extract_keypoints(hmap):
|
| 18 |
+
kps = []
|
| 19 |
+
for i in range(hmap.shape[2]):
|
| 20 |
+
y, x = np.where(hmap[:, :, i] > 0)
|
| 21 |
+
if len(x) > 0:
|
| 22 |
+
kps.append([np.mean(x), np.mean(y)])
|
| 23 |
+
else:
|
| 24 |
+
kps.append([0, 0])
|
| 25 |
+
return np.array(kps, dtype=np.float32)
|
| 26 |
+
|
| 27 |
+
single_kp = extract_keypoints(single_heatmap)
|
| 28 |
+
|
| 29 |
+
# reference motion
|
| 30 |
+
ref_files = sorted([f for f in os.listdir(reference_heatmap_dir) if f.endswith(".npy")])
|
| 31 |
+
ref_heatmaps = [np.load(os.path.join(reference_heatmap_dir, f)) for f in ref_files]
|
| 32 |
+
ref_kp_list = [extract_keypoints(hm) for hm in ref_heatmaps]
|
| 33 |
+
|
| 34 |
+
# Compute motion relative to first reference frame
|
| 35 |
+
ref_base_kp = ref_kp_list[0]
|
| 36 |
+
motion_vectors = [kp - ref_base_kp for kp in ref_kp_list]
|
| 37 |
+
|
| 38 |
+
# Apply motion to single input keypoints
|
| 39 |
+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
| 40 |
+
video_writer = cv2.VideoWriter(preview_video, fourcc, 30, (W, H))
|
| 41 |
+
|
| 42 |
+
def gaussian_heatmaps(points, H, W, sigma=2.0):
|
| 43 |
+
N = points.shape[0]
|
| 44 |
+
yy, xx = np.mgrid[0:H, 0:W].astype(np.float32)
|
| 45 |
+
heat = np.zeros((H, W, N), dtype=np.float32)
|
| 46 |
+
s2 = 2 * (sigma ** 2)
|
| 47 |
+
for i, (x, y) in enumerate(points):
|
| 48 |
+
d2 = (xx - x) ** 2 + (yy - y) ** 2
|
| 49 |
+
heat[..., i] = np.exp(-d2 / s2)
|
| 50 |
+
return heat
|
| 51 |
+
|
| 52 |
+
for frame_idx, displacement in enumerate(motion_vectors):
|
| 53 |
+
moved_kp = single_kp + displacement
|
| 54 |
+
|
| 55 |
+
# Generate Gaussian heatmap for all points
|
| 56 |
+
new_heatmap = gaussian_heatmaps(moved_kp, H, W, sigma=2.0)
|
| 57 |
+
np.save(os.path.join(output_dir, f"{frame_idx:05d}.npy"), new_heatmap)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
frame_vis = np.zeros((H, W, 3), dtype=np.uint8)
|
| 61 |
+
for (x, y) in moved_kp.astype(int):
|
| 62 |
+
cv2.circle(frame_vis, (x, y), 2, (0, 255, 0), -1)
|
| 63 |
+
video_writer.write(frame_vis)
|
| 64 |
+
|
| 65 |
+
video_writer.release()
|
| 66 |
+
|
| 67 |
+
print(f"Simulated motion heatmaps saved in '{output_dir}'")
|
| 68 |
+
print(f"Preview video saved as '{preview_video}'")
|
utilis/jitter.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
heatmap_dir = "motion_transfer/dataset_single/test_heatmap"
|
| 5 |
+
smoothed_dir = "motion_transfer/dataset_single/smoothed_heatmaps"
|
| 6 |
+
os.makedirs(smoothed_dir, exist_ok=True)
|
| 7 |
+
|
| 8 |
+
# Load existing heatmaps and extract keypoints
|
| 9 |
+
heatmap_files = sorted([f for f in os.listdir(heatmap_dir) if f.endswith(".npy")])
|
| 10 |
+
kp_list = [np.load(os.path.join(heatmap_dir, f)) for f in heatmap_files]
|
| 11 |
+
|
| 12 |
+
# Extract keypoints from each heatmap
|
| 13 |
+
def extract_kpoints(hm):
|
| 14 |
+
kps = []
|
| 15 |
+
for i in range(hm.shape[2]):
|
| 16 |
+
y, x = np.where(hm[:, :, i] > 0)
|
| 17 |
+
if len(x) > 0:
|
| 18 |
+
kps.append([np.mean(x), np.mean(y)])
|
| 19 |
+
else:
|
| 20 |
+
kps.append([0, 0])
|
| 21 |
+
return np.array(kps, dtype=np.float32)
|
| 22 |
+
|
| 23 |
+
kp_list = [extract_kpoints(hm) for hm in kp_list]
|
| 24 |
+
|
| 25 |
+
#Apply temporal smoothing
|
| 26 |
+
def temporal_smoothing(kp_list, alpha=0.7):
|
| 27 |
+
smoothed = [kp_list[0].copy()]
|
| 28 |
+
for i in range(1, len(kp_list)):
|
| 29 |
+
new_kp = alpha * smoothed[-1] + (1 - alpha) * kp_list[i]
|
| 30 |
+
smoothed.append(new_kp)
|
| 31 |
+
return smoothed
|
| 32 |
+
|
| 33 |
+
smoothed_kp_list = temporal_smoothing(kp_list, alpha=0.7)
|
| 34 |
+
|
| 35 |
+
# Recompute heatmaps with same dimensions
|
| 36 |
+
H, W, C = np.load(os.path.join(heatmap_dir, heatmap_files[0])).shape
|
| 37 |
+
|
| 38 |
+
def gaussian_heatmaps(points, H, W, sigma=2.0):
|
| 39 |
+
N = points.shape[0]
|
| 40 |
+
yy, xx = np.mgrid[0:H, 0:W].astype(np.float32)
|
| 41 |
+
heat = np.zeros((H, W, N), dtype=np.float32)
|
| 42 |
+
s2 = 2 * (sigma ** 2)
|
| 43 |
+
for i, (x, y) in enumerate(points):
|
| 44 |
+
d2 = (xx - x) ** 2 + (yy - y) ** 2
|
| 45 |
+
heat[..., i] = np.exp(-d2 / s2)
|
| 46 |
+
return heat
|
| 47 |
+
|
| 48 |
+
for idx, kp in enumerate(smoothed_kp_list):
|
| 49 |
+
new_hm = gaussian_heatmaps(kp, H, W, sigma=2.0)
|
| 50 |
+
np.save(os.path.join(smoothed_dir, heatmap_files[idx]), new_hm)
|