Madras1 commited on
Commit
f4d455b
·
verified ·
1 Parent(s): 4c79f9d

Upload 3 files

Browse files
Files changed (1) hide show
  1. app.py +68 -90
app.py CHANGED
@@ -21,10 +21,11 @@ def setup_sadtalker():
21
  ], check=True)
22
 
23
  # Install SadTalker requirements
 
24
  subprocess.run([
25
- sys.executable, "-m", "pip", "install", "-r",
26
  f"{SADTALKER_DIR}/requirements.txt"
27
- ], check=True)
28
 
29
  # Download checkpoints from HuggingFace
30
  print("Downloading checkpoints...")
@@ -37,113 +38,90 @@ def setup_sadtalker():
37
 
38
  return True
39
 
40
- @spaces.GPU(duration=120) # Request GPU for up to 120 seconds
41
- def generate_video(image_path: str, audio_path: str) -> str:
42
- """
43
- Generate talking head video from image and audio
44
- Returns: base64 encoded video
45
- """
46
  setup_sadtalker()
47
 
48
  # Add SadTalker to path
49
  if SADTALKER_DIR not in sys.path:
50
  sys.path.insert(0, SADTALKER_DIR)
51
 
52
- with tempfile.TemporaryDirectory() as tmpdir:
53
- output_dir = os.path.join(tmpdir, "output")
54
- os.makedirs(output_dir, exist_ok=True)
55
-
56
- # Run SadTalker inference (GPU mode - no --cpu flag)
57
- cmd = [
58
- sys.executable, f"{SADTALKER_DIR}/inference.py",
59
- "--driven_audio", audio_path,
60
- "--source_image", image_path,
61
- "--result_dir", output_dir,
62
- "--still", # Less movement, faster
63
- "--preprocess", "crop",
64
- "--enhancer", "gfpgan" # Face enhancement
65
- ]
66
-
67
- print(f"Running: {' '.join(cmd)}")
68
- result = subprocess.run(
69
- cmd,
70
- capture_output=True,
71
- text=True,
72
- cwd=SADTALKER_DIR,
73
- env={**os.environ, "CUDA_VISIBLE_DEVICES": "0"}
74
- )
75
-
76
- if result.returncode != 0:
77
- print(f"STDOUT: {result.stdout}")
78
- print(f"STDERR: {result.stderr}")
79
- raise Exception(f"SadTalker failed: {result.stderr}")
80
-
81
- # Find generated video
82
- for root, dirs, files in os.walk(output_dir):
83
- for f in files:
84
- if f.endswith(".mp4"):
85
- video_path = os.path.join(root, f)
86
- with open(video_path, "rb") as vf:
87
- return base64.b64encode(vf.read()).decode("utf-8")
88
-
89
- raise Exception("No video generated")
90
 
91
  def gradio_generate(image, audio):
92
  """Gradio interface wrapper"""
93
  if image is None or audio is None:
94
- return None
95
 
96
  with tempfile.TemporaryDirectory() as tmpdir:
97
  # Save uploaded files
98
  image_path = os.path.join(tmpdir, "input.png")
99
- audio_path = audio # Gradio gives us filepath directly
 
 
100
 
101
- # Handle image
102
  if isinstance(image, str):
103
  shutil.copy(image, image_path)
104
  else:
105
  from PIL import Image
106
- Image.fromarray(image).save(image_path)
 
 
 
107
 
108
- try:
109
- # Generate video
110
- video_base64 = generate_video(image_path, audio_path)
111
-
112
- # Save to temp file for Gradio output
113
- output_path = os.path.join(tmpdir, "output.mp4")
114
- with open(output_path, "wb") as f:
115
- f.write(base64.b64decode(video_base64))
116
-
117
- # Copy to persistent location
118
- final_path = "/tmp/sadtalker_output.mp4"
119
- shutil.copy(output_path, final_path)
120
- return final_path
121
-
122
- except Exception as e:
123
- raise gr.Error(f"Generation failed: {str(e)}")
124
-
125
- # API function for external calls
126
- def api_generate(image_base64: str, audio_base64: str) -> dict:
127
- """API endpoint for generating video"""
128
- try:
129
- with tempfile.TemporaryDirectory() as tmpdir:
130
- # Save image
131
- image_path = os.path.join(tmpdir, "input.png")
132
- with open(image_path, "wb") as f:
133
- f.write(base64.b64decode(image_base64))
134
-
135
- # Save audio
136
- audio_path = os.path.join(tmpdir, "input.mp3")
137
- with open(audio_path, "wb") as f:
138
- f.write(base64.b64decode(audio_base64))
139
-
140
- # Generate video
141
- video_base64 = generate_video(image_path, audio_path)
142
-
143
- return {"success": True, "video_base64": video_base64}
144
-
145
- except Exception as e:
146
- return {"success": False, "error": str(e)}
147
 
