import gradio as gr import torch import yaml import imageio import numpy as np from skimage.transform import resize from skimage import img_as_ubyte from tqdm import tqdm import os # 导入本地模块 try: from modules.generator import OcclusionAwareGenerator from modules.keypoint_detector import KPDetector from animate import normalize_kp except ModuleNotFoundError as e: print(f"Module import error: {e}") print("Please ensure 'modules' directory with generator.py, keypoint_detector.py, util.py, and animate.py are in the project root.") raise # 下载预训练 checkpoint CHECKPOINT_PATH = "vox-cpk.pth.tar" CONFIG_PATH = "config/vox-256.yaml" if not os.path.exists(CHECKPOINT_PATH): os.system(f"wget https://github.com/AliaksandrSiarohin/first-order-model/raw/master/vox-cpk.pth.tar -O {CHECKPOINT_PATH}") if not os.path.exists(CONFIG_PATH): os.makedirs("config", exist_ok=True) os.system(f"wget https://raw.githubusercontent.com/AliaksandrSiarohin/first-order-model/master/config/vox-256.yaml -O {CONFIG_PATH}") # 加载配置 with open(CONFIG_PATH) as f: config = yaml.safe_load(f) # 加载模型 device = "cpu" generator = OcclusionAwareGenerator(**config['model_params']['generator_params'], **config['model_params']['common_params']).to(device) kp_detector = KPDetector(**config['model_params']['kp_detector_params'], **config['model_params']['common_params']).to(device) checkpoint = torch.load(CHECKPOINT_PATH, map_location=device) generator.load_state_dict(checkpoint['generator']) kp_detector.load_state_dict(checkpoint['kp_detector']) generator.eval() kp_detector.eval() # 动画函数 def make_animation(source_image, driving_video, relative=True, adapt_movement_scale=True): try: source_image = resize(source_image, (256, 256))[..., :3] driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video] with torch.no_grad(): predictions = [] source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2).to(device) kp_source = kp_detector(source) driving_initial = torch.tensor(np.array(driving_video)[0][np.newaxis].astype(np.float32)).permute(0, 3, 1, 2).to(device) kp_driving_initial = kp_detector(driving_initial) for frame in tqdm(driving_video, desc="Generating frames"): driving = torch.tensor(np.array(frame)[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2).to(device) kp_driving = kp_detector(driving) kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving, kp_driving_initial=kp_driving_initial, use_relative_movement=relative, use_relative_jacobian=relative, adapt_movement_scale=adapt_movement_scale) out = generator(source, kp_source=kp_source, kp_driving=kp_norm) predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0]) output_video = "result.mp4" imageio.mimsave(output_video, [img_as_ubyte(frame) for frame in predictions], fps=30) return output_video except Exception as e: return f"Animation generation failed: {str(e)}" # Gradio 接口 def animate(source_img, driving_vid): try: source = imageio.imread(source_img) driving = imageio.mimread(driving_vid) return make_animation(source, driving) except Exception as e: return f"Error processing inputs: {str(e)}" iface = gr.Interface( fn=animate, inputs=[ gr.Image(type="filepath", label="Source Image (e.g., face photo, 256x256)"), gr.Video(label="Driving Video (e.g., motion video, short clip <10s)") ], outputs=gr.Video(label="Generated Animation"), title="FOMM Image-to-Video Demo", description="Upload a source image and driving video to generate animation. CPU-based, may take a few minutes." ) if __name__ == "__main__": iface.launch()