artificialguybr commited on
Commit
a675cb1
·
verified ·
1 Parent(s): 981bfc5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -62
app.py CHANGED
@@ -13,9 +13,7 @@ REPO_URL = "https://github.com/fishaudio/fish-speech.git"
13
  REPO_DIR = "fish-speech"
14
 
15
  if not os.path.exists(REPO_DIR):
16
- print(f"Clonando o repositório de {REPO_URL}...")
17
  subprocess.run(["git", "clone", REPO_URL, REPO_DIR], check=True)
18
- print("Repositório clonado com sucesso!")
19
 
20
  os.chdir(REPO_DIR)
21
  sys.path.insert(0, os.getcwd())
@@ -23,21 +21,19 @@ sys.path.insert(0, os.getcwd())
23
  from fish_speech.models.text2semantic.inference import (
24
  init_model,
25
  generate_long,
26
- load_codec_model
27
  )
28
 
29
  device = "cuda" if torch.cuda.is_available() else "cpu"
30
  precision = torch.bfloat16
31
 
32
- print("Baixando os pesos do Fish Audio S2 Pro...")
33
  checkpoint_dir = snapshot_download(repo_id="fishaudio/s2-pro")
34
 
35
- print("Carregando o modelo LLAMA (isso pode levar alguns instantes)...")
36
  llama_model, decode_one_token = init_model(
37
- checkpoint_path=checkpoint_dir,
38
- device=device,
39
- precision=precision,
40
- compile=False
41
  )
42
 
43
  with torch.device(device):
@@ -47,47 +43,46 @@ with torch.device(device):
47
  dtype=next(llama_model.parameters()).dtype,
48
  )
49
 
50
- print("Carregando o modelo Codec (VQGAN)...")
51
  codec_checkpoint = os.path.join(checkpoint_dir, "codec.pth")
52
  codec_model = load_codec_model(codec_checkpoint, device=device, precision=precision)
53
 
54
- print("✅ Todos os modelos carregados com sucesso!")
55
-
56
 
57
  @torch.no_grad()
58
- def custom_encode_audio(audio_path, codec, device):
59
- wav_np, _ = librosa.load(audio_path, sr=codec.sample_rate, mono=True)
60
  wav = torch.from_numpy(wav_np).to(device)
61
-
62
- model_dtype = next(codec.parameters()).dtype
63
  audios = wav[None, None, :].to(dtype=model_dtype)
64
  audio_lengths = torch.tensor([wav.shape[0]], device=device, dtype=torch.long)
65
-
66
- indices, feature_lengths = codec.encode(audios, audio_lengths)
67
  return indices[0, :, : feature_lengths[0]]
68
 
69
- @torch.no_grad()
70
- def custom_decode_audio(codes, codec):
71
- audio = codec.from_indices(codes[None])
72
- return audio[0, 0]
 
 
 
 
73
 
74
  @spaces.GPU(duration=120)
75
  def tts_inference(
76
- text,
77
- ref_audio,
78
- ref_text,
79
- max_new_tokens,
80
- chunk_length,
81
- top_p,
82
- repetition_penalty,
83
- temperature
84
  ):
85
  try:
86
  prompt_tokens_list = None
87
-
88
  if ref_audio is not None and ref_text:
