gaoton / app.py
wkplhc's picture
Update app.py
174e322 verified
import gradio as gr
import torch
import yaml
import imageio
import numpy as np
from skimage.transform import resize
from skimage import img_as_ubyte
from tqdm import tqdm
import os
# 导入本地模块
try:
from modules.generator import OcclusionAwareGenerator
from modules.keypoint_detector import KPDetector
from animate import normalize_kp
except ModuleNotFoundError as e:
print(f"Module import error: {e}")
print("Please ensure 'modules' directory with generator.py, keypoint_detector.py, util.py, and animate.py are in the project root.")
raise
# 下载预训练 checkpoint
CHECKPOINT_PATH = "vox-cpk.pth.tar"
CONFIG_PATH = "config/vox-256.yaml"
if not os.path.exists(CHECKPOINT_PATH):
os.system(f"wget https://github.com/AliaksandrSiarohin/first-order-model/raw/master/vox-cpk.pth.tar -O {CHECKPOINT_PATH}")
if not os.path.exists(CONFIG_PATH):
os.makedirs("config", exist_ok=True)
os.system(f"wget https://raw.githubusercontent.com/AliaksandrSiarohin/first-order-model/master/config/vox-256.yaml -O {CONFIG_PATH}")
# 加载配置
with open(CONFIG_PATH) as f:
config = yaml.safe_load(f)
# 加载模型
device = "cpu"
generator = OcclusionAwareGenerator(**config['model_params']['generator_params'], **config['model_params']['common_params']).to(device)
kp_detector = KPDetector(**config['model_params']['kp_detector_params'], **config['model_params']['common_params']).to(device)
checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
generator.load_state_dict(checkpoint['generator'])
kp_detector.load_state_dict(checkpoint['kp_detector'])
generator.eval()
kp_detector.eval()
# 动画函数
def make_animation(source_image, driving_video, relative=True, adapt_movement_scale=True):
try:
source_image = resize(source_image, (256, 256))[..., :3]
driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]
with torch.no_grad():
predictions = []
source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2).to(device)
kp_source = kp_detector(source)
driving_initial = torch.tensor(np.array(driving_video)[0][np.newaxis].astype(np.float32)).permute(0, 3, 1, 2).to(device)
kp_driving_initial = kp_detector(driving_initial)
for frame in tqdm(driving_video, desc="Generating frames"):
driving = torch.tensor(np.array(frame)[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2).to(device)
kp_driving = kp_detector(driving)
kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
kp_driving_initial=kp_driving_initial, use_relative_movement=relative,
use_relative_jacobian=relative, adapt_movement_scale=adapt_movement_scale)
out = generator(source, kp_source=kp_source, kp_driving=kp_norm)
predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
output_video = "result.mp4"
imageio.mimsave(output_video, [img_as_ubyte(frame) for frame in predictions], fps=30)
return output_video
except Exception as e:
return f"Animation generation failed: {str(e)}"
# Gradio 接口
def animate(source_img, driving_vid):
try:
source = imageio.imread(source_img)
driving = imageio.mimread(driving_vid)
return make_animation(source, driving)
except Exception as e:
return f"Error processing inputs: {str(e)}"
iface = gr.Interface(
fn=animate,
inputs=[
gr.Image(type="filepath", label="Source Image (e.g., face photo, 256x256)"),
gr.Video(label="Driving Video (e.g., motion video, short clip <10s)")
],
outputs=gr.Video(label="Generated Animation"),
title="FOMM Image-to-Video Demo",
description="Upload a source image and driving video to generate animation. CPU-based, may take a few minutes."
)
if __name__ == "__main__":
iface.launch()