Madras1 commited on
Commit
4c79f9d
·
verified ·
1 Parent(s): 30610b2

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +6 -32
  2. app.py +68 -52
  3. requirements.txt +4 -3
README.md CHANGED
@@ -9,41 +9,15 @@ app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- # SadTalker API
13
 
14
- Talking head generation API using SadTalker in CPU mode.
15
 
16
  ## Features
17
- - Generates video from face image + audio
18
- - Runs on CPU (no GPU required)
19
- - Returns base64 encoded video
20
 
21
  ## Usage
22
 
23
- ### Via UI
24
- Upload an image and audio file, click Generate.
25
-
26
- ### Via API
27
- ```python
28
- import requests
29
- import base64
30
-
31
- # Read files
32
- with open("face.png", "rb") as f:
33
- image_b64 = base64.b64encode(f.read()).decode()
34
-
35
- with open("audio.mp3", "rb") as f:
36
- audio_b64 = base64.b64encode(f.read()).decode()
37
-
38
- # Call API
39
- response = requests.post(
40
- "https://your-space.hf.space/api/predict",
41
- json={"data": [image_b64, audio_b64]}
42
- )
43
-
44
- video_b64 = response.json()["data"][0]
45
- ```
46
-
47
- ## Notes
48
- - First run will download ~2GB of model weights
49
- - Each generation takes 1-2 minutes on CPU
 
9
  pinned: false
10
  ---
11
 
12
+ # SadTalker API 🎭
13
 
14
+ Talking head generation using SadTalker with **ZeroGPU**.
15
 
16
  ## Features
17
+ - GPU-accelerated (~20-40 seconds)
18
+ - 🎨 Face enhancement with GFPGAN
19
+ - 📹 Returns MP4 video
20
 
21
  ## Usage
22
 
23
+ Upload a face image and audio file, click Generate.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,11 +1,13 @@
1
  import gradio as gr
 
2
  import subprocess
3
  import tempfile
4
  import base64
5
  import os
6
  import shutil
 
7
 
8
- # Clone SadTalker on first run
9
  SADTALKER_DIR = "/home/user/SadTalker"
10
 
11
  def setup_sadtalker():
@@ -18,15 +20,14 @@ def setup_sadtalker():
18
  SADTALKER_DIR
19
  ], check=True)
20
 
21
- # Download checkpoints
22
- print("Downloading checkpoints...")
23
- os.makedirs(f"{SADTALKER_DIR}/checkpoints", exist_ok=True)
24
-
25
- # Download from HuggingFace
26
  subprocess.run([
27
- "pip", "install", "huggingface_hub"
 
28
  ], check=True)
29
 
 
 
30
  from huggingface_hub import snapshot_download