148
  # Create Gradio app
149
  with gr.Blocks(title="SadTalker API") as demo:
@@ -158,7 +136,7 @@ with gr.Blocks(title="SadTalker API") as demo:
158
 
159
  with gr.Column():
160
  video_output = gr.Video(label="Generated Video")
161
- gr.Markdown("⏱️ Takes ~20-40 seconds with GPU")
162
 
163
  generate_btn.click(
164
  fn=gradio_generate,
 
21
  ], check=True)
22
 
23
  # Install SadTalker requirements
24
+ print("Installing SadTalker requirements...")
25
  subprocess.run([
26
+ sys.executable, "-m", "pip", "install", "-q", "-r",
27
  f"{SADTALKER_DIR}/requirements.txt"
28
+ ])
29
 
30
  # Download checkpoints from HuggingFace
31
  print("Downloading checkpoints...")
 
38
 
39
  return True
40
 
41
+ @spaces.GPU(duration=120)
42
+ def generate_video_gpu(image_path: str, audio_path: str, output_dir: str) -> str:
43
+ """GPU-accelerated video generation"""
 
 
 
44
  setup_sadtalker()
45
 
46
  # Add SadTalker to path
47
  if SADTALKER_DIR not in sys.path:
48
  sys.path.insert(0, SADTALKER_DIR)
49
 
50
+ # Run SadTalker inference
51
+ cmd = [
52
+ sys.executable, f"{SADTALKER_DIR}/inference.py",
53
+ "--driven_audio", audio_path,
54
+ "--source_image", image_path,
55
+ "--result_dir", output_dir,
56
+ "--still",
57
+ "--preprocess", "crop",
58
+ ]
59
+
60
+ print(f"Running: {' '.join(cmd)}")
61
+ result = subprocess.run(
62
+ cmd,
63
+ capture_output=True,
64
+ text=True,
65
+ cwd=SADTALKER_DIR
66
+ )
67
+
68
+ print(f"STDOUT: {result.stdout}")
69
+ if result.stderr:
70
+ print(f"STDERR: {result.stderr}")
71
+
72
+ if result.returncode != 0:
73
+ raise Exception(f"SadTalker failed: {result.stderr}")
74
+
75
+ # Find generated video
76
+ for root, dirs, files in os.walk(output_dir):
77
+ for f in files:
78
+ if f.endswith(".mp4"):
79
+ return os.path.join(root, f)
80
+
81
+ raise Exception("No video generated")
 
 
 
 
 
 
82
 
83
  def gradio_generate(image, audio):
84
  """Gradio interface wrapper"""
85
  if image is None or audio is None:
86
+ raise gr.Error("Por favor, envie uma imagem e um áudio")
87
 
88
  with tempfile.TemporaryDirectory() as tmpdir:
89
  # Save uploaded files
90
  image_path = os.path.join(tmpdir, "input.png")
91
+ audio_path = os.path.join(tmpdir, "input.wav")
92
+ output_dir = os.path.join(tmpdir, "output")
93
+ os.makedirs(output_dir, exist_ok=True)
94
 
95
+ # Handle image - Gradio gives filepath
96
  if isinstance(image, str):
97
  shutil.copy(image, image_path)
98
  else:
99
  from PIL import Image
100
+ if hasattr(image, 'save'):
101
+ image.save(image_path)
102
+ else:
103
+ Image.fromarray(image).save(image_path)
104
 
105
+ # Handle audio - Gradio gives filepath
106
+ if isinstance(audio, str):
107
+ shutil.copy(audio, audio_path)
108
+ elif isinstance(audio, tuple):
109
+ # (sample_rate, audio_data) format
110
+ import scipy.io.wavfile as wav
111
+ sr, data = audio
112
+ wav.write(audio_path, sr, data)
113
+
114
+ print(f"Image: {image_path}, exists: {os.path.exists(image_path)}")
115
+ print(f"Audio: {audio_path}, exists: {os.path.exists(audio_path)}")
116
+
117
+ # Generate video with GPU
118
+ video_path = generate_video_gpu(image_path, audio_path, output_dir)
119
+
120
+ # Copy to persistent location for Gradio
121
+ final_path = "/tmp/sadtalker_output.mp4"
122
+ shutil.copy(video_path, final_path)
123
+
124
+ return final_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  # Create Gradio app
127
  with gr.Blocks(title="SadTalker API") as demo:
 
136
 
137
  with gr.Column():
138
  video_output = gr.Video(label="Generated Video")
139
+ gr.Markdown("⏱️ Takes ~30-60 seconds with GPU")
140
 
141
  generate_btn.click(
142
  fn=gradio_generate,