Borio047 commited on
Commit
c23ecfc
·
verified ·
1 Parent(s): c8ea599

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -27
app.py CHANGED
@@ -1,64 +1,87 @@
1
  import gradio as gr
2
- from transformers import pipeline
3
  import numpy as np
4
  import soundfile as sf
5
  import os
6
  import uuid
7
 
8
- # Load TTS pipeline once at startup
9
- TTS_MODEL_ID = "suno/bark-small"
10
- tts = pipeline("text-to-speech", model=TTS_MODEL_ID)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  def generate_speech(text: str) -> str:
13
  """
14
- Takes input text and returns a filepath to a WAV file
15
- for gr.Audio(type="filepath").
 
16
  """
17
  if not text or text.strip() == "":
18
- raise gr.Error("Please enter some text to synthesize 🙂")
 
 
 
 
 
19
 
20
- # Run the model
21
- output = tts(text)
22
 
23
- # Expecting {"audio": np.ndarray or list, "sampling_rate": int}
24
- audio = np.asarray(output["audio"], dtype=np.float32)
25
- sr = int(output["sampling_rate"])
26
 
27
- # Ensure mono or stereo is fine; soundfile can handle it
28
- if audio.ndim > 1:
29
- audio = audio.squeeze()
30
 
31
- # Create a unique temporary path
 
 
 
 
32
  tmp_dir = "/tmp"
33
  os.makedirs(tmp_dir, exist_ok=True)
34
  filename = f"tts_{uuid.uuid4().hex}.wav"
35
  filepath = os.path.join(tmp_dir, filename)
36
 
37
- # Write WAV using soundfile (no pydub, no wave header issues)
38
- sf.write(filepath, audio, sr)
39
 
40
- # Return the path; gr.Audio(type="filepath") will use it directly
41
  return filepath
42
 
 
43
  with gr.Blocks() as demo:
44
- gr.Markdown("# 🗣️ Simple Text-to-Speech Demo (Bark Small)")
45
  gr.Markdown(
46
- "Type some English text, click **Generate speech**, and listen to the audio.\n"
47
- "Model: `suno/bark-small` via 🤗 Transformers TTS pipeline."
48
  )
49
 
50
  with gr.Row():
51
  with gr.Column(scale=2):
52
  text_input = gr.Textbox(
53
- label="Input text",
54
- placeholder="Type something like: Hello, this is my first TTS Space!",
55
- lines=4,
56
  )
57
  generate_button = gr.Button("Generate speech", variant="primary")
58
  with gr.Column(scale=1):
59
  audio_output = gr.Audio(
60
  label="Generated audio",
61
- type="filepath", # we are returning a path string
62
  )
63
 
64
  generate_button.click(
@@ -68,5 +91,4 @@ with gr.Blocks() as demo:
68
  )
69
 
70
  if __name__ == "__main__":
71
- # Disable SSR to avoid async quirks
72
  demo.launch(ssr_mode=False)
 
1
  import gradio as gr
 
2
  import numpy as np
3
  import soundfile as sf
4
  import os
5
  import uuid
6
 
7
+ import torch
8
+ from transformers import VitsModel, VitsTokenizer, set_seed
9
+
10
+ # 1. Load MMS-TTS English model (lighter than Bark)
11
+ MODEL_ID = "facebook/mms-tts-eng"
12
+
13
+ tokenizer = VitsTokenizer.from_pretrained(MODEL_ID)
14
+ model = VitsModel.from_pretrained(MODEL_ID)
15
+
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ model = model.to(device)
18
+
19
+ # Optional: make outputs deterministic
20
+ set_seed(555)
21
+
22
+
23
+ MAX_CHARS = 150 # keep text short for speed and stability
24
+
25
 
26
  def generate_speech(text: str) -> str:
27
  """
28
+ Take text, synthesize speech with MMS-TTS,
29
+ save to a WAV file, and return the filepath
30
+ (for gr.Audio(type="filepath")).
31
  """
32
  if not text or text.strip() == "":
33
+ raise gr.Error("Please enter some text 🙂")
34
+
35
+ text = text.strip()
36
+ if len(text) > MAX_CHARS:
37
+ text = text[:MAX_CHARS]
38
+ # You could also show a warning text if you like.
39
 
40
+ # MMS-TTS is trained on lowercased, unpunctuated text → simple normalization
41
+ normalized_text = text.lower()
42
 
43
+ # 1) Tokenize
44
+ inputs = tokenizer(text=normalized_text, return_tensors="pt").to(device)
 
45
 
46
+ # 2) Forward pass
47
+ with torch.no_grad():
48
+ outputs = model(**inputs)
49
 
50
+ # 3) Get waveform and sampling rate
51
+ waveform = outputs.waveform[0].cpu().numpy().astype(np.float32)
52
+ sr = model.config.sampling_rate # typically 16000
53
+
54
+ # 4) Save to /tmp as WAV
55
  tmp_dir = "/tmp"
56
  os.makedirs(tmp_dir, exist_ok=True)
57
  filename = f"tts_{uuid.uuid4().hex}.wav"
58
  filepath = os.path.join(tmp_dir, filename)
59
 
60
+ sf.write(filepath, waveform, sr)
 
61
 
62
+ # 5) Return file path for gr.Audio(type="filepath")
63
  return filepath
64
 
65
+
66
  with gr.Blocks() as demo:
67
+ gr.Markdown("# 🗣️ Simple TTS with facebook/mms-tts-eng")
68
  gr.Markdown(
69
+ "Type a short English sentence, click **Generate speech**, and listen to the audio.\n\n"
70
+ "Model: `facebook/mms-tts-eng` (MMS-TTS, VITS-based)."
71
  )
72
 
73
  with gr.Row():
74
  with gr.Column(scale=2):
75
  text_input = gr.Textbox(
76
+ label="Text to synthesize",
77
+ placeholder="Example: hello, this is my text-to-speech demo",
78
+ lines=3,
79
  )
80
  generate_button = gr.Button("Generate speech", variant="primary")
81
  with gr.Column(scale=1):
82
  audio_output = gr.Audio(
83
  label="Generated audio",
84
+ type="filepath", # we return a path string
85
  )
86
 
87
  generate_button.click(
 
91
  )
92
 
93
  if __name__ == "__main__":
 
94
  demo.launch(ssr_mode=False)