31
  snapshot_download(
32
  repo_id="vinthony/SadTalker",
@@ -36,33 +37,45 @@ def setup_sadtalker():
36
 
37
  return True
38
 
 
39
  def generate_video(image_path: str, audio_path: str) -> str:
40
  """
41
  Generate talking head video from image and audio
42
- Returns: path to generated video
43
  """
44
  setup_sadtalker()
45
 
 
 
 
 
46
  with tempfile.TemporaryDirectory() as tmpdir:
47
  output_dir = os.path.join(tmpdir, "output")
48
  os.makedirs(output_dir, exist_ok=True)
49
 
50
- # Run SadTalker inference
51
  cmd = [
52
- "python", f"{SADTALKER_DIR}/inference.py",
53
  "--driven_audio", audio_path,
54
  "--source_image", image_path,
55
  "--result_dir", output_dir,
56
  "--still", # Less movement, faster
57
  "--preprocess", "crop",
58
- "--cpu" # Force CPU mode
59
  ]
60
 
61
  print(f"Running: {' '.join(cmd)}")
62
- result = subprocess.run(cmd, capture_output=True, text=True, cwd=SADTALKER_DIR)
 
 
 
 
 
 
63
 
64
  if result.returncode != 0:
65
- print(f"Error: {result.stderr}")
 
66
  raise Exception(f"SadTalker failed: {result.stderr}")
67
 
68
  # Find generated video
@@ -70,12 +83,46 @@ def generate_video(image_path: str, audio_path: str) -> str:
70
  for f in files:
71
  if f.endswith(".mp4"):
72
  video_path = os.path.join(root, f)
73
- # Read and return as base64
74
  with open(video_path, "rb") as vf:
75
  return base64.b64encode(vf.read()).decode("utf-8")
76
 
77
  raise Exception("No video generated")
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  def api_generate(image_base64: str, audio_base64: str) -> dict:
80
  """API endpoint for generating video"""
81
  try:
@@ -98,50 +145,20 @@ def api_generate(image_base64: str, audio_base64: str) -> dict:
98
  except Exception as e:
99
  return {"success": False, "error": str(e)}
100
 
101
- # Gradio interface for testing
102
- def gradio_generate(image, audio):
103
- """Gradio interface wrapper"""
104
- if image is None or audio is None:
105
- return None
106
-
107
- with tempfile.TemporaryDirectory() as tmpdir:
108
- # Save uploaded files
109
- image_path = os.path.join(tmpdir, "input.png")
110
- audio_path = os.path.join(tmpdir, "input.mp3")
111
-
112
- # Handle image (could be numpy array or path)
113
- if isinstance(image, str):
114
- shutil.copy(image, image_path)
115
- else:
116
- from PIL import Image
117
- Image.fromarray(image).save(image_path)
118
-
119
- # Handle audio
120
- shutil.copy(audio, audio_path)
121
-
122
- # Generate
123
- video_base64 = generate_video(image_path, audio_path)
124
-
125
- # Save to temp file for Gradio
126
- output_path = os.path.join(tmpdir, "output.mp4")
127
- with open(output_path, "wb") as f:
128
- f.write(base64.b64decode(video_base64))
129
-
130
- return output_path
131
-
132
- # Create Gradio app with API
133
- with gr.Blocks() as demo:
134
- gr.Markdown("# SadTalker API 🎭")
135
- gr.Markdown("Generate talking head videos from image + audio")
136
 
137
  with gr.Row():
138
  with gr.Column():
139
  image_input = gr.Image(label="Face Image", type="filepath")
140
  audio_input = gr.Audio(label="Audio", type="filepath")
141
- generate_btn = gr.Button("Generate", variant="primary")
142
 
143
  with gr.Column():
144
- video_output = gr.Video(label="Result")
 
145
 
146
  generate_btn.click(
147
  fn=gradio_generate,
@@ -149,6 +166,5 @@ with gr.Blocks() as demo:
149
  outputs=video_output
150
  )
151
 
152
- # Launch with API enabled
153
  if __name__ == "__main__":
154
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import gradio as gr
2
+ import spaces
3
  import subprocess
4
  import tempfile
5
  import base64
6
  import os
7
  import shutil
8
+ import sys
9
 
10
+ # SadTalker path
11
  SADTALKER_DIR = "/home/user/SadTalker"
12
 
13
  def setup_sadtalker():
 
20
  SADTALKER_DIR
21
  ], check=True)
22
 
23
+ # Install SadTalker requirements
 
 
 
 
24
  subprocess.run([
25
+ sys.executable, "-m", "pip", "install", "-r",
26
+ f"{SADTALKER_DIR}/requirements.txt"
27
  ], check=True)
28
 
29
+ # Download checkpoints from HuggingFace
30
+ print("Downloading checkpoints...")
31
  from huggingface_hub import snapshot_download
