FresherDifference commited on
Commit
f48358e
·
verified ·
1 Parent(s): 4335f64

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -65
app.py CHANGED
@@ -1,72 +1,101 @@
1
  import gradio as gr
2
- import tempfile
3
- import soundfile as sf
4
-
5
  from pocket_tts import TTSModel
6
 
7
- # -------------------------------------------------
8
- # Load model ONCE
9
- # -------------------------------------------------
10
- model = TTSModel.load_model()
11
-
12
- # -------------------------------------------------
13
- # HF-safe catalog voices
14
- # -------------------------------------------------
15
- VOICES = [
16
- "alba",
17
- "marius",
18
- "javert",
19
- "jean",
20
- "fantine",
21
- "cosette",
22
- "eponine",
23
- "azelma",
24
- ]
25
-
26
- def generate_tts(text, voice):
27
  if not text.strip():
28
- return None
29
-
30
- # Step 1: get model state from catalog voice
31
- state = model.get_state_for_voice(voice)
32
-
33
- # ✅ Step 2: generate audio from state + text
34
- audio = model.generate_audio(state, text)
35
-
36
- tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
37
- sf.write(tmp.name, audio, samplerate=24000)
38
-
39
- return tmp.name
40
-
41
-
42
- with gr.Blocks(title="Pocket TTS (Correct API)") as demo:
43
- gr.Markdown(
44
- """
45
- # 🗣️ Pocket TTS
46
- **HF Spaces compatible – catalog voices**
47
- """
48
- )
49
-
50
- voice_select = gr.Dropdown(
51
- choices=VOICES,
52
- value="alba",
53
- label="Voice"
54
- )
55
-
56
- text_input = gr.Textbox(
57
- label="Text",
58
- lines=4,
59
- placeholder="Type something to hear it spoken"
60
- )
61
-
62
- generate_btn = gr.Button("Generate")
63
-
64
- audio_output = gr.Audio(label="Output")
65
-
66
- generate_btn.click(
67
- fn=generate_tts,
68
- inputs=[text_input, voice_select],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  outputs=audio_output
70
  )
71
 
72
- demo.launch()
 
 
1
  import gradio as gr
2
+ import numpy as np
 
 
3
  from pocket_tts import TTSModel
4
 
5
+ # 1. Load the model once at startup (Global scope)
6
+ # This prevents reloading the 100M parameters on every click, making it much faster.
7
+ print("Loading Pocket-TTS model...")
8
+ tts = TTSModel.load_model()
9
+ print("Model loaded successfully.")
10
+
11
+ # Define some preset voices available in the Kyutai library
12
+ # Note: You can find more voices or exact paths in the kyutai/tts-voices repo
13
+ PRESET_VOICES = {
14
+ "Alba (American English)": "hf://kyutai/tts-voices/alba-mackenna/casual.wav",
15
+ "Marius (French Accent)": "hf://kyutai/tts-voices/marius-reynaud/casual.wav",
16
+ "Jean (Narrator)": "hf://kyutai/tts-voices/jean-dormeuil/casual.wav",
17
+ "Fantine": "hf://kyutai/tts-voices/fantine-chevallier/casual.wav",
18
+ }
19
+
20
+ def generate_speech(text, voice_choice, custom_voice_file):
21
+ """
22
+ Generates audio from text using either a preset voice or a custom uploaded file.
23
+ """
 
24
  if not text.strip():
25
+ raise gr.Error("Please enter some text to generate speech.")
26
+
27
+ # Determine which voice to use
28
+ voice_path = None
29
+
30
+ # Priority: Custom file > Preset selection
31
+ if custom_voice_file is not None:
32
+ print(f"Using custom voice cloning from: {custom_voice_file}")
33
+ voice_path = custom_voice_file
34
+ else:
35
+ print(f"Using preset voice: {voice_choice}")
36
+ voice_path = PRESET_VOICES.get(voice_choice)
37
+
38
+ if not voice_path:
39
+ raise gr.Error("Please select a voice or upload a reference audio file.")
40
+
41
+ # 2. Process the voice prompt
42
+ # This converts the wav file (or HF path) into the conditioning vector
43
+ try:
44
+ voice_state = tts.get_state_for_audio_prompt(voice_path)
45
+ except Exception as e:
46
+ raise gr.Error(f"Error loading voice: {str(e)}")
47
+
48
+ # 3. Generate Audio
49
+ # The output is a torch tensor, we need to convert it to numpy for Gradio
50
+ try:
51
+ audio_tensor = tts.generate_audio(voice_state, text)
52
+ except Exception as e:
53
+ raise gr.Error(f"Generation failed: {str(e)}")
54
+
55
+ # Convert torch tensor to numpy array
56
+ # pocket-tts usually returns (samples,) shape. Gradio expects (sample_rate, data)
57
+ audio_numpy = audio_tensor.numpy()
58
+
59
+ # Return tuple (sample_rate, audio_data)
60
+ return (tts.sample_rate, audio_numpy)
61
+
62
+ # 4. Build the Gradio Interface
63
+ with gr.Blocks(title="Pocket-TTS Demo") as demo:
64
+ gr.Markdown("# 🗣️ Pocket-TTS on CPU")
65
+ gr.Markdown("A lightweight, 100M parameter text-to-speech model that runs purely on CPU.")
66
+
67
+ with gr.Row():
68
+ with gr.Column():
69
+ text_input = gr.Textbox(
70
+ label="Text to Speak",
71
+ placeholder="Type something here...",
72
+ lines=4,
73
+ value="Pocket TTS is amazing because it runs efficiently on consumer hardware!"
74
+ )
75
+
76
+ with gr.Accordion("Voice Settings", open=True):
77
+ voice_dropdown = gr.Dropdown(
78
+ choices=list(PRESET_VOICES.keys()),
79
+ value="Alba (American English)",
80
+ label="Choose a Preset Voice"
81
+ )
82
+ gr.Markdown("**OR**")
83
+ voice_upload = gr.Audio(
84
+ label="Clone a Custom Voice (Upload .wav)",
85
+ type="filepath"
86
+ )
87
+
88
+ submit_btn = gr.Button("Generate Audio", variant="primary")
89
+
90
+ with gr.Column():
91
+ audio_output = gr.Audio(label="Generated Speech", type="numpy")
92
+
93
+ # Connect the button
94
+ submit_btn.click(
95
+ fn=generate_speech,
96
+ inputs=[text_input, voice_dropdown, voice_upload],
97
  outputs=audio_output
98
  )
99
 
100
+ # Launch the app
101
+ demo.launch()