baenacoco commited on
Commit
2527849
·
verified ·
1 Parent(s): e291f1d

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. README.md +5 -6
  2. app.py +573 -0
  3. hub_utils.py +64 -0
  4. packages.txt +6 -0
  5. requirements.txt +19 -0
README.md CHANGED
@@ -1,12 +1,11 @@
1
  ---
2
- title: Talking Head Generate
3
- emoji: 📊
4
- colorFrom: blue
5
  colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 6.9.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Talking Head - Generate
3
+ emoji: 🎬
4
+ colorFrom: red
5
  colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 5.9.1
8
  app_file: app.py
9
  pinned: false
10
+ hardware: a100-large
11
  ---
 
 
app.py ADDED
@@ -0,0 +1,573 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Space 5: Generate Video (F5-TTS + Flux.1 + MuseTalk)
2
+
3
+ Downloads trained models from Hub -> TTS -> Image gen -> Lip-sync -> saves video to Hub.
4
+ GPU: A100 (Flux.1 image gen + MuseTalk lip-sync)
5
+ """
6
+ import gc
7
+ import json
8
+ import logging
9
+ import os
10
+ import shutil
11
+ import subprocess
12
+ import sys
13
+ import traceback
14
+ from pathlib import Path
15
+
16
+ import gradio as gr
17
+ import numpy as np
18
+ import soundfile as sf
19
+ import torch
20
+
21
+ from hub_utils import download_step, upload_step
22
+
23
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
24
+ logger = logging.getLogger(__name__)
25
+
26
+ # ── Config ──
27
+ IS_HF_SPACE = os.environ.get("SPACE_ID") is not None
28
+ _data_path = Path("/data")
29
+ if IS_HF_SPACE and _data_path.exists() and os.access(_data_path, os.W_OK):
30
+ BASE_DIR = _data_path
31
+ else:
32
+ BASE_DIR = Path("data")
33
+
34
+ VOICE_MODEL_DIR = BASE_DIR / "voice_model"
35
+ LORA_MODEL_DIR = BASE_DIR / "lora_model"
36
+ GENERATED_VIDEO_DIR = BASE_DIR / "generated"
37
+ TEMP_DIR = BASE_DIR / "temp"
38
+ HF_CACHE_DIR = BASE_DIR / "hf_cache"
39
+
40
+ for d in [VOICE_MODEL_DIR, LORA_MODEL_DIR, GENERATED_VIDEO_DIR, TEMP_DIR, HF_CACHE_DIR]:
41
+ d.mkdir(parents=True, exist_ok=True)
42
+
43
+ os.environ["HF_HOME"] = str(HF_CACHE_DIR)
44
+ os.environ["TRANSFORMERS_CACHE"] = str(HF_CACHE_DIR)
45
+
46
+ FLUX_MODEL_ID = "black-forest-labs/FLUX.1-dev"
47
+ F5_SPANISH_MODEL_ID = "jpgallegoar/F5-Spanish"
48
+ MUSETALK_REPO_ID = "TMElyralab/MuseTalk"
49
+ LORA_TRIGGER_WORD = "alvaro_person"
50
+
51
+ IMAGE_WIDTH = 1024
52
+ IMAGE_HEIGHT = 1024
53
+ IMAGE_STEPS = 30
54
+ IMAGE_GUIDANCE = 3.5
55
+ TTS_SPEED = 1.0
56
+ MUSETALK_FPS = 30
57
+ MUSETALK_BBOX_SHIFT = 5
58
+ CHUNK_DURATION_S = 10
59
+ CROSSFADE_DURATION_S = 0.5
60
+
61
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
62
+ APP_VERSION = "1.0.0"
63
+
64
+ # ── Model state ──
65
+ _f5_model = None
66
+ _flux_pipe = None
67
+ MUSETALK_DIR = Path("musetalk_repo")
68
+
69
+
70
+ def _clear_cache():
71
+ gc.collect()
72
+ if torch.cuda.is_available():
73
+ torch.cuda.empty_cache()
74
+ torch.cuda.synchronize()
75
+
76
+
77
+ def _unload_all():
78
+ global _f5_model, _flux_pipe
79
+ if _f5_model is not None:
80
+ del _f5_model
81
+ _f5_model = None
82
+ if _flux_pipe is not None:
83
+ del _flux_pipe
84
+ _flux_pipe = None
85
+ _clear_cache()
86
+
87
+
88
+ # ── FFmpeg utils ──
89
+
90
+ def _ffmpeg_run(cmd, description):
91
+ result = subprocess.run(cmd, capture_output=True, text=True)
92
+ if result.returncode != 0:
93
+ raise RuntimeError(f"FFmpeg failed ({description}): {result.stderr[-500:]}")
94
+
95
+
96
+ def _get_duration(file_path):
97
+ cmd = ["ffprobe", "-v", "error", "-show_entries", "format=duration",
98
+ "-of", "default=noprint_wrappers=1:nokey=1", file_path]
99
+ result = subprocess.run(cmd, capture_output=True, text=True, check=True)
100
+ return float(result.stdout.strip())
101
+
102
+
103
+ def _concat_videos(video_paths, output_path):
104
+ list_file = Path(output_path).parent / "concat_list.txt"
105
+ with open(list_file, "w") as f:
106
+ for vp in video_paths:
107
+ f.write(f"file '{vp}'\n")
108
+ _ffmpeg_run(["ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", str(list_file), "-c", "copy", output_path], "concat")
109
+ list_file.unlink(missing_ok=True)
110
+
111
+
112
+ def _crossfade_videos(v1, v2, output, duration=0.5):
113
+ dur1 = _get_duration(v1)
114
+ offset = dur1 - duration
115
+ _ffmpeg_run([
116
+ "ffmpeg", "-y", "-i", v1, "-i", v2,
117
+ "-filter_complex", f"[0:v][1:v]xfade=transition=fade:duration={duration}:offset={offset}[v]",
118
+ "-map", "[v]", "-c:v", "libx264", "-pix_fmt", "yuv420p", output,
119
+ ], "crossfade")
120
+
121
+
122
+ def _mux_audio_video(video, audio, output):
123
+ _ffmpeg_run([
124
+ "ffmpeg", "-y", "-i", video, "-i", audio,
125
+ "-c:v", "copy", "-c:a", "aac", "-b:a", "192k",
126
+ "-map", "0:v:0", "-map", "1:a:0", "-shortest", output,
127
+ ], "mux")
128
+
129
+
130
+ # ── TTS ──
131
+
132
+ def _load_tts():
133
+ global _f5_model
134
+ if _f5_model is not None:
135
+ return
136
+
137
+ _unload_all()
138
+ from f5_tts.api import F5TTS
139
+
140
+ finetuned_path = VOICE_MODEL_DIR / "model_last.pt"
141
+ if not finetuned_path.exists():
142
+ checkpoints = list(VOICE_MODEL_DIR.glob("*.pt")) + list(VOICE_MODEL_DIR.glob("*.safetensors"))
143
+ finetuned_path = checkpoints[0] if checkpoints else None
144
+
145
+ if finetuned_path and finetuned_path.exists():
146
+ logger.info(f"Loading fine-tuned F5-TTS from {finetuned_path}")
147
+ _f5_model = F5TTS(model_path=str(finetuned_path), device=DEVICE)
148
+ else:
149
+ logger.info(f"Loading base F5-Spanish from {F5_SPANISH_MODEL_ID}")
150
+ _f5_model = F5TTS(model_name=F5_SPANISH_MODEL_ID, device=DEVICE)
151
+ logger.info("F5-TTS loaded")
152
+
153
+
154
+ def _get_reference_audio():
155
+ ref = VOICE_MODEL_DIR / "reference.wav"
156
+ if ref.exists():
157
+ return str(ref)
158
+ raise FileNotFoundError("No reference audio found. Download voice model first.")
159
+
160
+
161
+ def generate_speech(text, output_path=None):
162
+ _load_tts()
163
+ ref_audio = _get_reference_audio()
164
+ if output_path is None:
165
+ output_path = str(TEMP_DIR / "tts_output.wav")
166
+ audio, sr = _f5_model.infer(ref_file=ref_audio, ref_text="", gen_text=text, speed=TTS_SPEED)
167
+ sf.write(output_path, audio, sr)
168
+ logger.info(f"Generated speech: {output_path} ({len(audio)/sr:.1f}s)")
169
+ return output_path
170
+
171
+
172
+ def _unload_tts():
173
+ global _f5_model
174
+ if _f5_model is not None:
175
+ del _f5_model
176
+ _f5_model = None
177
+ _clear_cache()
178
+
179
+
180
+ # ── Image generation ──
181
+
182
+ def _load_flux():
183
+ global _flux_pipe
184
+ if _flux_pipe is not None:
185
+ return
186
+
187
+ _unload_tts()
188
+
189
+ from diffusers import FluxPipeline
190
+
191
+ logger.info(f"Loading Flux.1 from {FLUX_MODEL_ID}...")
192
+ _flux_pipe = FluxPipeline.from_pretrained(
193
+ FLUX_MODEL_ID, torch_dtype=torch.bfloat16,
194
+ token=os.environ.get("HF_TOKEN"),
195
+ ).to(DEVICE)
196
+
197
+ lora_weights = list(LORA_MODEL_DIR.glob("*.safetensors"))
198
+ if not lora_weights:
199
+ lora_weights = list(LORA_MODEL_DIR.glob("adapter_model.*"))
200
+ if lora_weights:
201
+ try:
202
+ _flux_pipe.load_lora_weights(str(LORA_MODEL_DIR))
203
+ logger.info("LoRA weights loaded")
204
+ except Exception as e:
205
+ logger.warning(f"Could not load LoRA: {e}")
206
+
207
+ _flux_pipe.enable_model_cpu_offload()
208
+ logger.info("Flux.1 pipeline loaded")
209
+
210
+
211
+ def _unload_flux():
212
+ global _flux_pipe
213
+ if _flux_pipe is not None:
214
+ del _flux_pipe
215
+ _flux_pipe = None
216
+ _clear_cache()
217
+
218
+
219
+ def generate_image(prompt, num_steps, guidance_scale, seed, output_path=None):
220
+ _load_flux()
221
+
222
+ config_path = LORA_MODEL_DIR / "lora_config.json"
223
+ trigger = LORA_TRIGGER_WORD
224
+ if config_path.exists():
225
+ with open(config_path) as f:
226
+ trigger = json.load(f).get("trigger_word", LORA_TRIGGER_WORD)
227
+
228
+ if trigger and trigger not in prompt:
229
+ prompt = f"{trigger}, {prompt}"
230
+
231
+ generator = None
232
+ if seed >= 0:
233
+ generator = torch.Generator(device=DEVICE).manual_seed(seed)
234
+
235
+ if output_path is None:
236
+ output_path = str(TEMP_DIR / "generated_avatar.png")
237
+
238
+ result = _flux_pipe(
239
+ prompt=prompt, width=IMAGE_WIDTH, height=IMAGE_HEIGHT,
240
+ num_inference_steps=num_steps, guidance_scale=guidance_scale,
241
+ generator=generator,
242
+ )
243
+ result.images[0].save(output_path)
244
+ logger.info(f"Image saved: {output_path}")
245
+ return output_path
246
+
247
+
248
+ # ── MuseTalk lip-sync ──
249
+
250
+ def _ensure_mm_packages():
251
+ try:
252
+ import mmcv
253
+ return
254
+ except ImportError:
255
+ pass
256
+ logger.info("Installing mmcv, mmdet, mmpose via mim...")
257
+ for pkg in ["mmengine", "mmcv>=2.0.0", "mmdet>=3.1.0", "mmpose>=1.1.0"]:
258
+ subprocess.run([sys.executable, "-m", "mim", "install", pkg],
259
+ capture_output=True, text=True, timeout=600)
260
+
261
+
262
+ def _ensure_musetalk():
263
+ _ensure_mm_packages()
264
+ if not MUSETALK_DIR.exists():
265
+ logger.info("Cloning MuseTalk repository...")
266
+ try:
267
+ subprocess.run(
268
+ ["git", "clone", "https://github.com/TMElyralab/MuseTalk.git", str(MUSETALK_DIR)],
269
+ capture_output=True, text=True, timeout=300, check=True,
270
+ )
271
+ except Exception:
272
+ from huggingface_hub import snapshot_download
273
+ snapshot_download(repo_id=MUSETALK_REPO_ID, local_dir=str(MUSETALK_DIR), repo_type="model")
274
+ _download_musetalk_models()
275
+
276
+
277
+ def _download_musetalk_models():
278
+ from huggingface_hub import hf_hub_download
279
+ models = [
280
+ ("TMElyralab/MuseTalk", "models/musetalk/musetalk.json"),
281
+ ("TMElyralab/MuseTalk", "models/musetalk/pytorch_model.bin"),
282
+ ("TMElyralab/MuseTalk", "models/dwpose/dw-ll_ucoco_384.onnx"),
283
+ ("TMElyralab/MuseTalk", "models/face-parse-bisenet/79999_iter.pth"),
284
+ ("TMElyralab/MuseTalk", "models/sd-vae-ft-mse/config.json"),
285
+ ("TMElyralab/MuseTalk", "models/sd-vae-ft-mse/diffusion_pytorch_model.bin"),
286
+ ("TMElyralab/MuseTalk", "models/whisper/tiny.pt"),
287
+ ]
288
+ for repo_id, filename in models:
289
+ local_path = MUSETALK_DIR / filename
290
+ if not local_path.exists():
291
+ try:
292
+ hf_hub_download(repo_id=repo_id, filename=filename, local_dir=str(MUSETALK_DIR))
293
+ except Exception as e:
294
+ logger.warning(f"Could not download {filename}: {e}")
295
+
296
+
297
+ def _generate_lipsync(image_path, audio_path, output_path, bbox_shift):
298
+ _unload_all()
299
+ _ensure_musetalk()
300
+
301
+ # Try Python API
302
+ try:
303
+ sys.path.insert(0, str(MUSETALK_DIR))
304
+ from musetalk.models.musetalk import MuseTalk
305
+ model = MuseTalk()
306
+ model.load_model(str(MUSETALK_DIR / "models"))
307
+ result = model.inference(
308
+ video_path=image_path, audio_path=audio_path,
309
+ bbox_shift=bbox_shift, result_dir=str(Path(output_path).parent),
310
+ )
311
+ if result and Path(result).exists():
312
+ if str(result) != output_path:
313
+ shutil.move(result, output_path)
314
+ return output_path
315
+ except Exception as e:
316
+ logger.warning(f"Python MuseTalk failed: {e}, trying CLI...")
317
+
318
+ # Fallback to CLI
319
+ result_dir = TEMP_DIR / "musetalk_output"
320
+ result_dir.mkdir(parents=True, exist_ok=True)
321
+ cmd = [
322
+ sys.executable, "-m", "scripts.inference",
323
+ "--video_path", image_path, "--audio_path", audio_path,
324
+ "--bbox_shift", str(bbox_shift), "--result_dir", str(result_dir),
325
+ "--fps", str(MUSETALK_FPS), "--batch_size", "8",
326
+ ]
327
+ env = os.environ.copy()
328
+ env["PYTHONPATH"] = str(MUSETALK_DIR) + ":" + env.get("PYTHONPATH", "")
329
+ proc = subprocess.run(cmd, capture_output=True, text=True, cwd=str(MUSETALK_DIR), env=env, timeout=1800)
330
+ if proc.returncode != 0:
331
+ raise RuntimeError(f"MuseTalk failed: {proc.stderr[-500:]}")
332
+ outputs = sorted(result_dir.glob("**/*.mp4"), key=lambda p: p.stat().st_mtime, reverse=True)
333
+ if not outputs:
334
+ raise RuntimeError("MuseTalk did not produce output")
335
+ shutil.move(str(outputs[0]), output_path)
336
+ shutil.rmtree(result_dir, ignore_errors=True)
337
+ return output_path
338
+
339
+
340
+ # ── Video composition ──
341
+
342
+ def _find_silence_boundaries(audio, sr, chunk_duration):
343
+ from pydub import AudioSegment
344
+ from pydub.silence import detect_silence
345
+ temp_path = str(TEMP_DIR / "_temp_silence.wav")
346
+ sf.write(temp_path, audio, sr)
347
+ sound = AudioSegment.from_wav(temp_path)
348
+ silences = detect_silence(sound, min_silence_len=300, silence_thresh=-35)
349
+ total_duration = len(audio) / sr
350
+ boundaries = [0.0]
351
+ current = 0.0
352
+ while current + chunk_duration < total_duration:
353
+ target = current + chunk_duration
354
+ best_split = target
355
+ best_dist = float("inf")
356
+ for start_ms, end_ms in silences:
357
+ mid = (start_ms + end_ms) / 2000.0
358
+ if current + 3.0 < mid < total_duration - 1.0:
359
+ dist = abs(mid - target)
360
+ if dist < best_dist:
361
+ best_dist = dist
362
+ best_split = mid
363
+ boundaries.append(best_split)
364
+ current = best_split
365
+ boundaries.append(total_duration)
366
+ Path(temp_path).unlink(missing_ok=True)
367
+ return boundaries
368
+
369
+
370
+ def compose_long_video(image_path, audio_path, output_path, bbox_shift, progress_callback=None):
371
+ audio, sr = sf.read(audio_path)
372
+ if audio.ndim > 1:
373
+ audio = audio.mean(axis=1)
374
+ total_duration = len(audio) / sr
375
+
376
+ if total_duration <= CHUNK_DURATION_S * 1.5:
377
+ if progress_callback:
378
+ progress_callback(0.1, "Generando lip-sync...")
379
+ return _generate_lipsync(image_path, audio_path, output_path, bbox_shift)
380
+
381
+ work_dir = TEMP_DIR / "compose_work"
382
+ if work_dir.exists():
383
+ shutil.rmtree(work_dir)
384
+ work_dir.mkdir(parents=True)
385
+
386
+ if progress_callback:
387
+ progress_callback(0.05, "Buscando puntos de corte...")
388
+ boundaries = _find_silence_boundaries(audio, sr, CHUNK_DURATION_S)
389
+ n_chunks = len(boundaries) - 1
390
+
391
+ chunk_videos = []
392
+ for i in range(n_chunks):
393
+ if progress_callback:
394
+ progress_callback(0.1 + (i / n_chunks) * 0.7, f"Generando chunk {i+1}/{n_chunks}...")
395
+ start_sample = int(boundaries[i] * sr)
396
+ end_sample = int(boundaries[i + 1] * sr)
397
+ chunk_audio = audio[start_sample:end_sample]
398
+ chunk_audio_path = str(work_dir / f"chunk_{i:03d}.wav")
399
+ sf.write(chunk_audio_path, chunk_audio, sr)
400
+ chunk_video_path = str(work_dir / f"chunk_{i:03d}.mp4")
401
+ _generate_lipsync(image_path, chunk_audio_path, chunk_video_path, bbox_shift)
402
+ chunk_videos.append(chunk_video_path)
403
+
404
+ if progress_callback:
405
+ progress_callback(0.85, "Componiendo video final...")
406
+
407
+ if len(chunk_videos) == 1:
408
+ final_video = chunk_videos[0]
409
+ elif CROSSFADE_DURATION_S > 0:
410
+ current = chunk_videos[0]
411
+ for i in range(1, len(chunk_videos)):
412
+ merged = str(work_dir / f"merged_{i:03d}.mp4")
413
+ try:
414
+ _crossfade_videos(current, chunk_videos[i], merged, CROSSFADE_DURATION_S)
415
+ current = merged
416
+ except Exception:
417
+ _concat_videos([current, chunk_videos[i]], merged)
418
+ current = merged
419
+ final_video = current
420
+ else:
421
+ final_video = str(work_dir / "concat.mp4")
422
+ _concat_videos(chunk_videos, final_video)
423
+
424
+ _mux_audio_video(final_video, audio_path, output_path)
425
+ shutil.rmtree(work_dir, ignore_errors=True)
426
+ return output_path
427
+
428
+
429
+ # ── Gradio handlers ──
430
+
431
+ def download_models_from_hub(project_name, progress=gr.Progress()):
432
+ if not project_name or not project_name.strip():
433
+ return "Error: Debes introducir un nombre de proyecto"
434
+ name = project_name.strip()
435
+ try:
436
+ status_parts = []
437
+
438
+ # Download voice model
439
+ if VOICE_MODEL_DIR.exists():
440
+ shutil.rmtree(VOICE_MODEL_DIR)
441
+ VOICE_MODEL_DIR.mkdir(parents=True)
442
+ download_step(name, "step3_voice", str(BASE_DIR))
443
+ src = BASE_DIR / name / "step3_voice"
444
+ if src.exists():
445
+ for f in src.iterdir():
446
+ shutil.move(str(f), str(VOICE_MODEL_DIR / f.name))
447
+ status_parts.append("voz")
448
+
449
+ # Download LoRA model
450
+ if LORA_MODEL_DIR.exists():
451
+ shutil.rmtree(LORA_MODEL_DIR)
452
+ LORA_MODEL_DIR.mkdir(parents=True)
453
+ download_step(name, "step4_lora", str(BASE_DIR))
454
+ src = BASE_DIR / name / "step4_lora"
455
+ if src.exists():
456
+ for f in src.iterdir():
457
+ shutil.move(str(f), str(LORA_MODEL_DIR / f.name))
458
+ status_parts.append("LoRA")
459
+
460
+ shutil.rmtree(BASE_DIR / name, ignore_errors=True)
461
+ return f"OK - Descargados modelos: {', '.join(status_parts)}"
462
+ except Exception as e:
463
+ return f"Error: {e}"
464
+
465
+
466
+ def generate_video_handler(
467
+ project_name, text, scene_prompt, bbox_shift,
468
+ img_steps, guidance, seed, progress=gr.Progress(),
469
+ ):
470
+ if not project_name or not project_name.strip():
471
+ return None, "Error: Debes introducir un nombre de proyecto"
472
+ if not text.strip():
473
+ return None, "Error: Introduce texto para hablar"
474
+
475
+ logger.info(f"=== Video Generation Started === text='{text[:50]}...'")
476
+
477
+ try:
478
+ # Step 1: TTS
479
+ progress(0.0, desc="Generando voz con TTS...")
480
+ audio_path = generate_speech(text)
481
+
482
+ # Step 2: Image generation
483
+ progress(0.2, desc="Generando imagen con Flux.1 + LoRA...")
484
+ image_path = generate_image(
485
+ prompt=scene_prompt, num_steps=int(img_steps),
486
+ guidance_scale=guidance, seed=int(seed),
487
+ )
488
+
489
+ # Unload Flux before MuseTalk
490
+ _unload_flux()
491
+
492
+ # Step 3: Lip-sync
493
+ progress(0.4, desc="Generando lip-sync con MuseTalk...")
494
+ output_path = str(GENERATED_VIDEO_DIR / "final_output.mp4")
495
+ compose_long_video(
496
+ image_path=image_path, audio_path=audio_path,
497
+ output_path=output_path, bbox_shift=int(bbox_shift),
498
+ progress_callback=lambda p, m: progress(0.4 + p * 0.6, desc=m),
499
+ )
500
+
501
+ logger.info("=== Video Generation Complete ===")
502
+ return output_path, "OK - Video generado!"
503
+
504
+ except Exception as e:
505
+ logger.error(f"=== Video Generation Failed ===\n{traceback.format_exc()}")
506
+ return None, f"Error: {e}"
507
+
508
+
509
+ def save_to_hub(project_name):
510
+ if not project_name or not project_name.strip():
511
+ return "Error: Debes introducir un nombre de proyecto"
512
+ name = project_name.strip()
513
+ videos = list(GENERATED_VIDEO_DIR.glob("*.mp4"))
514
+ if not videos:
515
+ return "Error: No hay video para guardar."
516
+ try:
517
+ return upload_step(name, "step5_video", str(GENERATED_VIDEO_DIR))
518
+ except Exception as e:
519
+ return f"Error: {e}"
520
+
521
+
522
+ # ── UI ──
523
+
524
+ with gr.Blocks(title="Talking Head - Generate", theme=gr.themes.Soft()) as demo:
525
+ gr.Markdown(f"# Talking Head - Generar Video `v{APP_VERSION}`\nTTS + Imagen + Lip-sync con modelos entrenados")
526
+
527
+ project_name = gr.Textbox(
528
+ label="Nombre del proyecto",
529
+ placeholder="mi_proyecto",
530
+ info="Obligatorio. Se usa como carpeta en el Hub.",
531
+ )
532
+
533
+ gr.Markdown("### 1. Descargar modelos del Hub")
534
+ download_btn = gr.Button("Descargar modelos del Hub", variant="secondary")
535
+ download_status = gr.Textbox(label="Estado descarga", interactive=False)
536
+
537
+ gr.Markdown("### 2. Generar video")
538
+ with gr.Row():
539
+ with gr.Column():
540
+ text_input = gr.Textbox(
541
+ label="Texto a hablar (espanol)",
542
+ placeholder="Hola, soy un avatar digital hiperrealista...",
543
+ lines=4,
544
+ )
545
+ scene_prompt = gr.Textbox(
546
+ label="Prompt de escena",
547
+ value="portrait photo, professional lighting, neutral background",
548
+ )
549
+ with gr.Row():
550
+ bbox_shift = gr.Slider(-20, 20, value=MUSETALK_BBOX_SHIFT, step=1, label="Bbox Shift")
551
+ img_steps = gr.Slider(10, 50, value=IMAGE_STEPS, step=5, label="Image Steps")
552
+ with gr.Row():
553
+ guidance = gr.Slider(1.0, 10.0, value=IMAGE_GUIDANCE, step=0.5, label="Guidance Scale")
554
+ seed_input = gr.Number(value=-1, label="Seed (-1=random)")
555
+ gen_btn = gr.Button("Generar Video", variant="primary")
556
+ with gr.Column():
557
+ video_output = gr.Video(label="Video generado")
558
+ gen_status = gr.Textbox(label="Estado", interactive=False)
559
+
560
+ gr.Markdown("### 3. Guardar video en Hub")
561
+ save_btn = gr.Button("Guardar en Hub", variant="secondary")
562
+ save_status = gr.Textbox(label="Estado guardado", interactive=False)
563
+
564
+ download_btn.click(download_models_from_hub, inputs=[project_name], outputs=[download_status])
565
+ gen_btn.click(
566
+ generate_video_handler,
567
+ inputs=[project_name, text_input, scene_prompt, bbox_shift, img_steps, guidance, seed_input],
568
+ outputs=[video_output, gen_status],
569
+ )
570
+ save_btn.click(save_to_hub, inputs=[project_name], outputs=[save_status])
571
+
572
+ if __name__ == "__main__":
573
+ demo.queue().launch(server_name="0.0.0.0", server_port=7860)
hub_utils.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hub utilities for uploading/downloading step data to HF Dataset repo."""
2
+ import os
3
+ import logging
4
+ from pathlib import Path
5
+ from huggingface_hub import HfApi, hf_hub_download, list_repo_tree
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ HF_DATASET_REPO_ID = "baenacoco/talking-head-avatar"
10
+
11
+
12
+ def _get_api():
13
+ token = os.environ.get("HF_TOKEN")
14
+ if not token:
15
+ raise ValueError("HF_TOKEN no encontrado en variables de entorno")
16
+ api = HfApi(token=token)
17
+ api.create_repo(repo_id=HF_DATASET_REPO_ID, repo_type="dataset", exist_ok=True)
18
+ return api
19
+
20
+
21
+ def upload_step(name: str, step_folder: str, local_dir: str):
22
+ """Upload a local directory to {name}/{step_folder}/ in the dataset repo."""
23
+ api = _get_api()
24
+ api.upload_folder(
25
+ folder_path=local_dir,
26
+ path_in_repo=f"{name}/{step_folder}",
27
+ repo_id=HF_DATASET_REPO_ID,
28
+ repo_type="dataset",
29
+ )
30
+ logger.info(f"Uploaded {local_dir} -> {name}/{step_folder}")
31
+ return f"Subido a Hub: {name}/{step_folder}"
32
+
33
+
34
+ def download_step(name: str, step_folder: str, local_dir: str):
35
+ """Download {name}/{step_folder}/ from the dataset repo to a local directory."""
36
+ from huggingface_hub import snapshot_download
37
+ token = os.environ.get("HF_TOKEN")
38
+ snapshot_download(
39
+ repo_id=HF_DATASET_REPO_ID,
40
+ repo_type="dataset",
41
+ local_dir=local_dir,
42
+ allow_patterns=[f"{name}/{step_folder}/**"],
43
+ token=token,
44
+ )
45
+ logger.info(f"Downloaded {name}/{step_folder} -> {local_dir}")
46
+ return f"Descargado de Hub: {name}/{step_folder}"
47
+
48
+
49
+ def list_projects() -> list[str]:
50
+ """List project names (top-level folders) in the dataset repo."""
51
+ token = os.environ.get("HF_TOKEN")
52
+ try:
53
+ api = HfApi(token=token)
54
+ entries = list(api.list_repo_tree(
55
+ repo_id=HF_DATASET_REPO_ID, repo_type="dataset", path_in_repo="",
56
+ ))
57
+ return sorted(set(
58
+ e.rfilename.split("/")[0] if hasattr(e, "rfilename") else e.path.split("/")[0]
59
+ for e in entries
60
+ if ("/" in getattr(e, "rfilename", "")) or hasattr(e, "path")
61
+ ))
62
+ except Exception as e:
63
+ logger.warning(f"Could not list projects: {e}")
64
+ return []
packages.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ ffmpeg
2
+ libgl1-mesa-glx
3
+ libglib2.0-0
4
+ libsm6
5
+ libxext6
6
+ libxrender-dev
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ setuptools>=69.0.0
2
+ gradio>=5.9.1
3
+ torch>=2.1.0
4
+ torchaudio>=2.1.0
5
+ torchvision>=0.16.0
6
+ transformers>=4.36.0,<5.0.0
7
+ diffusers>=0.25.0
8
+ accelerate>=0.25.0
9
+ safetensors>=0.4.0
10
+ peft>=0.7.0
11
+ huggingface_hub>=0.20.0
12
+ numpy>=1.24.0
13
+ Pillow>=10.0.0
14
+ soundfile>=0.12.0
15
+ pydub>=0.25.1
16
+ f5-tts>=0.3.0
17
+ sentencepiece>=0.1.99
18
+ protobuf>=3.20.0
19
+ openmim>=0.3.9