Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -105,10 +105,76 @@ def _init_pocket(
|
|
| 105 |
"sample_rate": model.sample_rate,
|
| 106 |
})
|
| 107 |
print(f"Pocket TTS initialized. Sample rate: {model.sample_rate} Hz")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
except Exception as e:
|
| 109 |
raise gr.Error(f"Failed to initialize Pocket TTS model: {str(e)}")
|
| 110 |
|
| 111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
def _convert_to_wav(audio_path: str) -> str:
|
| 113 |
"""Convert audio file to WAV format if needed.
|
| 114 |
|
|
@@ -181,7 +247,40 @@ def _get_voice_state(voice_name: str | None, custom_audio_path: str | None):
|
|
| 181 |
if voice_name in _POCKET_STATE["voice_states"]:
|
| 182 |
return _POCKET_STATE["voice_states"][voice_name]
|
| 183 |
|
| 184 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
voice_path = PRESET_VOICES[voice_name]
|
| 186 |
print(f"Loading preset voice '{voice_name}' from: {voice_path}")
|
| 187 |
|
|
@@ -381,7 +480,6 @@ with gr.Blocks() as demo:
|
|
| 381 |
label="Generated Speech",
|
| 382 |
streaming=True,
|
| 383 |
autoplay=True,
|
| 384 |
-
buttons=["download"],
|
| 385 |
)
|
| 386 |
|
| 387 |
with gr.Accordion("Advanced Options", open=False):
|
|
|
|
| 105 |
"sample_rate": model.sample_rate,
|
| 106 |
})
|
| 107 |
print(f"Pocket TTS initialized. Sample rate: {model.sample_rate} Hz")
|
| 108 |
+
|
| 109 |
+
# Auto-create missing embeddings if voice cloning is available
|
| 110 |
+
if model.has_voice_cloning:
|
| 111 |
+
_create_missing_embeddings(model)
|
| 112 |
+
else:
|
| 113 |
+
print("Voice cloning not available - using pre-computed embeddings only")
|
| 114 |
+
|
| 115 |
except Exception as e:
|
| 116 |
raise gr.Error(f"Failed to initialize Pocket TTS model: {str(e)}")
|
| 117 |
|
| 118 |
|
| 119 |
+
def _create_missing_embeddings(model) -> None:
|
| 120 |
+
"""Create embeddings for any voices that have audio files but no embedding."""
|
| 121 |
+
import os
|
| 122 |
+
from pocket_tts.data.audio import audio_read
|
| 123 |
+
from pocket_tts.data.audio_utils import convert_audio
|
| 124 |
+
import safetensors.torch
|
| 125 |
+
|
| 126 |
+
voices_dir = os.path.join(os.path.dirname(__file__), "voices")
|
| 127 |
+
embeddings_dir = os.path.join(os.path.dirname(__file__), "embeddings")
|
| 128 |
+
|
| 129 |
+
if not os.path.exists(voices_dir):
|
| 130 |
+
return
|
| 131 |
+
|
| 132 |
+
os.makedirs(embeddings_dir, exist_ok=True)
|
| 133 |
+
|
| 134 |
+
audio_extensions = ('.wav', '.mp3', '.flac', '.ogg', '.m4a')
|
| 135 |
+
|
| 136 |
+
for voice_name, voice_path in PRESET_VOICES.items():
|
| 137 |
+
embedding_path = os.path.join(embeddings_dir, f"{voice_name}.safetensors")
|
| 138 |
+
|
| 139 |
+
# Skip if embedding already exists or no local file
|
| 140 |
+
if os.path.exists(embedding_path) or voice_path is None:
|
| 141 |
+
continue
|
| 142 |
+
|
| 143 |
+
# Skip fallback HuggingFace voices
|
| 144 |
+
if voice_path.startswith("hf://"):
|
| 145 |
+
continue
|
| 146 |
+
|
| 147 |
+
print(f"Creating embedding for '{voice_name}'...")
|
| 148 |
+
|
| 149 |
+
try:
|
| 150 |
+
# Convert to WAV if needed
|
| 151 |
+
audio_path = voice_path
|
| 152 |
+
if not voice_path.lower().endswith('.wav'):
|
| 153 |
+
from pydub import AudioSegment
|
| 154 |
+
import tempfile
|
| 155 |
+
audio = AudioSegment.from_file(voice_path)
|
| 156 |
+
temp_wav = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
|
| 157 |
+
audio.export(temp_wav.name, format='wav')
|
| 158 |
+
audio_path = temp_wav.name
|
| 159 |
+
|
| 160 |
+
# Read and encode audio
|
| 161 |
+
audio, sr = audio_read(audio_path)
|
| 162 |
+
audio_tensor = convert_audio(audio, sr, model.config.mimi.sample_rate, 1)
|
| 163 |
+
|
| 164 |
+
with torch.no_grad():
|
| 165 |
+
audio_prompt = model._encode_audio(audio_tensor.unsqueeze(0).to(model.device))
|
| 166 |
+
|
| 167 |
+
# Save embedding
|
| 168 |
+
safetensors.torch.save_file(
|
| 169 |
+
{"audio_prompt": audio_prompt.cpu()},
|
| 170 |
+
embedding_path
|
| 171 |
+
)
|
| 172 |
+
print(f" Saved: {embedding_path}")
|
| 173 |
+
|
| 174 |
+
except Exception as e:
|
| 175 |
+
print(f" Error creating embedding for {voice_name}: {e}")
|
| 176 |
+
|
| 177 |
+
|
| 178 |
def _convert_to_wav(audio_path: str) -> str:
|
| 179 |
"""Convert audio file to WAV format if needed.
|
| 180 |
|
|
|
|
| 247 |
if voice_name in _POCKET_STATE["voice_states"]:
|
| 248 |
return _POCKET_STATE["voice_states"][voice_name]
|
| 249 |
|
| 250 |
+
# Check for pre-computed embedding first (no voice cloning needed)
|
| 251 |
+
import os
|
| 252 |
+
embeddings_dir = os.path.join(os.path.dirname(__file__), "embeddings")
|
| 253 |
+
embedding_path = os.path.join(embeddings_dir, f"{voice_name}.safetensors")
|
| 254 |
+
|
| 255 |
+
if os.path.exists(embedding_path):
|
| 256 |
+
print(f"Loading pre-computed embedding for '{voice_name}' from: {embedding_path}")
|
| 257 |
+
import safetensors.torch
|
| 258 |
+
from pocket_tts.modules.stateful_module import init_states
|
| 259 |
+
|
| 260 |
+
# Load the audio prompt embedding
|
| 261 |
+
state_dict = safetensors.torch.load_file(embedding_path)
|
| 262 |
+
audio_prompt = state_dict["audio_prompt"].to(model.device)
|
| 263 |
+
|
| 264 |
+
# Create fresh model state and condition it with the audio prompt
|
| 265 |
+
# (same logic as model.get_state_for_audio_prompt uses internally)
|
| 266 |
+
voice_state = init_states(model.flow_lm, batch_size=1, sequence_length=1000)
|
| 267 |
+
model._run_flow_lm_and_increment_step(model_state=voice_state, audio_conditioning=audio_prompt)
|
| 268 |
+
|
| 269 |
+
# Detach all tensors to make them leaf tensors (required for deepcopy)
|
| 270 |
+
def detach_tensors(obj):
|
| 271 |
+
if isinstance(obj, torch.Tensor):
|
| 272 |
+
return obj.detach().clone()
|
| 273 |
+
elif isinstance(obj, dict):
|
| 274 |
+
return {k: detach_tensors(v) for k, v in obj.items()}
|
| 275 |
+
else:
|
| 276 |
+
return obj
|
| 277 |
+
|
| 278 |
+
voice_state = detach_tensors(voice_state)
|
| 279 |
+
|
| 280 |
+
_POCKET_STATE["voice_states"][voice_name] = voice_state
|
| 281 |
+
return voice_state
|
| 282 |
+
|
| 283 |
+
# Fall back to voice cloning (requires auth)
|
| 284 |
voice_path = PRESET_VOICES[voice_name]
|
| 285 |
print(f"Loading preset voice '{voice_name}' from: {voice_path}")
|
| 286 |
|
|
|
|
| 480 |
label="Generated Speech",
|
| 481 |
streaming=True,
|
| 482 |
autoplay=True,
|
|
|
|
| 483 |
)
|
| 484 |
|
| 485 |
with gr.Accordion("Advanced Options", open=False):
|