Nymbo commited on
Commit
a23619f
·
verified ·
1 Parent(s): 4d21128

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -2
app.py CHANGED
@@ -105,10 +105,76 @@ def _init_pocket(
105
  "sample_rate": model.sample_rate,
106
  })
107
  print(f"Pocket TTS initialized. Sample rate: {model.sample_rate} Hz")
 
 
 
 
 
 
 
108
  except Exception as e:
109
  raise gr.Error(f"Failed to initialize Pocket TTS model: {str(e)}")
110
 
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  def _convert_to_wav(audio_path: str) -> str:
113
  """Convert audio file to WAV format if needed.
114
 
@@ -181,7 +247,40 @@ def _get_voice_state(voice_name: str | None, custom_audio_path: str | None):
181
  if voice_name in _POCKET_STATE["voice_states"]:
182
  return _POCKET_STATE["voice_states"][voice_name]
183
 
184
- # Load and cache voice state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  voice_path = PRESET_VOICES[voice_name]
186
  print(f"Loading preset voice '{voice_name}' from: {voice_path}")
187
 
@@ -381,7 +480,6 @@ with gr.Blocks() as demo:
381
  label="Generated Speech",
382
  streaming=True,
383
  autoplay=True,
384
- buttons=["download"],
385
  )
386
 
387
  with gr.Accordion("Advanced Options", open=False):
 
105
  "sample_rate": model.sample_rate,
106
  })
107
  print(f"Pocket TTS initialized. Sample rate: {model.sample_rate} Hz")
108
+
109
+ # Auto-create missing embeddings if voice cloning is available
110
+ if model.has_voice_cloning:
111
+ _create_missing_embeddings(model)
112
+ else:
113
+ print("Voice cloning not available - using pre-computed embeddings only")
114
+
115
  except Exception as e:
116
  raise gr.Error(f"Failed to initialize Pocket TTS model: {str(e)}")
117
 
118
 
119
+ def _create_missing_embeddings(model) -> None:
120
+ """Create embeddings for any voices that have audio files but no embedding."""
121
+ import os
122
+ from pocket_tts.data.audio import audio_read
123
+ from pocket_tts.data.audio_utils import convert_audio
124
+ import safetensors.torch
125
+
126
+ voices_dir = os.path.join(os.path.dirname(__file__), "voices")
127
+ embeddings_dir = os.path.join(os.path.dirname(__file__), "embeddings")
128
+
129
+ if not os.path.exists(voices_dir):
130
+ return
131
+
132
+ os.makedirs(embeddings_dir, exist_ok=True)
133
+
134
+ audio_extensions = ('.wav', '.mp3', '.flac', '.ogg', '.m4a')
135
+
136
+ for voice_name, voice_path in PRESET_VOICES.items():
137
+ embedding_path = os.path.join(embeddings_dir, f"{voice_name}.safetensors")
138
+
139
+ # Skip if embedding already exists or no local file
140
+ if os.path.exists(embedding_path) or voice_path is None:
141
+ continue
142
+
143
+ # Skip fallback HuggingFace voices
144
+ if voice_path.startswith("hf://"):
145
+ continue
146
+
147
+ print(f"Creating embedding for '{voice_name}'...")
148
+
149
+ try:
150
+ # Convert to WAV if needed
151
+ audio_path = voice_path
152
+ if not voice_path.lower().endswith('.wav'):
153
+ from pydub import AudioSegment
154
+ import tempfile
155
+ audio = AudioSegment.from_file(voice_path)
156
+ temp_wav = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
157
+ audio.export(temp_wav.name, format='wav')
158
+ audio_path = temp_wav.name
159
+
160
+ # Read and encode audio
161
+ audio, sr = audio_read(audio_path)
162
+ audio_tensor = convert_audio(audio, sr, model.config.mimi.sample_rate, 1)
163
+
164
+ with torch.no_grad():
165
+ audio_prompt = model._encode_audio(audio_tensor.unsqueeze(0).to(model.device))
166
+
167
+ # Save embedding
168
+ safetensors.torch.save_file(
169
+ {"audio_prompt": audio_prompt.cpu()},
170
+ embedding_path
171
+ )
172
+ print(f" Saved: {embedding_path}")
173
+
174
+ except Exception as e:
175
+ print(f" Error creating embedding for {voice_name}: {e}")
176
+
177
+
178
  def _convert_to_wav(audio_path: str) -> str:
179
  """Convert audio file to WAV format if needed.
180
 
 
247
  if voice_name in _POCKET_STATE["voice_states"]:
248
  return _POCKET_STATE["voice_states"][voice_name]
249
 
250
+ # Check for pre-computed embedding first (no voice cloning needed)
251
+ import os
252
+ embeddings_dir = os.path.join(os.path.dirname(__file__), "embeddings")
253
+ embedding_path = os.path.join(embeddings_dir, f"{voice_name}.safetensors")
254
+
255
+ if os.path.exists(embedding_path):
256
+ print(f"Loading pre-computed embedding for '{voice_name}' from: {embedding_path}")
257
+ import safetensors.torch
258
+ from pocket_tts.modules.stateful_module import init_states
259
+
260
+ # Load the audio prompt embedding
261
+ state_dict = safetensors.torch.load_file(embedding_path)
262
+ audio_prompt = state_dict["audio_prompt"].to(model.device)
263
+
264
+ # Create fresh model state and condition it with the audio prompt
265
+ # (same logic as model.get_state_for_audio_prompt uses internally)
266
+ voice_state = init_states(model.flow_lm, batch_size=1, sequence_length=1000)
267
+ model._run_flow_lm_and_increment_step(model_state=voice_state, audio_conditioning=audio_prompt)
268
+
269
+ # Detach all tensors to make them leaf tensors (required for deepcopy)
270
+ def detach_tensors(obj):
271
+ if isinstance(obj, torch.Tensor):
272
+ return obj.detach().clone()
273
+ elif isinstance(obj, dict):
274
+ return {k: detach_tensors(v) for k, v in obj.items()}
275
+ else:
276
+ return obj
277
+
278
+ voice_state = detach_tensors(voice_state)
279
+
280
+ _POCKET_STATE["voice_states"][voice_name] = voice_state
281
+ return voice_state
282
+
283
+ # Fall back to voice cloning (requires auth)
284
  voice_path = PRESET_VOICES[voice_name]
285
  print(f"Loading preset voice '{voice_name}' from: {voice_path}")
286
 
 
480
  label="Generated Speech",
481
  streaming=True,
482
  autoplay=True,
 
483
  )
484
 
485
  with gr.Accordion("Advanced Options", open=False):