32
  snapshot_download(
33
  repo_id="vinthony/SadTalker",
 
37
 
38
  return True
39
 
40
+ @spaces.GPU(duration=120) # Request GPU for up to 120 seconds
41
  def generate_video(image_path: str, audio_path: str) -> str:
42
  """
43
  Generate talking head video from image and audio
44
+ Returns: base64 encoded video
45
  """
46
  setup_sadtalker()
47
 
48
+ # Add SadTalker to path
49
+ if SADTALKER_DIR not in sys.path:
50
+ sys.path.insert(0, SADTALKER_DIR)
51
+
52
  with tempfile.TemporaryDirectory() as tmpdir:
53
  output_dir = os.path.join(tmpdir, "output")
54
  os.makedirs(output_dir, exist_ok=True)
55
 
56
+ # Run SadTalker inference (GPU mode - no --cpu flag)
57
  cmd = [
58
+ sys.executable, f"{SADTALKER_DIR}/inference.py",
59
  "--driven_audio", audio_path,
60
  "--source_image", image_path,
61
  "--result_dir", output_dir,
62
  "--still", # Less movement, faster
63
  "--preprocess", "crop",
64
+ "--enhancer", "gfpgan" # Face enhancement
65
  ]
66
 
67
  print(f"Running: {' '.join(cmd)}")
68
+ result = subprocess.run(
69
+ cmd,
70
+ capture_output=True,
71
+ text=True,
72
+ cwd=SADTALKER_DIR,
73
+ env={**os.environ, "CUDA_VISIBLE_DEVICES": "0"}
74
+ )
75
 
76
  if result.returncode != 0:
77
+ print(f"STDOUT: {result.stdout}")
78
+ print(f"STDERR: {result.stderr}")
79
  raise Exception(f"SadTalker failed: {result.stderr}")
80
 
81
  # Find generated video
 
83
  for f in files:
84
  if f.endswith(".mp4"):
85
  video_path = os.path.join(root, f)
 
86
  with open(video_path, "rb") as vf:
87
  return base64.b64encode(vf.read()).decode("utf-8")
88
 
89
  raise Exception("No video generated")
90
 
91
+ def gradio_generate(image, audio):
92
+ """Gradio interface wrapper"""
93
+ if image is None or audio is None:
94
+ return None
95
+
96
+ with tempfile.TemporaryDirectory() as tmpdir:
97
+ # Save uploaded files
98
+ image_path = os.path.join(tmpdir, "input.png")
99
+ audio_path = audio # Gradio gives us filepath directly
100
+
101
+ # Handle image
102
+ if isinstance(image, str):
103
+ shutil.copy(image, image_path)
104
+ else:
105
+ from PIL import Image
106
+ Image.fromarray(image).save(image_path)
107
+
108
+ try:
109
+ # Generate video
110
+ video_base64 = generate_video(image_path, audio_path)
111
+
112
+ # Save to temp file for Gradio output
113
+ output_path = os.path.join(tmpdir, "output.mp4")
114
+ with open(output_path, "wb") as f:
115
+ f.write(base64.b64decode(video_base64))
116
+
117
+ # Copy to persistent location
118
+ final_path = "/tmp/sadtalker_output.mp4"
119
+ shutil.copy(output_path, final_path)
120
+ return final_path
121
+
122
+ except Exception as e:
123
+ raise gr.Error(f"Generation failed: {str(e)}")
124
+
125
+ # API function for external calls
126
  def api_generate(image_base64: str, audio_base64: str) -> dict:
127
  """API endpoint for generating video"""
128
  try:
 
145
  except Exception as e:
146
  return {"success": False, "error": str(e)}
147
 
148
+ # Create Gradio app
149
+ with gr.Blocks(title="SadTalker API") as demo:
150
+ gr.Markdown("# 🎭 SadTalker API")
151
+ gr.Markdown("Generate talking head videos from image + audio (ZeroGPU)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  with gr.Row():
154
  with gr.Column():
155
  image_input = gr.Image(label="Face Image", type="filepath")
156
  audio_input = gr.Audio(label="Audio", type="filepath")
157
+ generate_btn = gr.Button("🎬 Generate Video", variant="primary", size="lg")
158
 
159
  with gr.Column():
160
+ video_output = gr.Video(label="Generated Video")
161
+ gr.Markdown("⏱️ Takes ~20-40 seconds with GPU")
162
 
163
  generate_btn.click(
164
  fn=gradio_generate,
 
166
  outputs=video_output
167
  )
168
 
 
169
  if __name__ == "__main__":
170
  demo.launch(server_name="0.0.0.0", server_port=7860)
requirements.txt CHANGED
@@ -1,14 +1,14 @@
1
  # Core
2
  gradio==4.44.0
3
  huggingface_hub==0.25.0
 
4
 
5
- # PyTorch CPU
6
- --extra-index-url https://download.pytorch.org/whl/cpu
7
  torch
8
  torchvision
9
  torchaudio
10
 
11
- # SadTalker deps
12
  numpy<2.0.0
13
  scipy
14
  opencv-python-headless
@@ -29,3 +29,4 @@ basicsr
29
  facexlib
30
  kornia
31
  safetensors
 
 
1
  # Core
2
  gradio==4.44.0
3
  huggingface_hub==0.25.0
4
+ spaces
5
 
6
+ # PyTorch CUDA (ZeroGPU will handle this)
 
7
  torch
8
  torchvision
9
  torchaudio
10
 
11
+ # SadTalker deps
12
  numpy<2.0.0
13
  scipy
14
  opencv-python-headless
 
29
  facexlib
30
  kornia
31
  safetensors
32
+ gfpgan