File size: 4,045 Bytes
dd52e36
 
 
 
 
174e322
 
dd52e36
174e322
 
 
 
 
 
 
 
 
 
 
dd52e36
174e322
dd52e36
 
 
 
174e322
 
 
 
 
dd52e36
 
 
 
 
 
174e322
dd52e36
 
 
 
 
 
 
 
 
174e322
dd52e36
174e322
 
 
dd52e36
174e322
 
 
 
 
 
dd52e36
174e322
 
 
 
 
 
 
 
dd52e36
174e322
 
 
 
 
dd52e36
 
 
174e322
 
 
 
 
 
dd52e36
 
 
174e322
 
 
 
 
 
 
dd52e36
 
174e322
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import gradio as gr
import torch
import yaml
import imageio
import numpy as np
from skimage.transform import resize
from skimage import img_as_ubyte
from tqdm import tqdm
import os

# 导入本地模块
try:
    from modules.generator import OcclusionAwareGenerator
    from modules.keypoint_detector import KPDetector
    from animate import normalize_kp
except ModuleNotFoundError as e:
    print(f"Module import error: {e}")
    print("Please ensure 'modules' directory with generator.py, keypoint_detector.py, util.py, and animate.py are in the project root.")
    raise

# 下载预训练 checkpoint
CHECKPOINT_PATH = "vox-cpk.pth.tar"
CONFIG_PATH = "config/vox-256.yaml"

if not os.path.exists(CHECKPOINT_PATH):
    os.system(f"wget https://github.com/AliaksandrSiarohin/first-order-model/raw/master/vox-cpk.pth.tar -O {CHECKPOINT_PATH}")

if not os.path.exists(CONFIG_PATH):
    os.makedirs("config", exist_ok=True)
    os.system(f"wget https://raw.githubusercontent.com/AliaksandrSiarohin/first-order-model/master/config/vox-256.yaml -O {CONFIG_PATH}")

# 加载配置
with open(CONFIG_PATH) as f:
    config = yaml.safe_load(f)

# 加载模型
device = "cpu"
generator = OcclusionAwareGenerator(**config['model_params']['generator_params'], **config['model_params']['common_params']).to(device)
kp_detector = KPDetector(**config['model_params']['kp_detector_params'], **config['model_params']['common_params']).to(device)

checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
generator.load_state_dict(checkpoint['generator'])
kp_detector.load_state_dict(checkpoint['kp_detector'])
generator.eval()
kp_detector.eval()

# 动画函数
def make_animation(source_image, driving_video, relative=True, adapt_movement_scale=True):
    try:
        source_image = resize(source_image, (256, 256))[..., :3]
        driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]

        with torch.no_grad():
            predictions = []
            source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2).to(device)
            kp_source = kp_detector(source)
            driving_initial = torch.tensor(np.array(driving_video)[0][np.newaxis].astype(np.float32)).permute(0, 3, 1, 2).to(device)
            kp_driving_initial = kp_detector(driving_initial)

            for frame in tqdm(driving_video, desc="Generating frames"):
                driving = torch.tensor(np.array(frame)[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2).to(device)
                kp_driving = kp_detector(driving)
                kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
                                       kp_driving_initial=kp_driving_initial, use_relative_movement=relative,
                                       use_relative_jacobian=relative, adapt_movement_scale=adapt_movement_scale)
                out = generator(source, kp_source=kp_source, kp_driving=kp_norm)
                predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])

        output_video = "result.mp4"
        imageio.mimsave(output_video, [img_as_ubyte(frame) for frame in predictions], fps=30)
        return output_video
    except Exception as e:
        return f"Animation generation failed: {str(e)}"

# Gradio 接口
def animate(source_img, driving_vid):
    try:
        source = imageio.imread(source_img)
        driving = imageio.mimread(driving_vid)
        return make_animation(source, driving)
    except Exception as e:
        return f"Error processing inputs: {str(e)}"

iface = gr.Interface(
    fn=animate,
    inputs=[
        gr.Image(type="filepath", label="Source Image (e.g., face photo, 256x256)"),
        gr.Video(label="Driving Video (e.g., motion video, short clip <10s)")
    ],
    outputs=gr.Video(label="Generated Animation"),
    title="FOMM Image-to-Video Demo",
    description="Upload a source image and driving video to generate animation. CPU-based, may take a few minutes."
)

if __name__ == "__main__":
    iface.launch()