Torchem commited on
Commit
589f3c3
Β·
verified Β·
1 Parent(s): d49b8df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -97
app.py CHANGED
@@ -1,134 +1,105 @@
1
  import os
2
  import uuid
3
  import subprocess
4
- import glob
5
-
6
  import gradio as gr
7
  from PIL import Image
 
 
 
8
 
9
- # -------------------------
10
- # Download/prepare SadTalker
11
- # -------------------------
12
- from download_sadtalker_models import ensure_sadtalker
13
- ensure_sadtalker()
14
-
15
- # -------------------------
16
- # Bark (latest API)
17
- # -------------------------
18
- from bark import SAMPLE_RATE, generate_audio, preload_models
19
-
20
- # Load Bark SMALL (change to "large" later for upgrade)
21
- preload_models(model_type="small")
22
 
 
 
 
23
  RESULTS_DIR = "results"
24
  os.makedirs(RESULTS_DIR, exist_ok=True)
25
 
 
 
 
 
 
 
 
 
26
 
27
- # -----------------------------------------
28
- # Available Bark Male Voice Presets
29
- # -----------------------------------------
30
- BARK_MALE_VOICES = [
31
- "male_voice_1",
32
- "male_voice_2",
33
- "angry_male",
34
- "male_broadcast",
35
- "male_baritone",
36
- "us_male_0",
37
- "us_male_1",
38
- "male_host",
39
- "old_male",
40
- "rough_male",
41
- "male_voice_young",
42
- "announcer"
43
- ]
44
-
45
-
46
- # -------------------------
47
- # Generate Bark audio
48
- # -------------------------
49
  def generate_tts(script: str, speaker: str):
50
- """Generate a WAV file using Bark."""
51
- audio_path = os.path.join(RESULTS_DIR, f"audio_{uuid.uuid4().hex}.wav")
52
 
53
- # Bark’s updated API
54
- audio_array = generate_audio(
55
- text=script,
56
- speaker=speaker
57
- )
58
 
59
- import soundfile as sf
60
- sf.write(audio_path, audio_array, SAMPLE_RATE)
61
 
62
- return audio_path
 
 
63
 
64
 
65
- # -------------------------
66
- # Run SadTalker
67
- # -------------------------
68
- def run_sadtalker(image: Image.Image, audio_path: str):
69
- """Run SadTalker to generate a talking-head video."""
70
  img_path = os.path.join(RESULTS_DIR, f"torch_{uuid.uuid4().hex}.png")
71
- image.save(img_path)
72
 
73
- sadtalker_results = os.path.join("SadTalker", "results")
74
- os.makedirs(sadtalker_results, exist_ok=True)
75
 
76
  cmd = [
77
  "python", "inference.py",
78
- "--driven_audio", os.path.abspath(audio_path),
79
- "--source_image", os.path.abspath(img_path),
80
- "--result_dir", os.path.abspath(sadtalker_results),
81
- "--preprocess", "full",
82
- "--still"
83
  ]
84
 
85
- subprocess.run(cmd, cwd="SadTalker", check=True)
86
-
87
- mp4_files = glob.glob(os.path.join(sadtalker_results, "**", "*.mp4"), recursive=True)
88
- if not mp4_files:
89
- raise RuntimeError("SadTalker produced no output video.")
90
-
91
- latest = max(mp4_files, key=os.path.getmtime)
92
-
93
- out_path = os.path.join(RESULTS_DIR, f"torch_out_{uuid.uuid4().hex}.mp4")
94
- subprocess.run(["cp", latest, out_path])
95
-
96
- return out_path
97
 
98
 
99
- # -------------------------
100
- # Pipeline
101
- # -------------------------
102
  def pipeline(script, voice, image):
103
  if not script.strip():
104
  raise gr.Error("Script is empty.")
105
  if image is None:
106
- raise gr.Error("No image uploaded.")
107
 
108
  audio = generate_tts(script, voice)
109
- video = run_sadtalker(image, audio)
110
  return video
111
 
112
 
113
- # -------------------------
114
  # Gradio UI
115
- # -------------------------
116
  def build_ui():
117
  with gr.Blocks() as demo:
118
- gr.Markdown("## πŸ”₯ Torch Em β€” Bark TTS + SadTalker Lipsync")
119
 
120
  with gr.Row():
121
  with gr.Column():
122
  script = gr.Textbox(
123
  label="Script",
124
- lines=4,
125
- placeholder="Type any dialogue here…"
126
  )
127
 
128
  voice = gr.Dropdown(
129
- label="Male Voice Preset",
130
- choices=BARK_MALE_VOICES,
131
- value="male_voice_1"
132
  )
133
 
134
  image = gr.Image(
@@ -136,15 +107,15 @@ def build_ui():
136
  type="pil"
137
  )
138
 
139
- generate_btn = gr.Button("Generate Video")
140
 
141
  with gr.Column():
142
- output_video = gr.Video(label="Output")
143
 
144
- generate_btn.click(
145
  pipeline,
146
  inputs=[script, voice, image],
147
- outputs=output_video
148
  )
149
 
150
  return demo
@@ -153,9 +124,9 @@ def build_ui():
153
  demo = build_ui()
154
 
155
 
156
- # -------------------------
157
- # API endpoint for n8n
158
- # -------------------------
159
  from fastapi import FastAPI, UploadFile, Form
160
  from fastapi.responses import FileResponse
161
  import uvicorn
@@ -168,19 +139,18 @@ async def generate_api(
168
  voice: str = Form(...),
169
  image: UploadFile = Form(...)
170
  ):
