|
|
import argparse |
|
|
import os |
|
|
from pathlib import Path |
|
|
|
|
|
import cv2 |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import numpy as np |
|
|
import face_alignment |
|
|
import torchvision.transforms as transforms |
|
|
from PIL import Image |
|
|
from tqdm import tqdm |
|
|
|
|
|
from renderer.models import IMTRenderer |
|
|
|
|
|
|
|
|
torch.backends.cudnn.enabled = True |
|
|
torch.backends.cudnn.benchmark = True |
|
|
|
|
|
|
|
|
class DataProcessor: |
|
|
def __init__(self, opt): |
|
|
self.opt = opt |
|
|
self.input_size = opt.input_size |
|
|
self.fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, flip_input=False) |
|
|
|
|
|
self.transform = transforms.Compose([ |
|
|
transforms.Resize((512, 512)), |
|
|
transforms.ToTensor(), |
|
|
]) |
|
|
|
|
|
@torch.no_grad() |
|
|
def process_img(self, img): |
|
|
"""Detects face and crops the image to the face region.""" |
|
|
if isinstance(img, Image.Image): |
|
|
img = np.array(img) |
|
|
|
|
|
|
|
|
if img.ndim == 2: |
|
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) |
|
|
elif img.shape[2] == 4: |
|
|
img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB) |
|
|
|
|
|
|
|
|
h, w = img.shape[:2] |
|
|
|
|
|
|
|
|
bboxes = self.fa.face_detector.detect_from_image(img) |
|
|
|
|
|
|
|
|
valid_bboxes = [ |
|
|
(int(x1), int(y1), int(x2 ), int(y2 ), score) |
|
|
for (x1, y1, x2, y2, score) in bboxes if score > 0.95 |
|
|
] |
|
|
|
|
|
if not valid_bboxes: |
|
|
print("[WARN] No face detected, falling back to center resize.") |
|
|
return cv2.resize(img, (self.input_size, self.input_size)) |
|
|
|
|
|
|
|
|
x1, y1, x2, y2, _ = valid_bboxes[0] |
|
|
bsy, bsx = int((y2 - y1) / 2), int((x2 - x1) / 2) |
|
|
my, mx = int((y1 + y2) / 2), int((x1 + x2) / 2) |
|
|
bs = int(max(bsy, bsx) * 1.6) |
|
|
|
|
|
|
|
|
img = cv2.copyMakeBorder(img, bs, bs, bs, bs, cv2.BORDER_CONSTANT, value=0) |
|
|
|
|
|
|
|
|
my, mx = my + bs, mx + bs |
|
|
crop_img = img[my - bs:my + bs, mx - bs:mx + bs] |
|
|
return Image.fromarray(crop_img) |
|
|
|
|
|
def load_image(self, path): |
|
|
img = cv2.imread(path) |
|
|
if img is None: |
|
|
raise FileNotFoundError(f"Image not found at {path}") |
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
|
return Image.fromarray(img) |
|
|
|
|
|
|
|
|
def save_video(vid_target_recon, save_path, fps): |
|
|
vid = vid_target_recon.permute(0, 2, 3, 1).clamp(0, 1).cpu().numpy() |
|
|
vid = (vid * 255).astype(np.uint8) |
|
|
T, H, W, C = vid.shape |
|
|
|
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
|
writer = cv2.VideoWriter(save_path, fourcc, fps, (W, H)) |
|
|
|
|
|
for frame in vid: |
|
|
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) |
|
|
writer.write(frame_bgr) |
|
|
|
|
|
writer.release() |
|
|
print(f"[Success] Video saved to {save_path}") |
|
|
|
|
|
|
|
|
class Demo(nn.Module): |
|
|
def __init__(self, args, gen): |
|
|
super(Demo, self).__init__() |
|
|
self.args = args |
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
print('==> Loading model...') |
|
|
self.gen = gen.to(self.device) |
|
|
self.gen.eval() |
|
|
|
|
|
self.save_path = args.save_path |
|
|
os.makedirs(self.save_path, exist_ok=True) |
|
|
self.processor = DataProcessor(args) |
|
|
|
|
|
@torch.no_grad() |
|
|
def process_single(self, source_path, driving_path): |
|
|
print(f"==> Processing: {source_path} + {driving_path}") |
|
|
|
|
|
|
|
|
source_img = self.processor.load_image(source_path) |
|
|
if self.args.crop: |
|
|
source_img = self.processor.process_img(source_img) |
|
|
|
|
|
source_tensor = self.processor.transform(source_img).unsqueeze(0).to(self.device) |
|
|
|
|
|
f_r, i_r = self.gen.app_encode(source_tensor) |
|
|
t_r = self.gen.mot_encode(source_tensor) |
|
|
ta_r = self.gen.adapt(t_r, i_r) |
|
|
ma_r = self.gen.mot_decode(ta_r) |
|
|
|
|
|
|
|
|
cap = cv2.VideoCapture(driving_path) |
|
|
fps = cap.get(cv2.CAP_PROP_FPS) if self.args.fps is None else self.args.fps |
|
|
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
|
|
|
vid_target_recon = [] |
|
|
|
|
|
pbar = tqdm(total=frame_count, desc="Inferencing") |
|
|
while True: |
|
|
ret, frame = cap.read() |
|
|
if not ret: |
|
|
break |
|
|
|
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
frame_pil = Image.fromarray(frame) |
|
|
|
|
|
|
|
|
frame_tensor = self.processor.transform(frame_pil).unsqueeze(0).to(self.device) |
|
|
|
|
|
|
|
|
t_c = self.gen.mot_encode(frame_tensor) |
|
|
ta_c = self.gen.adapt(t_c, i_r) |
|
|
ma_c = self.gen.mot_decode(ta_c) |
|
|
out = self.gen.decode(ma_c, ma_r, f_r) |
|
|
|
|
|
vid_target_recon.append(out.cpu()) |
|
|
pbar.update(1) |
|
|
|
|
|
cap.release() |
|
|
pbar.close() |
|
|
|
|
|
if not vid_target_recon: |
|
|
print("[Error] No frames generated.") |
|
|
return |
|
|
|
|
|
|
|
|
vid_target_recon = torch.cat(vid_target_recon, dim=0) |
|
|
save_name = f"{Path(source_path).stem}_{Path(driving_path).stem}.mp4" |
|
|
save_video(vid_target_recon, os.path.join(self.save_path, save_name), fps) |
|
|
|
|
|
def process_batch(self, root_dir): |
|
|
subdirs = [ |
|
|
os.path.join(root_dir, d) for d in os.listdir(root_dir) |
|
|
if os.path.isdir(os.path.join(root_dir, d)) |
|
|
] |
|
|
|
|
|
for sub in subdirs: |
|
|
img_files = [f for f in os.listdir(sub) if f.lower().endswith((".png", ".jpg", ".jpeg"))] |
|
|
vid_files = [f for f in os.listdir(sub) if f.lower().endswith((".mp4", ".avi", ".mov"))] |
|
|
|
|
|
if not img_files or not vid_files: |
|
|
print(f"[Skip] {sub} missing image or video.") |
|
|
continue |
|
|
|
|
|
img_path = os.path.join(sub, img_files[0]) |
|
|
vid_path = os.path.join(sub, vid_files[0]) |
|
|
|
|
|
try: |
|
|
self.process_single(img_path, vid_path) |
|
|
except Exception as e: |
|
|
print(f"[Error] Failed processing {sub}: {e}") |
|
|
|
|
|
def run(self): |
|
|
if self.args.source_path and self.args.driving_path: |
|
|
self.process_single(self.args.source_path, self.args.driving_path) |
|
|
elif self.args.data_dir: |
|
|
self.process_batch(self.args.data_dir) |
|
|
else: |
|
|
raise ValueError("Usage: Provide either --source_path & --driving_path OR --data_dir") |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
parser = argparse.ArgumentParser(description="Image Animation Inference Demo") |
|
|
|
|
|
|
|
|
parser.add_argument("--source_path", type=str, help="Path to source image") |
|
|
parser.add_argument("--driving_path", type=str, help="Path to driving video") |
|
|
parser.add_argument("--data_dir", type=str, help="Batch directory containing subfolders") |
|
|
parser.add_argument("--save_path", type=str, default="./results", help="Output directory") |
|
|
|
|
|
|
|
|
parser.add_argument("--renderer_path", type=str, default="./checkpoints/renderer.ckpt", help="Checkpoint path") |
|
|
parser.add_argument("--input_size", type=int, default=512, help="Resolution") |
|
|
parser.add_argument('--swin_res_threshold', type=int, default=128) |
|
|
parser.add_argument('--num_heads', type=int, default=8) |
|
|
parser.add_argument('--window_size', type=int, default=8) |
|
|
|
|
|
|
|
|
parser.add_argument("--fps", type=int, default=None, help="Output FPS (default: same as input)") |
|
|
parser.add_argument("--crop", action="store_true", help="Crop face from source image") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
model = IMTRenderer(args) |
|
|
checkpoint = torch.load(args.renderer_path, map_location="cpu") |
|
|
|
|
|
|
|
|
state_dict = checkpoint.get("state_dict", checkpoint) |
|
|
clean_state_dict = { |
|
|
k.replace("gen.", ""): v for k, v in state_dict.items() if k.startswith("gen.") |
|
|
} |
|
|
|
|
|
model.load_state_dict(clean_state_dict, strict=False) |
|
|
|
|
|
|
|
|
demo = Demo(args, model) |
|
|
demo.run() |
|
|
|