ollui commited on
Commit
42712e0
·
verified ·
1 Parent(s): 2db380e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -6
app.py CHANGED
@@ -1,24 +1,25 @@
1
  import torch
 
2
  from transformers import AutoProcessor, AutoModelForTextToWaveform
3
  import gradio as gr
4
 
5
  device = "cuda" if torch.cuda.is_available() else "cpu"
6
 
7
- model_id = "facebook/mms-tts-sah" # Yakut TTS
8
  processor = AutoProcessor.from_pretrained(model_id)
9
  model = AutoModelForTextToWaveform.from_pretrained(model_id).to(device)
10
 
11
  def yakut_tts(text):
12
  inputs = processor(text=text, return_tensors="pt").to(device)
13
  with torch.no_grad():
14
- waveform = model.generate(**inputs)
15
- audio = waveform.cpu().numpy()[0]
16
- return (16000, audio)
17
 
18
  gr.Interface(
19
  fn=yakut_tts,
20
- inputs=gr.Textbox(label="Yakut Text", placeholder="Введите текст на якутском..."),
21
- outputs=gr.Audio(label="Generated Audio", type="numpy"), # FIXED type
22
  title="Yakut Text-to-Speech",
23
  description="Enter Yakut (Sakha) text and generate speech using facebook/mms-tts-sah model."
24
  ).launch()
 
1
  import torch
2
+ import torchaudio
3
  from transformers import AutoProcessor, AutoModelForTextToWaveform
4
  import gradio as gr
5
 
6
  device = "cuda" if torch.cuda.is_available() else "cpu"
7
 
8
+ model_id = "facebook/mms-tts-sah"
9
  processor = AutoProcessor.from_pretrained(model_id)
10
  model = AutoModelForTextToWaveform.from_pretrained(model_id).to(device)
11
 
12
  def yakut_tts(text):
13
  inputs = processor(text=text, return_tensors="pt").to(device)
14
  with torch.no_grad():
15
+ outputs = model(**inputs)
16
+ waveform = outputs.waveform.squeeze().cpu()
17
+ return (model.config.sampling_rate, waveform.numpy())
18
 
19
  gr.Interface(
20
  fn=yakut_tts,
21
+ inputs=gr.Textbox(label="Yakut Text", placeholder="Саха тыла"),
22
+ outputs=gr.Audio(label="Generated Audio", type="numpy"),
23
  title="Yakut Text-to-Speech",
24
  description="Enter Yakut (Sakha) text and generate speech using facebook/mms-tts-sah model."
25
  ).launch()