171
- img_path = f"tmp_{uuid.uuid4().hex}.png"
172
- with open(img_path, "wb") as f:
173
  f.write(await image.read())
174
 
175
- pil_img = Image.open(img_path).convert("RGB")
176
 
177
  audio = generate_tts(script, voice)
178
- video = run_sadtalker(pil_img, audio)
179
 
180
  return FileResponse(video)
181
 
182
 
183
- # Mount Gradio under FastAPI
184
  app = gr.mount_gradio_app(api, demo, path="/")
185
 
186
 
 
1
  import os
2
  import uuid
3
  import subprocess
 
 
4
  import gradio as gr
5
  from PIL import Image
6
+ import torch
7
+ import soundfile as sf
8
+ import numpy as np
9
 
10
+ from transformers import AutoProcessor, AutoModelForTextToWaveform
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ # -------------------------------
13
+ # Setup output folder
14
+ # -------------------------------
15
  RESULTS_DIR = "results"
16
  os.makedirs(RESULTS_DIR, exist_ok=True)
17
 
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+
20
+ # -------------------------------
21
+ # Load Parler-TTS
22
+ # -------------------------------
23
+ model_name = "facebook/parler-tts-mini-en" # HuggingFace-native, stable
24
+ processor = AutoProcessor.from_pretrained(model_name)
25
+ model = AutoModelForTextToWaveform.from_pretrained(model_name).to(device)
26
 
27
+ # -------------------------------
28
+ # Parler male voices (all)
29
+ # -------------------------------
30
+ PARLER_MALE_VOICES = processor.speakers["male"] # all male speakers
31
+
32
+
33
+ # -------------------------------
34
+ # TTS function
35
+ # -------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def generate_tts(script: str, speaker: str):
37
+ inputs = processor(text=script, speaker=speaker, return_tensors="pt").to(device)
 
38
 
39
+ with torch.no_grad():
40
+ audio_values = model.generate(**inputs)
 
 
 
41
 
42
+ audio = audio_values.cpu().numpy().squeeze()
43
+ sample_rate = model.config.sampling_rate
44
 
45
+ out_path = os.path.join(RESULTS_DIR, f"audio_{uuid.uuid4().hex}.wav")
46
+ sf.write(out_path, audio, sample_rate)
47
+ return out_path
48
 
49
 
50
+ # -------------------------------
51
+ # Wav2Lip function
52
+ # -------------------------------
53
+ def run_wav2lip(image: Image.Image, audio_path: str):
 
54
  img_path = os.path.join(RESULTS_DIR, f"torch_{uuid.uuid4().hex}.png")
55
+ video_out = os.path.join(RESULTS_DIR, f"torch_out_{uuid.uuid4().hex}.mp4")
56
 
57
+ image.save(img_path)
 
58
 
59
  cmd = [
60
  "python", "inference.py",
61
+ "--face", img_path,
62
+ "--audio", audio_path,
63
+ "--outfile", video_out
 
 
64
  ]
65
 
66
+ subprocess.run(cmd, cwd="Wav2Lip", check=True)
67
+ return video_out
 
 
 
 
 
 
 
 
 
 
68
 
69
 
70
+ # -------------------------------
71
+ # Full pipeline
72
+ # -------------------------------
73
  def pipeline(script, voice, image):
74
  if not script.strip():
75
  raise gr.Error("Script is empty.")
76
  if image is None:
77
+ raise gr.Error("Upload your Torch PNG first.")
78
 
79
  audio = generate_tts(script, voice)
80
+ video = run_wav2lip(image, audio)
81
  return video
82
 
83
 
84
+ # -------------------------------
85
  # Gradio UI
86
+ # -------------------------------
87
  def build_ui():
88
  with gr.Blocks() as demo:
89
+ gr.Markdown("## πŸ”₯ Torch Em β€” Parler-TTS + Wav2Lip (Stable)")
90
 
91
  with gr.Row():
92
  with gr.Column():
93
  script = gr.Textbox(
94
  label="Script",
95
+ lines=3,
96
+ placeholder="Enter 2–3 second intro line…"
97
  )
98
 
99
  voice = gr.Dropdown(
100
+ label="Voice",
101
+ choices=PARLER_MALE_VOICES,
102
+ value=PARLER_MALE_VOICES[0]
103
  )
104
 
105
  image = gr.Image(
 
107
  type="pil"
108
  )
109
 
110
+ btn = gr.Button("Generate Video")
111
 
112
  with gr.Column():
113
+ output = gr.Video(label="Output Video")
114
 
115
+ btn.click(
116
  pipeline,
117
  inputs=[script, voice, image],
118
+ outputs=output
119
  )
120
 
121
  return demo
 
124
  demo = build_ui()
125
 
126
 
127
+ # -------------------------------
128
+ # FastAPI endpoint for n8n
129
+ # -------------------------------
130
  from fastapi import FastAPI, UploadFile, Form
131
  from fastapi.responses import FileResponse
132
  import uvicorn
 
139
  voice: str = Form(...),
140
  image: UploadFile = Form(...)
141
  ):
142
+ tmp_img = f"tmp_{uuid.uuid4().hex}.png"
143
+ with open(tmp_img, "wb") as f:
144
  f.write(await image.read())
145
 
146
+ pil_img = Image.open(tmp_img).convert("RGB")
147
 
148
  audio = generate_tts(script, voice)
149
+ video = run_wav2lip(pil_img, audio)
150
 
151
  return FileResponse(video)
152
 
153
 
 
154
  app = gr.mount_gradio_app(api, demo, path="/")
155
 
156