|
|
import gradio as gr |
|
|
import torch |
|
|
import yaml |
|
|
import imageio |
|
|
import numpy as np |
|
|
from skimage.transform import resize |
|
|
from skimage import img_as_ubyte |
|
|
from tqdm import tqdm |
|
|
import os |
|
|
|
|
|
|
|
|
try: |
|
|
from modules.generator import OcclusionAwareGenerator |
|
|
from modules.keypoint_detector import KPDetector |
|
|
from animate import normalize_kp |
|
|
except ModuleNotFoundError as e: |
|
|
print(f"Module import error: {e}") |
|
|
print("Please ensure 'modules' directory with generator.py, keypoint_detector.py, util.py, and animate.py are in the project root.") |
|
|
raise |
|
|
|
|
|
|
|
|
CHECKPOINT_PATH = "vox-cpk.pth.tar" |
|
|
CONFIG_PATH = "config/vox-256.yaml" |
|
|
|
|
|
if not os.path.exists(CHECKPOINT_PATH): |
|
|
os.system(f"wget https://github.com/AliaksandrSiarohin/first-order-model/raw/master/vox-cpk.pth.tar -O {CHECKPOINT_PATH}") |
|
|
|
|
|
if not os.path.exists(CONFIG_PATH): |
|
|
os.makedirs("config", exist_ok=True) |
|
|
os.system(f"wget https://raw.githubusercontent.com/AliaksandrSiarohin/first-order-model/master/config/vox-256.yaml -O {CONFIG_PATH}") |
|
|
|
|
|
|
|
|
with open(CONFIG_PATH) as f: |
|
|
config = yaml.safe_load(f) |
|
|
|
|
|
|
|
|
device = "cpu" |
|
|
generator = OcclusionAwareGenerator(**config['model_params']['generator_params'], **config['model_params']['common_params']).to(device) |
|
|
kp_detector = KPDetector(**config['model_params']['kp_detector_params'], **config['model_params']['common_params']).to(device) |
|
|
|
|
|
checkpoint = torch.load(CHECKPOINT_PATH, map_location=device) |
|
|
generator.load_state_dict(checkpoint['generator']) |
|
|
kp_detector.load_state_dict(checkpoint['kp_detector']) |
|
|
generator.eval() |
|
|
kp_detector.eval() |
|
|
|
|
|
|
|
|
def make_animation(source_image, driving_video, relative=True, adapt_movement_scale=True): |
|
|
try: |
|
|
source_image = resize(source_image, (256, 256))[..., :3] |
|
|
driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video] |
|
|
|
|
|
with torch.no_grad(): |
|
|
predictions = [] |
|
|
source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2).to(device) |
|
|
kp_source = kp_detector(source) |
|
|
driving_initial = torch.tensor(np.array(driving_video)[0][np.newaxis].astype(np.float32)).permute(0, 3, 1, 2).to(device) |
|
|
kp_driving_initial = kp_detector(driving_initial) |
|
|
|
|
|
for frame in tqdm(driving_video, desc="Generating frames"): |
|
|
driving = torch.tensor(np.array(frame)[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2).to(device) |
|
|
kp_driving = kp_detector(driving) |
|
|
kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving, |
|
|
kp_driving_initial=kp_driving_initial, use_relative_movement=relative, |
|
|
use_relative_jacobian=relative, adapt_movement_scale=adapt_movement_scale) |
|
|
out = generator(source, kp_source=kp_source, kp_driving=kp_norm) |
|
|
predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0]) |
|
|
|
|
|
output_video = "result.mp4" |
|
|
imageio.mimsave(output_video, [img_as_ubyte(frame) for frame in predictions], fps=30) |
|
|
return output_video |
|
|
except Exception as e: |
|
|
return f"Animation generation failed: {str(e)}" |
|
|
|
|
|
|
|
|
def animate(source_img, driving_vid): |
|
|
try: |
|
|
source = imageio.imread(source_img) |
|
|
driving = imageio.mimread(driving_vid) |
|
|
return make_animation(source, driving) |
|
|
except Exception as e: |
|
|
return f"Error processing inputs: {str(e)}" |
|
|
|
|
|
iface = gr.Interface( |
|
|
fn=animate, |
|
|
inputs=[ |
|
|
gr.Image(type="filepath", label="Source Image (e.g., face photo, 256x256)"), |
|
|
gr.Video(label="Driving Video (e.g., motion video, short clip <10s)") |
|
|
], |
|
|
outputs=gr.Video(label="Generated Animation"), |
|
|
title="FOMM Image-to-Video Demo", |
|
|
description="Upload a source image and driving video to generate animation. CPU-based, may take a few minutes." |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
iface.launch() |