89
- prompt_tokens_list = [custom_encode_audio(ref_audio, codec_model, device).cpu()]
90
-
91
  generator = generate_long(
92
  model=llama_model,
93
  device=device,
@@ -105,27 +100,29 @@ def tts_inference(
105
  prompt_text=[ref_text] if ref_text else None,
106
  prompt_tokens=prompt_tokens_list,
107
  )
108
-
109
  codes = []
110
  for response in generator:
111
  if response.action == "sample":
112
  codes.append(response.codes)
113
  elif response.action == "next":
114
  break
115
-
116
  if not codes:
117
- raise gr.Error("Nenhum áudio foi gerado. Verifique o seu texto de entrada.")
118
-
119
- merged_codes = torch.cat(codes, dim=1).to(device).clone()
120
- with torch.inference_mode(False):
121
- audio_waveform = custom_decode_audio(merged_codes, codec_model)
122
  audio_np = audio_waveform.cpu().float().numpy()
123
-
124
  return (codec_model.sample_rate, audio_np)
125
 
 
 
126
  except Exception as e:
127
  traceback.print_exc()
128
- raise gr.Error(f"Erro na Inferência: {str(e)}")
 
129
 
130
  custom_theme = gr.themes.Soft(
131
  primary_hue="blue",
@@ -133,8 +130,8 @@ custom_theme = gr.themes.Soft(
133
  font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"],
134
  )
135
 
136
- with gr.Blocks(theme=custom_theme, title="Fish Audio S2 Pro") as app:
137
-
138
  gr.Markdown(
139
  """
140
  <div style="text-align: center; max-width: 800px; margin: 0 auto; padding: 20px 0;">
@@ -143,26 +140,29 @@ with gr.Blocks(theme=custom_theme, title="Fish Audio S2 Pro") as app:
143
  </h1>
144
  <p style="font-size: 1.1rem; color: #4B5563;">
145
  State-of-the-Art Dual-Autoregressive Text-to-Speech.<br>
146
- Supports over 80 languages, emotional control via text tags (e.g. <code>[laugh]</code>, <code>[whisper]</code>) and Zero-Shot voice cloning.
147
  </p>
148
  </div>
149
  """
150
  )
151
-
152
  with gr.Row():
153
  with gr.Column(scale=5):
154
  gr.Markdown("### ✍️ Input Text")
155
  text_input = gr.Textbox(
156
  show_label=False,
157
  placeholder="Type the text you want to synthesize here.\nTry adding tags like [laugh], [whisper], or [angry]!",
158
- lines=7
159
  )
160
-
161
  with gr.Accordion("🎙️ Voice Cloning (Optional Reference)", open=False):
162
  gr.Markdown("Upload a clean 5–10 second audio clip and type exactly what is said in it to clone the voice.")
163
  ref_audio = gr.Audio(label="Reference Audio", type="filepath")
164
- ref_text = gr.Textbox(label="Reference Audio Text", placeholder="Exact transcription of the reference audio...")
165
-
 
 
 
166
  with gr.Accordion("⚙️ Advanced Settings", open=False):
167
  with gr.Row():
168
  max_new_tokens = gr.Slider(0, 2048, 1024, step=8, label="Max New Tokens (0 = no limit)")
@@ -171,44 +171,48 @@ with gr.Blocks(theme=custom_theme, title="Fish Audio S2 Pro") as app:
171
  top_p = gr.Slider(0.1, 1.0, 0.7, step=0.01, label="Top-P")
172
  repetition_penalty = gr.Slider(0.9, 2.0, 1.2, step=0.01, label="Repetition Penalty")
173
  temperature = gr.Slider(0.1, 1.0, 0.7, step=0.01, label="Temperature")
174
-
175
  generate_btn = gr.Button("🚀 Generate Audio", variant="primary", size="lg")
176
-
177
  with gr.Column(scale=4):
178
  gr.Markdown("### 🎧 Result")
179
- audio_output = gr.Audio(label="Generated Audio", type="numpy", interactive=False, autoplay=True)
180
-
 
 
 
 
 
181
  gr.Markdown(
182
  """
183
  <div style="background-color: #EFF6FF; padding: 15px; border-radius: 8px; margin-top: 20px;">
184
  <h4 style="margin-top: 0; color: #1D4ED8;">💡 Pro Tips</h4>
185
  <ul style="margin-bottom: 0; color: #1E3A8A; font-size: 0.95rem;">
186
- <li>The model understands natural text perfectly — no need for manual phonemes.</li>
187
- <li>Wrap words in brackets to control emotion. Example: <i>[pitch up] Wow! [laugh]</i></li>
188
- <li>For cloning, the more accurate the transcription of the reference audio, the better the result.</li>
189
  </ul>
190
  </div>
191
  """
192
  )
193
-
194
  gr.Markdown("### 🌟 Examples")
195
  gr.Examples(
196
  examples=[
197
  ["Hello world! This is a test of the Fish Audio S2 Pro model.", None, "", 1024, 200, 0.7, 1.2, 0.7],
198
  ["I can't believe it! [laugh] This is absolutely amazing!", None, "", 1024, 200, 0.7, 1.2, 0.7],
199
- ["[whisper in small voice] I have a secret to tell you... promise you won't tell anyone?", None, "", 1024, 200, 0.7, 1.2, 0.7]
200
  ],
201
  inputs=[text_input, ref_audio, ref_text, max_new_tokens, chunk_length, top_p, repetition_penalty, temperature],
202
  outputs=[audio_output],
203
  fn=tts_inference,
204
  cache_examples=False,
205
  )
206
-
207
- # Evento de clique do botão
208
  generate_btn.click(
209
  fn=tts_inference,
210
  inputs=[text_input, ref_audio, ref_text, max_new_tokens, chunk_length, top_p, repetition_penalty, temperature],
211
- outputs=[audio_output]
212
  )
213
 
214
  if __name__ == "__main__":
 
13
  REPO_DIR = "fish-speech"
14
 
15
  if not os.path.exists(REPO_DIR):
 
16
  subprocess.run(["git", "clone", REPO_URL, REPO_DIR], check=True)
 
17
 
18
  os.chdir(REPO_DIR)
19
  sys.path.insert(0, os.getcwd())
 
21
  from fish_speech.models.text2semantic.inference import (
22
  init_model,
23
  generate_long,
24
+ load_codec_model,
25
  )
26
 
27
  device = "cuda" if torch.cuda.is_available() else "cpu"
28
  precision = torch.bfloat16
29
 
 
30
  checkpoint_dir = snapshot_download(repo_id="fishaudio/s2-pro")
31
 
 
32
  llama_model, decode_one_token = init_model(
33
+ checkpoint_path=checkpoint_dir,
34
+ device=device,
35
+ precision=precision,
36
+ compile=False,
37
  )
38
 
39
  with torch.device(device):
 
43
  dtype=next(llama_model.parameters()).dtype,
44
  )
45
 
 
46
  codec_checkpoint = os.path.join(checkpoint_dir, "codec.pth")
47
  codec_model = load_codec_model(codec_checkpoint, device=device, precision=precision)
48
 
 
 
49
 
50
  @torch.no_grad()
51
+ def encode_reference_audio(audio_path):
52
+ wav_np, _ = librosa.load(audio_path, sr=codec_model.sample_rate, mono=True)
53
  wav = torch.from_numpy(wav_np).to(device)
54
+ model_dtype = next(codec_model.parameters()).dtype
 
55
  audios = wav[None, None, :].to(dtype=model_dtype)
56
  audio_lengths = torch.tensor([wav.shape[0]], device=device, dtype=torch.long)
57
+ indices, feature_lengths = codec_model.encode(audios, audio_lengths)
 
58
  return indices[0, :, : feature_lengths[0]]
59
 
60
+
61
+ def decode_codes_to_audio(merged_codes):
62
+ with torch.inference_mode(False):
63
+ with torch.no_grad():
64
+ codes_clean = merged_codes.clone()
65
+ audio = codec_model.from_indices(codes_clean[None])
66
+ return audio[0, 0]
67
+
68
 
69
  @spaces.GPU(duration=120)
70
  def tts_inference(
71
+ text,
72
+ ref_audio,
73
+ ref_text,
74
+ max_new_tokens,
75
+ chunk_length,
76
+ top_p,
77
+ repetition_penalty,
78
+ temperature,
79
  ):
80
  try:
81
  prompt_tokens_list = None
82
+
83
  if ref_audio is not None and ref_text:
84
+ prompt_tokens_list = [encode_reference_audio(ref_audio).cpu()]
85
+
86
  generator = generate_long(
87
  model=llama_model,
88
  device=device,
 
100
  prompt_text=[ref_text] if ref_text else None,
101
  prompt_tokens=prompt_tokens_list,
102
  )
103
+
104
  codes = []
105
  for response in generator:
106
  if response.action == "sample":
107
  codes.append(response.codes)
108
  elif response.action == "next":
109
  break
110
+
111
  if not codes:
112
+ raise gr.Error("No audio was generated. Please check your input text.")
113
+
114
+ merged_codes = torch.cat(codes, dim=1).to(device)
115
+ audio_waveform = decode_codes_to_audio(merged_codes)
 
116
  audio_np = audio_waveform.cpu().float().numpy()
117
+
118
  return (codec_model.sample_rate, audio_np)
119
 
120
+ except gr.Error:
121
+ raise
122
  except Exception as e:
123
  traceback.print_exc()
124
+ raise gr.Error(f"Inference error: {str(e)}")
125
+
126
 
127
  custom_theme = gr.themes.Soft(
128
  primary_hue="blue",
 
130
  font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"],
131
  )
132
 
133
+ with gr.Blocks(title="Fish Audio S2 Pro") as app:
134
+
135
  gr.Markdown(
136
  """
137
  <div style="text-align: center; max-width: 800px; margin: 0 auto; padding: 20px 0;">
 
140
  </h1>
141
  <p style="font-size: 1.1rem; color: #4B5563;">
142
  State-of-the-Art Dual-Autoregressive Text-to-Speech.<br>
143
+ Supports 80+ languages, emotion tags (e.g. <code>[laugh]</code>, <code>[whisper]</code>) and zero-shot voice cloning.
144
  </p>
145
  </div>
146
  """
147
  )
148
+
149
  with gr.Row():
150
  with gr.Column(scale=5):
151
  gr.Markdown("### ✍️ Input Text")
152
  text_input = gr.Textbox(
153
  show_label=False,
154
  placeholder="Type the text you want to synthesize here.\nTry adding tags like [laugh], [whisper], or [angry]!",
155
+ lines=7,
156
  )
157
+
158
  with gr.Accordion("🎙️ Voice Cloning (Optional Reference)", open=False):
159
  gr.Markdown("Upload a clean 5–10 second audio clip and type exactly what is said in it to clone the voice.")
160
  ref_audio = gr.Audio(label="Reference Audio", type="filepath")
161
+ ref_text = gr.Textbox(
162
+ label="Reference Audio Transcription",
163
+ placeholder="Exact transcription of the reference audio...",
164
+ )
165
+
166
  with gr.Accordion("⚙️ Advanced Settings", open=False):
167
  with gr.Row():
168
  max_new_tokens = gr.Slider(0, 2048, 1024, step=8, label="Max New Tokens (0 = no limit)")
 
171
  top_p = gr.Slider(0.1, 1.0, 0.7, step=0.01, label="Top-P")
172
  repetition_penalty = gr.Slider(0.9, 2.0, 1.2, step=0.01, label="Repetition Penalty")
173
  temperature = gr.Slider(0.1, 1.0, 0.7, step=0.01, label="Temperature")
174
+
175
  generate_btn = gr.Button("🚀 Generate Audio", variant="primary", size="lg")
176
+
177
  with gr.Column(scale=4):
178
  gr.Markdown("### 🎧 Result")
179
+ audio_output = gr.Audio(
180
+ label="Generated Audio",
181
+ type="numpy",
182
+ interactive=False,
183
+ autoplay=True,
184
+ )
185
+
186
  gr.Markdown(
187
  """
188
  <div style="background-color: #EFF6FF; padding: 15px; border-radius: 8px; margin-top: 20px;">
189
  <h4 style="margin-top: 0; color: #1D4ED8;">💡 Pro Tips</h4>
190
  <ul style="margin-bottom: 0; color: #1E3A8A; font-size: 0.95rem;">
191
+ <li>The model understands natural text — no need for manual phonemes.</li>
192
+ <li>Control emotion with brackets: <i>[pitch up] Wow! [laugh]</i></li>
193
+ <li>For cloning, the more accurate the transcription, the better the result.</li>
194
  </ul>
195
  </div>
196
  """
197
  )
198
+
199
  gr.Markdown("### 🌟 Examples")
200
  gr.Examples(
201
  examples=[
202
  ["Hello world! This is a test of the Fish Audio S2 Pro model.", None, "", 1024, 200, 0.7, 1.2, 0.7],
203
  ["I can't believe it! [laugh] This is absolutely amazing!", None, "", 1024, 200, 0.7, 1.2, 0.7],
204
+ ["[whisper in small voice] I have a secret to tell you... promise you won't tell anyone?", None, "", 1024, 200, 0.7, 1.2, 0.7],
205
  ],
206
  inputs=[text_input, ref_audio, ref_text, max_new_tokens, chunk_length, top_p, repetition_penalty, temperature],
207
  outputs=[audio_output],
208
  fn=tts_inference,
209
  cache_examples=False,
210
  )
211
+
 
212
  generate_btn.click(
213
  fn=tts_inference,
214
  inputs=[text_input, ref_audio, ref_text, max_new_tokens, chunk_length, top_p, repetition_penalty, temperature],
215
+ outputs=[audio_output],
216
  )
217
 
218
  if __name__ == "__main__":