File size: 4,716 Bytes
a125901
 
900bbc0
a125901
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
900bbc0
a125901
 
 
 
 
900bbc0
a125901
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
900bbc0
a125901
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import pathlib
import spaces  # Required for ZeroGPU
import gradio as gr
import torch
from PIL import Image

repo_dir = pathlib.Path("Thin-Plate-Spline-Motion-Model").absolute()
if not repo_dir.exists():
    os.system("git clone https://github.com/yoyo-nb/Thin-Plate-Spline-Motion-Model")
os.chdir(repo_dir.name)
if not (repo_dir / "checkpoints").exists():
    os.system("mkdir checkpoints")
if not (repo_dir / "checkpoints/vox.pth.tar").exists():
    os.system("gdown 1-CKOjv_y_TzNe-dwQsjjeVxJUuyBAb5X -O checkpoints/vox.pth.tar")

title = "#✨ MotionMagicAI"
DESCRIPTION = '''### πŸŽ₯ <b>MotionMagicAI</b> Brings Images to Life! πŸš€ Powered by Thin-Plate Spline Motion Model (CVPR 2022). Upload a face, add a video, and watch it dance or sing! πŸ•ΊπŸ’ƒ <a href='https://arxiv.org/abs/2203.14367'>[Paper]</a> <a href='https://github.com/yoyo-nb/Thin-Plate-Spline-Motion-Model'>[Code]</a>
<img id="overview" alt="overview" src="https://github.com/yoyo-nb/Thin-Plate-Spline-Motion-Model/raw/main/assets/vox.gif" />
'''
FOOTER = '<img id="visitor-badge" alt="visitor badge" src="https://visitor-badge.glitch.me/badge?page_id=gradio-blocks.Image-Animation-using-Thin-Plate-Spline-Motion-Model" />'

def get_style_image_path(style_name: str) -> str:
    base_path = 'assets'
    filenames = {
        'source': 'source.png',
        'driving': 'driving.mp4',
    }
    return f'{base_path}/{filenames[style_name]}'

def get_style_image_markdown_text(style_name: str) -> str:
    url = get_style_image_path(style_name)
    return f'<img id="style-image" src="{url}" alt="style image">'

def update_style_image(style_name: str) -> dict:
    text = get_style_image_markdown_text(style_name)
    return gr.Markdown.update(value=text)

@spaces.GPU(duration=300)  # Increased duration for animation (5 minutes)
def inference(img, vid):
    if not os.path.exists('temp'):
        os.system('mkdir temp')

    img.save("temp/image.jpg", "JPEG")
    # Check CUDA availability before running
    if torch.cuda.is_available():
        os.system(f"python demo.py --config config/vox-256.yaml --checkpoint ./checkpoints/vox.pth.tar --source_image 'temp/image.jpg' --driving_video {vid} --result_video './temp/result.mp4'")
    else:
        os.system(f"python demo.py --config config/vox-256.yaml --checkpoint ./checkpoints/vox.pth.tar --source_image 'temp/image.jpg' --driving_video {vid} --result_video './temp/result.mp4' --cpu")
    return './temp/result.mp4'

def main():
    with gr.Blocks(css='style.css') as demo:
        gr.Markdown(title)
        gr.Markdown(DESCRIPTION)

        with gr.Box():
            gr.Markdown('''## Step 1 (Provide Input Face Image)
- Drop an image containing a face to the **Input Image**.
    - If there are multiple faces in the image, use Edit button in the upper right corner and crop the input image beforehand.
''')
            with gr.Row():
                with gr.Column():
                    with gr.Row():
                        input_image = gr.Image(label='Input Image',
                                               type="pil")

            with gr.Row():
                paths = sorted(pathlib.Path('assets').glob('*.png'))
                gr.Examples(inputs=[input_image],
                            examples=[[path.as_posix()] for path in paths])

        with gr.Box():
            gr.Markdown('''## Step 2 (Select Driving Video)
- Select **Style Driving Video for the face image animation**.
''')
            with gr.Row():
                with gr.Column():
                    with gr.Row():
                        driving_video = gr.Video(label='Driving Video',
                                                 format="mp4")

            with gr.Row():
                paths = sorted(pathlib.Path('assets').glob('*.mp4'))
                gr.Examples(inputs=[driving_video],
                            examples=[[path.as_posix()] for path in paths])

        with gr.Box():
            gr.Markdown('''## Step 3 (Generate Animated Image based on the Video)
- Hit the **Generate** button. (Note: It may take a few minutes to generate results.)
''')
            with gr.Row():
                with gr.Column():
                    with gr.Row():
                        generate_button = gr.Button('Generate')

                with gr.Column():
                    result = gr.Video(label="Output")
        gr.Markdown(FOOTER)
        generate_button.click(fn=inference,
                              inputs=[
                                  input_image,
                                  driving_video
                              ],
                              outputs=result)

    demo.queue(max_size=10).launch()

if __name__ == '__main__':
    main()