File size: 5,662 Bytes
1e135d7
4c79f9d
1e135d7
 
 
 
 
4c79f9d
1e135d7
4c79f9d
1e135d7
 
 
 
 
 
 
 
 
 
 
 
4c79f9d
f4d455b
1e135d7
f4d455b
4c79f9d
f4d455b
1e135d7
4c79f9d
 
1e135d7
 
 
 
 
 
 
 
 
f4d455b
 
 
1e135d7
 
4c79f9d
 
 
 
f4d455b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e135d7
4c79f9d
 
fe0a2e1
 
 
 
 
 
 
 
 
 
4c79f9d
 
 
f4d455b
 
 
4c79f9d
fe0a2e1
 
 
 
 
 
f4d455b
fe0a2e1
 
 
 
f4d455b
fe0a2e1
 
 
 
 
 
 
 
 
 
 
 
 
 
f4d455b
fe0a2e1
 
f4d455b
fe0a2e1
 
 
 
 
 
 
 
 
 
 
1e135d7
4c79f9d
 
 
 
1e135d7
 
 
 
 
4c79f9d
1e135d7
 
4c79f9d
f4d455b
1e135d7
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import gradio as gr
import spaces
import subprocess
import tempfile
import base64
import os
import shutil
import sys

# SadTalker path
SADTALKER_DIR = "/home/user/SadTalker"

def setup_sadtalker():
    """Clone and setup SadTalker if not already done"""
    if not os.path.exists(SADTALKER_DIR):
        print("Cloning SadTalker...")
        subprocess.run([
            "git", "clone", "--depth", "1",
            "https://github.com/OpenTalker/SadTalker.git",
            SADTALKER_DIR
        ], check=True)
        
        # Install SadTalker requirements
        print("Installing SadTalker requirements...")
        subprocess.run([
            sys.executable, "-m", "pip", "install", "-q", "-r", 
            f"{SADTALKER_DIR}/requirements.txt"
        ])
        
        # Download checkpoints from HuggingFace
        print("Downloading checkpoints...")
        from huggingface_hub import snapshot_download
        snapshot_download(
            repo_id="vinthony/SadTalker",
            local_dir=f"{SADTALKER_DIR}/checkpoints",
            local_dir_use_symlinks=False
        )
        
    return True

@spaces.GPU(duration=120)
def generate_video_gpu(image_path: str, audio_path: str, output_dir: str) -> str:
    """GPU-accelerated video generation"""
    setup_sadtalker()
    
    # Add SadTalker to path
    if SADTALKER_DIR not in sys.path:
        sys.path.insert(0, SADTALKER_DIR)
    
    # Run SadTalker inference
    cmd = [
        sys.executable, f"{SADTALKER_DIR}/inference.py",
        "--driven_audio", audio_path,
        "--source_image", image_path,
        "--result_dir", output_dir,
        "--still",
        "--preprocess", "crop",
    ]
    
    print(f"Running: {' '.join(cmd)}")
    result = subprocess.run(
        cmd, 
        capture_output=True, 
        text=True, 
        cwd=SADTALKER_DIR
    )
    
    print(f"STDOUT: {result.stdout}")
    if result.stderr:
        print(f"STDERR: {result.stderr}")
    
    if result.returncode != 0:
        raise Exception(f"SadTalker failed: {result.stderr}")
    
    # Find generated video
    for root, dirs, files in os.walk(output_dir):
        for f in files:
            if f.endswith(".mp4"):
                return os.path.join(root, f)
    
    raise Exception("No video generated")

def gradio_generate(image, audio):
    """Gradio interface wrapper"""
    # Debug logging
    print(f"=== RECEIVED ===")
    print(f"Image type: {type(image)}, value: {image}")
    print(f"Audio type: {type(audio)}, value: {audio}")
    
    # Better validation
    if not image:
        raise gr.Error("Envie uma imagem primeiro!")
    if not audio:
        raise gr.Error("Envie um áudio primeiro!")
    
    with tempfile.TemporaryDirectory() as tmpdir:
        image_path = os.path.join(tmpdir, "input.png")
        audio_path = os.path.join(tmpdir, "input.wav")
        output_dir = os.path.join(tmpdir, "output")
        os.makedirs(output_dir, exist_ok=True)
        
        # Handle image
        try:
            if isinstance(image, str) and os.path.exists(image):
                shutil.copy(image, image_path)
            elif hasattr(image, 'name'):  # File object
                shutil.copy(image.name, image_path)
            else:
                from PIL import Image as PILImage
                PILImage.fromarray(image).save(image_path)
        except Exception as e:
            raise gr.Error(f"Erro ao processar imagem: {e}")
        
        # Handle audio
        try:
            if isinstance(audio, str) and os.path.exists(audio):
                shutil.copy(audio, audio_path)
            elif hasattr(audio, 'name'):  # File object
                shutil.copy(audio.name, audio_path)
            elif isinstance(audio, tuple):
                import scipy.io.wavfile as wav
                sr, data = audio
                wav.write(audio_path, sr, data)
            else:
                raise gr.Error(f"Formato de áudio não reconhecido: {type(audio)}")
        except Exception as e:
            raise gr.Error(f"Erro ao processar áudio: {e}")
        
        print(f"Image saved: {image_path}, exists: {os.path.exists(image_path)}")
        print(f"Audio saved: {audio_path}, exists: {os.path.exists(audio_path)}")
        
        # Generate video
        try:
            video_path = generate_video_gpu(image_path, audio_path, output_dir)
            
            # Copy to persistent location
            final_path = "/tmp/sadtalker_output.mp4"
            shutil.copy(video_path, final_path)
            
            return final_path
        except Exception as e:
            raise gr.Error(f"Erro na geração: {e}")

# Create Gradio app
with gr.Blocks(title="SadTalker API") as demo:
    gr.Markdown("# 🎭 SadTalker API")
    gr.Markdown("Generate talking head videos from image + audio (ZeroGPU)")
    
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(label="Face Image", type="filepath")
            audio_input = gr.Audio(label="Audio", type="filepath")
            generate_btn = gr.Button("🎬 Generate Video", variant="primary", size="lg")
        
        with gr.Column():
            video_output = gr.Video(label="Generated Video")
            gr.Markdown("⏱️ Takes ~30-60 seconds with GPU")
    
    generate_btn.click(
        fn=gradio_generate,
        inputs=[image_input, audio_input],
        outputs=video_output
    )

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)