BIBLETUM commited on
Commit
44e7908
·
verified ·
1 Parent(s): 7242553

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +163 -53
app.py CHANGED
@@ -10,17 +10,15 @@ import gradio as gr
10
  OUTDIR = Path("outputs")
11
  OUTDIR.mkdir(parents=True, exist_ok=True)
12
 
13
-
14
  def slug(s: str) -> str:
15
  """Make a safe filename slug (ASCII, underscores)."""
16
  if s is None:
17
  s = ""
18
  return "".join(c if c.isalnum() else "_" for c in s)[:80].strip("_")
19
 
20
-
21
  def save_wav(path: Path, sr: int, audio):
22
  import numpy as np
23
- import scipy.io.wavfile as wav
24
 
25
  if hasattr(audio, "detach"):
26
  audio = audio.detach().cpu().numpy()
@@ -28,13 +26,12 @@ def save_wav(path: Path, sr: int, audio):
28
  a = np.squeeze(a)
29
  if a.ndim == 2 and a.shape[0] < a.shape[1]:
30
  a = a.T
31
- # normalize if needed
32
  max_abs = np.max(np.abs(a)) if a.size else 1.0
33
  if np.isfinite(max_abs) and max_abs > 1.0:
34
  a = a / max_abs
35
  wav.write(str(path), int(sr), a)
36
 
