artificialguybr commited on
Commit
84a2fb0
Β·
verified Β·
1 Parent(s): 6795799

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +201 -0
app.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import traceback
4
+ from pathlib import Path
5
+
6
+ import gradio as gr
7
+ import numpy as np
8
+ import spaces
9
+ import torch
10
+ from huggingface_hub import snapshot_download
11
+
12
+ sys.path.append(str(Path(__file__).parent))
13
+
14
+ from fish_speech.models.text2semantic.inference import (
15
+ init_model,
16
+ generate_long,
17
+ load_codec_model,
18
+ decode_to_audio,
19
+ encode_audio
20
+ )
21
+
22
+
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+ precision = torch.bfloat16
25
+
26
+ print("Downloading Fish Audio S2 Pro weights...")
27
+ checkpoint_dir = snapshot_download(repo_id="fishaudio/s2-pro")
28
+
29
+ print("Loading LLAMA model...")
30
+ llama_model, decode_one_token = init_model(
31
+ checkpoint_path=checkpoint_dir,
32
+ device=device,
33
+ precision=precision,
34
+ compile=False
35
+ )
36
+
37
+ with torch.device(device):
38
+ llama_model.setup_caches(
39
+ max_batch_size=1,
40
+ max_seq_len=llama_model.config.max_seq_len,
41
+ dtype=next(llama_model.parameters()).dtype,
42
+ )
43
+
44
+ print("Loading Codec model...")
45
+ codec_checkpoint = os.path.join(checkpoint_dir, "codec.pth")
46
+ codec_model = load_codec_model(codec_checkpoint, device=device, precision=precision)
47
+
48
+ print("All models loaded successfully!")
49
+
50
+
51
+
52
+ @spaces.GPU(duration=120)
53
+ def tts_inference(
54
+ text,
55
+ ref_audio,
56
+ ref_text,
57
+ max_new_tokens,
58
+ chunk_length,
59
+ top_p,
60
+ repetition_penalty,
61
+ temperature
62
+ ):
63
+ """
64
+ Main TTS Generation function decorated with @spaces.GPU
65
+ to request GPU allocation only during execution.
66
+ """
67
+ try:
68
+ prompt_tokens_list = None
69
+
70
+ if ref_audio is not None and ref_text:
71
+ prompt_tokens_list = [encode_audio(ref_audio, codec_model, device).cpu()]
72
+
73
+ generator = generate_long(
74
+ model=llama_model,
75
+ device=device,
76
+ decode_one_token=decode_one_token,
77
+ text=text,
78
+ num_samples=1,
79
+ max_new_tokens=max_new_tokens,
80
+ top_p=top_p,
81
+ top_k=30,
82
+ temperature=temperature,
83
+ repetition_penalty=repetition_penalty,
84
+ compile=False,
85
+ iterative_prompt=True,
86
+ chunk_length=chunk_length,
87
+ prompt_text=[ref_text] if ref_text else None,
88
+ prompt_tokens=prompt_tokens_list,
89
+ )
90
+
91
+ codes = []
92
+ for response in generator:
93
+ if response.action == "sample":
94
+ codes.append(response.codes)
95
+ elif response.action == "next":
96
+ break
97
+
98
+ if not codes:
99
+ raise gr.Error("No audio generated. Please check your text.")
100
+
101
+ merged_codes = torch.cat(codes, dim=1)
102
+ audio_waveform = decode_to_audio(merged_codes.to(device), codec_model)
103
+ audio_np = audio_waveform.cpu().float().numpy()
104
+
105
+ return (codec_model.sample_rate, audio_np)
106
+
107
+ except Exception as e:
108
+ traceback.print_exc()
109
+ raise gr.Error(f"Inference Error: {str(e)}")
110
+
111
+
112
+ custom_theme = gr.themes.Soft(
113
+ primary_hue="blue",
114
+ secondary_hue="indigo",
115
+ font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"],
116
+ ).set(
117
+ block_title_text_weight="600",
118
+ block_border_width="1px",
119
+ block_shadow="0px 2px 4px rgba(0, 0, 0, 0.05)",
120
+ button_shadow="0px 2px 4px rgba(0, 0, 0, 0.1)",
121
+ )
122
+
123
+ with gr.Blocks(theme=custom_theme, title="Fish Audio S2 Pro") as app:
124
+
125
+ gr.Markdown(
126
+ """
127
+ <div style="text-align: center; max-width: 800px; margin: 0 auto; padding: 20px 0;">
128
+ <h1 style="font-size: 2.5rem; font-weight: 800; color: #1E3A8A; margin-bottom: 10px;">
129
+ 🐟 Fish Audio S2 Pro
130
+ </h1>
131
+ <p style="font-size: 1.1rem; color: #4B5563;">
132
+ State-of-the-Art Dual-Autoregressive Text-to-Speech.
133
+ Supports 80+ languages, emotional inline control (e.g., <code>[laugh]</code>, <code>[whisper]</code>), and zero-shot voice cloning.
134
+ </p>
135
+ </div>
136
+ """
137
+ )
138
+
139
+ with gr.Row():
140
+ with gr.Column(scale=5):
141
+ gr.Markdown("### ✍️ Text Input")
142
+ text_input = gr.Textbox(
143
+ show_label=False,
144
+ placeholder="Enter the text you want to synthesize here.\nTry adding tags like [laugh], [whisper], or [angry]!",
145
+ lines=7
146
+ )
147
+
148
+ with gr.Accordion("πŸŽ™οΈ Voice Cloning (Optional Reference)", open=False):
149
+ gr.Markdown("Upload a 5-10 second clear audio clip and type its exact transcription to clone the voice.")
150
+ ref_audio = gr.Audio(label="Reference Audio", type="filepath")
151
+ ref_text = gr.Textbox(label="Reference Text", placeholder="Transcription of the reference audio...")
152
+
153
+ with gr.Accordion("βš™οΈ Advanced Settings", open=False):
154
+ with gr.Row():
155
+ max_new_tokens = gr.Slider(0, 2048, 1024, step=8, label="Max New Tokens (0 = unlimited)")
156
+ chunk_length = gr.Slider(100, 400, 200, step=8, label="Chunk Length")
157
+ with gr.Row():
158
+ top_p = gr.Slider(0.1, 1.0, 0.7, step=0.01, label="Top-P")
159
+ repetition_penalty = gr.Slider(0.9, 2.0, 1.2, step=0.01, label="Repetition Penalty")
160
+ temperature = gr.Slider(0.1, 1.0, 0.7, step=0.01, label="Temperature")
161
+
162
+ generate_btn = gr.Button("πŸš€ Generate Speech", variant="primary", size="lg")
163
+
164
+ with gr.Column(scale=4):
165
+ gr.Markdown("### 🎧 Output")
166
+ audio_output = gr.Audio(label="Generated Audio", type="numpy", interactive=False, autoplay=True)
167
+
168
+ gr.Markdown(
169
+ """
170
+ <div style="background-color: #EFF6FF; padding: 15px; border-radius: 8px; margin-top: 20px;">
171
+ <h4 style="margin-top: 0; color: #1D4ED8;">πŸ’‘ Pro Tips</h4>
172
+ <ul style="margin-bottom: 0; color: #1E3A8A; font-size: 0.95rem;">
173
+ <li>You don't need phonemes, the model understands raw text seamlessly.</li>
174
+ <li>Try wrapping specific words in brackets for inline emotional control.</li>
175
+ <li>For cloning, the closer the transcript matches the audio, the better the result.</li>
176
+ </ul>
177
+ </div>
178
+ """
179
+ )
180
+
181
+ gr.Markdown("### 🌟 Examples")
182
+ gr.Examples(
183
+ examples=[
184
+ ["Hello world! This is a test of the Fish Audio S2 Pro model.", None, "", 1024, 200, 0.7, 1.2, 0.7],
185
+ ["I can't believe it! [laugh] This is absolutely amazing!", None, "", 1024, 200, 0.7, 1.2, 0.7],
186
+ ["[whisper in small voice] I have a secret to tell you... promise you won't tell anyone?", None, "", 1024, 200, 0.7, 1.2, 0.7]
187
+ ],
188
+ inputs=[text_input, ref_audio, ref_text, max_new_tokens, chunk_length, top_p, repetition_penalty, temperature],
189
+ outputs=[audio_output],
190
+ fn=tts_inference,
191
+ cache_examples=False,
192
+ )
193
+
194
+ generate_btn.click(
195
+ fn=tts_inference,
196
+ inputs=[text_input, ref_audio, ref_text, max_new_tokens, chunk_length, top_p, repetition_penalty, temperature],
197
+ outputs=[audio_output]
198
+ )
199
+
200
+ if __name__ == "__main__":
201
+ app.launch()