37
-
38
  MODEL_NAMES = {
39
  "suno/bark-small": "bark",
40
  "facebook/mms-tts-rus": "mms",
@@ -44,7 +41,6 @@ MODEL_NAMES = {
44
  _model_cache: Dict[str, object] = {}
45
  _device_hint = "auto"
46
 
47
-
48
  def _load_bark():
49
  from transformers import pipeline
50
  pipe = pipeline("text-to-speech", model="suno/bark-small", device_map=_device_hint)
@@ -57,7 +53,6 @@ def _load_bark():
57
 
58
  return generate
59
 
60
-
61
  def _load_mms():
62
  from transformers import pipeline
63
  pipe = pipeline("text-to-speech", model="facebook/mms-tts-rus", device_map=_device_hint)
@@ -70,7 +65,6 @@ def _load_mms():
70
 
71
  return generate
72
 
73
-
74
  def _load_seamless():
75
  import torch
76
  import numpy as np
@@ -81,7 +75,6 @@ def _load_seamless():
81
 
82
  device = "cuda" if torch.cuda.is_available() else "cpu"
83
 
84
- # КЛЮЧЕВОЕ: use_fast=False, чтобы не требовался tiktoken
85
  proc = AutoProcessor.from_pretrained(
86
  "facebook/seamless-m4t-v2-large",
87
  use_fast=False
@@ -98,7 +91,6 @@ def _load_seamless():
98
 
99
  return generate
100
 
101
-
102
  def get_generator(kind: str):
103
  if kind in _model_cache:
104
  return _model_cache[kind]
@@ -113,25 +105,22 @@ def get_generator(kind: str):
113
  _model_cache[kind] = gen
114
  return gen
115
 
116
-
117
  DEFAULT_PROMPTS = (
118
  "Привет! Это короткий тест русского TTS.\n"
119
  "Сегодня мы проверяем интонации, паузы и четкость дикции.\n"
120
  "Немного сложнее: числа 3.14 и 2025 читаем правильно."
121
  )
122
 
123
-
124
  def run_tts(
125
  prompts_text: str,
126
  split_lines: bool,
127
  model_choice: str,
128
- ) -> tuple:
129
- """Main Gradio callback.
130
-
131
  Returns:
132
- files: list[str] — файловые пути для скачивания
133
- df: pd.DataFrame — таблица с метаданными
134
- last_audio: tuple[int, np.ndarray] | None — предпросмотр последнего файла
135
  """
136
  text_items: List[str] = []
137
  if split_lines:
@@ -147,12 +136,12 @@ def run_tts(
147
  kind = MODEL_NAMES[model_choice]
148
  gen = get_generator(kind)
149
 
150
- stamp_dir = OUTDIR / time.strftime("%Y%m%d-%H%M%S")
151
  stamp_dir.mkdir(parents=True, exist_ok=True)
152
 
153
  rows = []
154
  file_paths: List[str] = []
155
- last_audio_payload = None
156
 
157
  for p in text_items:
158
  t0 = time.time()
@@ -162,6 +151,7 @@ def run_tts(
162
  save_wav(path, sr, audio)
163
 
164
  rows.append({
 
165
  "model": model_choice,
166
  "prompt": p,
167
  "file": str(path),
@@ -169,57 +159,177 @@ def run_tts(
169
  "gen_time_s": round(dt, 3),
170
  })
171
  file_paths.append(str(path))
172
- last_audio_payload = str(path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
  df = pd.DataFrame(rows)
175
- return file_paths, df, last_audio_payload
176
 
177
 
178
- description_md = (
179
  """
180
  Russian TTS Bench: выберите модель и введите один или несколько промптов.\
181
- По умолчанию каждая строка — отдельный промпт. Результаты сохраняются в `outputs/…`.
182
 
183
  **Модели:**
184
  - `suno/bark-small` — небольшой мультиязычный TTS.
185
  - `facebook/mms-tts-rus` — русская TTS из проекта MMS.
186
- - `facebook/seamless-m4t-v2-large` — крупная модель перевода/говорения; тяжёлая для CPU.\
 
 
 
 
 
 
 
 
 
 
 
187
  """
188
  )
189
 
190
- with gr.Blocks(title="Russian TTS Bench") as demo:
191
- gr.Markdown("# 🗣️ Russian TTS Bench")
192
- gr.Markdown(description_md)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
- with gr.Row():
195
- model_choice = gr.Dropdown(
196
- label="Модель",
197
- choices=list(MODEL_NAMES.keys()),
198
- value="suno/bark-small",
199
  )
200
- split_lines = gr.Checkbox(value=True, label="Одна строка = один промпт")
201
 
202
- prompts = gr.Textbox(
203
- label="Промпты",
204
- value=DEFAULT_PROMPTS,
205
- lines=6,
206
- placeholder="Каждая строка — отдельный промпт…",
207
- )
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
- run_btn = gr.Button("Сгенерировать", variant="primary")
210
 
211
- with gr.Row():
212
- files = gr.Files(label="Файлы .wav для скачивания")
213
- with gr.Row():
214
- df_out = gr.Dataframe(label="Таблица результатов", interactive=False)
215
- with gr.Row():
216
- preview = gr.Audio(label="Предпросмотр последнего семпла", autoplay=False)
217
 
218
- run_btn.click(
219
- fn=run_tts,
220
- inputs=[prompts, split_lines, model_choice],
221
- outputs=[files, df_out, preview],
222
- )
223
 
224
  if __name__ == "__main__":
225
- demo.launch()
 
10
  OUTDIR = Path("outputs")
11
  OUTDIR.mkdir(parents=True, exist_ok=True)
12
 
 
13
  def slug(s: str) -> str:
14
  """Make a safe filename slug (ASCII, underscores)."""
15
  if s is None:
16
  s = ""
17
  return "".join(c if c.isalnum() else "_" for c in s)[:80].strip("_")
18
 
 
19
  def save_wav(path: Path, sr: int, audio):
20
  import numpy as np
21
+ from scipy.io import wavfile as wav
22
 
23
  if hasattr(audio, "detach"):
24
  audio = audio.detach().cpu().numpy()
 
26
  a = np.squeeze(a)
27
  if a.ndim == 2 and a.shape[0] < a.shape[1]:
28
  a = a.T
29
+ # normalize if needed (safety)
30
  max_abs = np.max(np.abs(a)) if a.size else 1.0
31
  if np.isfinite(max_abs) and max_abs > 1.0:
32
  a = a / max_abs
33
  wav.write(str(path), int(sr), a)
34
 
 
35
  MODEL_NAMES = {
36
  "suno/bark-small": "bark",
37
  "facebook/mms-tts-rus": "mms",
 
41
  _model_cache: Dict[str, object] = {}
42
  _device_hint = "auto"
43
 
 
44
  def _load_bark():
45
  from transformers import pipeline
46
  pipe = pipeline("text-to-speech", model="suno/bark-small", device_map=_device_hint)
 
53
 
54
  return generate
55
 
 
56
  def _load_mms():
57
  from transformers import pipeline
58
  pipe = pipeline("text-to-speech", model="facebook/mms-tts-rus", device_map=_device_hint)
 
65
 
66
  return generate
67
 
 
68
  def _load_seamless():
69
  import torch
70
  import numpy as np
 
75
 
76
  device = "cuda" if torch.cuda.is_available() else "cpu"
77
 
 
78
  proc = AutoProcessor.from_pretrained(
79
  "facebook/seamless-m4t-v2-large",
80
  use_fast=False
 
91
 
92
  return generate
93
 
 
94
  def get_generator(kind: str):
95
  if kind in _model_cache:
96
  return _model_cache[kind]
 
105
  _model_cache[kind] = gen
106
  return gen
107
 
 
108
  DEFAULT_PROMPTS = (
109
  "Привет! Это короткий тест русского TTS.\n"
110
  "Сегодня мы проверяем интонации, паузы и четкость дикции.\n"
111
  "Немного сложнее: числа 3.14 и 2025 читаем правильно."
112
  )
113
 
 
114
  def run_tts(
115
  prompts_text: str,
116
  split_lines: bool,
117
  model_choice: str,
118
+ ):
119
+ """Main Gradio callback: TTS.
 
120
  Returns:
121
+ files: list[str] — пути к wav
122
+ df: pd.DataFrame — таблица метаданных
123
+ last_audio: str | None — путь к последнему файлу для предпросмотра
124
  """
125
  text_items: List[str] = []
126
  if split_lines:
 
136
  kind = MODEL_NAMES[model_choice]
137
  gen = get_generator(kind)
138
 
139
+ stamp_dir = OUTDIR / "tts" / time.strftime("%Y%m%d-%H%M%S")
140
  stamp_dir.mkdir(parents=True, exist_ok=True)
141
 
142
  rows = []
143
  file_paths: List[str] = []
144
+ last_audio_path = None
145
 
146
  for p in text_items:
147
  t0 = time.time()
 
151
  save_wav(path, sr, audio)
152
 
153
  rows.append({
154
+ "task": "tts",
155
  "model": model_choice,
156
  "prompt": p,
157
  "file": str(path),
 
159
  "gen_time_s": round(dt, 3),
160
  })
161
  file_paths.append(str(path))
162
+ last_audio_path = str(path)
163
+
164
+ df = pd.DataFrame(rows)
165
+ return file_paths, df, last_audio_path
166
+
167
+ _music_pipes: Dict[str, object] = {}
168
+
169
+ MUSIC_MODELS = [
170
+ "facebook/musicgen-small",
171
+ ]
172
+
173
+ def get_music_pipe(model_name: str):
174
+ if model_name in _music_pipes:
175
+ return _music_pipes[model_name]
176
+ from transformers import pipeline
177
+ pipe = pipeline("text-to-audio", model=model_name, device_map=_device_hint)
178
+ _music_pipes[model_name] = pipe
179
+ return pipe
180
+
181
+ MUSIC_DEFAULT_PROMPTS = (
182
+ "High-energy 90s rock track with distorted electric guitars, driving bass, and hard-hitting acoustic drums\n"
183
+ "Modern electronic dance track with punchy kick, bright synth lead, and sidechained pads, 128 BPM\n"
184
+ "Dark industrial electro with gritty bass, sharp snares, and mechanical percussion"
185
+ )
186
+
187
+ def run_music(
188
+ prompts_text: str,
189
+ split_lines: bool,
190
+ model_name: str,
191
+ do_sample: bool,
192
+ ):
193
+ """Main Gradio callback: MusicGen."""
194
+ text_items: List[str] = []
195
+ if split_lines:
196
+ for line in [s.strip() for s in prompts_text.splitlines()]:
197
+ if line:
198
+ text_items.append(line)
199
+ else:
200
+ text_items = [prompts_text.strip()] if prompts_text.strip() else []
201
+
202
+ if not text_items:
203
+ return [], pd.DataFrame(), None
204
+
205
+ pipe = get_music_pipe(model_name)
206
+
207
+ stamp_dir = OUTDIR / "music" / slug(model_name) / time.strftime("%Y%m%d-%H%M%S")
208
+ stamp_dir.mkdir(parents=True, exist_ok=True)
209
+
210
+ rows = []
211
+ file_paths: List[str] = []
212
+ last_audio_path = None
213
+
214
+ for p in text_items:
215
+ t0 = time.time()
216
+ # Параметры генерации держим минимальными и совместимыми
217
+ out = pipe(p, forward_params={"do_sample": bool(do_sample)})
218
+ dt = time.time() - t0
219
+
220
+ sr = int(out["sampling_rate"])
221
+ audio = np.asarray(out["audio"], dtype=np.float32)
222
+
223
+ path = stamp_dir / f"{slug(p)}.wav"
224
+ save_wav(path, sr, audio)
225
+
226
+ rows.append({
227
+ "task": "music",
228
+ "model": model_name,
229
+ "prompt": p,
230
+ "file": str(path),
231
+ "sr": sr,
232
+ "gen_time_s": round(dt, 3),
233
+ })
234
+ file_paths.append(str(path))
235
+ last_audio_path = str(path)
236
 
237
  df = pd.DataFrame(rows)
238
+ return file_paths, df, last_audio_path
239
 
240
 
241
+ tts_description_md = (
242
  """
243
  Russian TTS Bench: выберите модель и введите один или несколько промптов.\
244
+ По умолчанию каждая строка — отдельный промпт. Результаты сохраняются в `outputs/tts/…`.
245
 
246
  **Модели:**
247
  - `suno/bark-small` — небольшой мультиязычный TTS.
248
  - `facebook/mms-tts-rus` — русская TTS из проекта MMS.
249
+ - `facebook/seamless-m4t-v2-large` — крупная модель перевода/говорения; тяжёлая для CPU.
250
+ """
251
+ )
252
+
253
+ music_description_md = (
254
+ """
255
+ **Music Gen:** текст → музыка на базе MusicGen. По умолчанию каждая строка — отдельный промпт.\
256
+ Результаты сохраняются в `outputs/music/<model>/…`.
257
+
258
+ **Модели:**
259
+ - `facebook/musicgen-small`
260
+ - (опционально) `facebook/musicgen-stereo-small` — раскомментируйте в коде.
261
  """
262
  )
263
 
264
+ with gr.Blocks(title="Speech & Music Bench") as demo:
265
+ gr.Markdown("# 🎙️🪄 Speech & Music Bench")
266
+
267
+ with gr.Tab("🗣️ TTS"):
268
+ gr.Markdown(tts_description_md)
269
+
270
+ with gr.Row():
271
+ model_choice = gr.Dropdown(
272
+ label="Модель TTS",
273
+ choices=list(MODEL_NAMES.keys()),
274
+ value="suno/bark-small",
275
+ )
276
+ split_lines_tts = gr.Checkbox(value=True, label="Одна строка = один промпт")
277
+
278
+ prompts_tts = gr.Textbox(
279
+ label="Промпты",
280
+ value=DEFAULT_PROMPTS,
281
+ lines=6,
282
+ placeholder="Каждая строка — отдельный промпт…",
283
+ )
284
+
285
+ run_btn_tts = gr.Button("Сгенерировать речь", variant="primary")
286
+
287
+ with gr.Row():
288
+ files_tts = gr.Files(label="Файлы .wav для скачивания")
289
+ with gr.Row():
290
+ df_out_tts = gr.Dataframe(label="Таблица результатов", interactive=False)
291
+ with gr.Row():
292
+ preview_tts = gr.Audio(label="Предпросмотр последнего семпла", autoplay=False)
293
 
294
+ run_btn_tts.click(
295
+ fn=run_tts,
296
+ inputs=[prompts_tts, split_lines_tts, model_choice],
297
+ outputs=[files_tts, df_out_tts, preview_tts],
 
298
  )
 
299
 
300
+ with gr.Tab("🎵 Music"):
301
+ gr.Markdown(music_description_md)
302
+
303
+ with gr.Row():
304
+ music_model = gr.Dropdown(
305
+ label="Модель MusicGen",
306
+ choices=MUSIC_MODELS,
307
+ value=MUSIC_MODELS[0],
308
+ )
309
+ split_lines_music = gr.Checkbox(value=True, label="Одна строка = один промпт")
310
+ do_sample = gr.Checkbox(value=True, label="do_sample")
311
+
312
+ prompts_music = gr.Textbox(
313
+ label="Музыкальные промпты",
314
+ value=MUSIC_DEFAULT_PROMPTS,
315
+ lines=6,
316
+ placeholder="Каждая строка — отдельный промпт…",
317
+ )
318
 
319
+ run_btn_music = gr.Button("Сгенерировать музыку", variant="primary")
320
 
321
+ with gr.Row():
322
+ files_music = gr.Files(label="Файлы .wav для скачивания")
323
+ with gr.Row():
324
+ df_out_music = gr.Dataframe(label="Таблица результатов", interactive=False)
325
+ with gr.Row():
326
+ preview_music = gr.Audio(label="Предпросмотр последнего трека", autoplay=False)
327
 
328
+ run_btn_music.click(
329
+ fn=run_music,
330
+ inputs=[prompts_music, split_lines_music, music_model, do_sample],
331
+ outputs=[files_music, df_out_music, preview_music],
332
+ )
333
 
334
  if __name__ == "__main__":
335
+ demo.launch()