import gradio as gr import numpy as np import os import time import torch import tempfile import threading import scipy.io.wavfile import traceback from huggingface_hub import login from pocket_tts import TTSModel # Configure PyTorch threading behavior for CPU optimization torch.set_num_threads(1) torch.set_num_interop_threads(1) # HF Token for gated models in Spaces hf_token = os.getenv("HF_TOKEN") if hf_token: print("HF_TOKEN found, logging in...") login(token=hf_token) VOICES = ['alba', 'marius', 'javert', 'jean', 'fantine', 'cosette', 'eponine', 'azelma'] # Default configuration values DEFAULT_VOICE = "alba" DEFAULT_MODEL_VARIANT = "b6369a24" DEFAULT_TEMPERATURE = 0.1 DEFAULT_LSD_DECODE_STEPS = 1 DEFAULT_EOS_THRESHOLD = -4.0 DEFAULT_NOISE_CLAMP = 0.0 DEFAULT_FRAMES_AFTER_EOS = 10 MAXIMUM_INPUT_LENGTH = 111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111 TEMPORARY_FILE_LIFETIME_SECONDS = 7200 # 2 hours generation_state_lock = threading.Lock() is_currently_generating = False stop_generation_requested = False temporary_files_registry = {} temporary_files_lock = threading.Lock() class TextToSpeechManager: """ Manages TTS model lifecycle and speech generation operations. Implements lazy loading and caching strategies for performance. """ def __init__(self): self.loaded_model = None self.current_configuration = {} self.voice_state_cache = {} def load_or_get_model( self, model_variant, temperature, lsd_decode_steps, noise_clamp, eos_threshold ): """Load a TTS model or return cached instance if configuration matches.""" processed_variant = str(model_variant or DEFAULT_MODEL_VARIANT).strip() processed_temperature = float(temperature) if temperature is not None else DEFAULT_TEMPERATURE processed_lsd_steps = int(lsd_decode_steps) if lsd_decode_steps is not None else DEFAULT_LSD_DECODE_STEPS processed_noise_clamp = float(noise_clamp) if noise_clamp and float(noise_clamp) > 0 else None processed_eos_threshold = float(eos_threshold) if eos_threshold is not None else DEFAULT_EOS_THRESHOLD requested_configuration = { "variant": processed_variant, "temp": processed_temperature, "lsd_decode_steps": processed_lsd_steps, "noise_clamp": processed_noise_clamp, "eos_threshold": processed_eos_threshold } if self.loaded_model is None or self.current_configuration != requested_configuration: print(f"Loading model with config: {requested_configuration}") self.loaded_model = TTSModel.load_model(**requested_configuration) self.current_configuration = requested_configuration self.voice_state_cache = {} print("Model loaded.") return self.loaded_model def get_voice_state_for_preset(self, voice_name): """Get or compute voice state for a preset voice with caching.""" validated_voice = voice_name if voice_name in VOICES else DEFAULT_VOICE if validated_voice not in self.voice_state_cache: self.voice_state_cache[validated_voice] = self.loaded_model.get_state_for_audio_prompt( audio_conditioning=validated_voice, truncate=False ) return self.voice_state_cache[validated_voice] def get_voice_state_for_clone(self, audio_file_path): """Compute voice state from uploaded audio file for voice cloning.""" return self.loaded_model.get_state_for_audio_prompt( audio_conditioning=audio_file_path, truncate=False ) def generate_audio(self, text_content, voice_state, frames_after_eos, enable_custom_frames): """Generate speech audio from text using the specified voice state.""" processed_frames = int(frames_after_eos) if enable_custom_frames else None return self.loaded_model.generate_audio( model_state=voice_state, text_to_generate=text_content, frames_after_eos=processed_frames, copy_state=True ) def save_audio_to_file(self, audio_tensor): """Save generated audio tensor to a temporary WAV file.""" audio_numpy_data = audio_tensor.numpy() audio_sample_rate = self.loaded_model.sample_rate output_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) scipy.io.wavfile.write(output_file.name, audio_sample_rate, audio_numpy_data) with temporary_files_lock: temporary_files_registry[output_file.name] = time.time() return output_file.name # Create global TTS manager instance tts_manager = TextToSpeechManager() # Load model at startup with default parameters print("Loading PocketTTS model with default parameters...") tts_manager.load_or_get_model( DEFAULT_MODEL_VARIANT, DEFAULT_TEMPERATURE, DEFAULT_LSD_DECODE_STEPS, DEFAULT_NOISE_CLAMP, DEFAULT_EOS_THRESHOLD ) print("Model ready!") def cleanup_expired_temporary_files(): """Remove temporary files that have exceeded their lifetime.""" current_timestamp = time.time() expired_files = [] with temporary_files_lock: for file_path, creation_timestamp in list(temporary_files_registry.items()): if current_timestamp - creation_timestamp > TEMPORARY_FILE_LIFETIME_SECONDS: expired_files.append(file_path) for file_path in expired_files: try: if os.path.exists(file_path): os.remove(file_path) del temporary_files_registry[file_path] except Exception: pass def validate_text_input(text_content): """Validate and clean text input for speech generation.""" if not text_content or not isinstance(text_content, str): return False, "" cleaned_text = text_content.strip() if not cleaned_text: return False, "" if len(cleaned_text) > MAXIMUM_INPUT_LENGTH: return False, f"Input exceeds maximum length of {MAXIMUM_INPUT_LENGTH} characters." return True, cleaned_text def request_generation_stop(): """Signal a request to stop the current generation.""" global stop_generation_requested stop_generation_requested = True return gr.update(interactive=False) # Speech generation function def generate_speech( text, voice_mode, voice_dropdown, voice_upload, temperature, lsd_decode_steps, noise_clamp, eos_threshold, frames_after_eos, enable_custom_frames ): """Perform the complete speech generation workflow with thread safety.""" global is_currently_generating, stop_generation_requested cleanup_expired_temporary_files() is_valid, validation_result = validate_text_input(text) if not is_valid: if validation_result: raise gr.Error(validation_result) raise gr.Error("Please enter valid text to generate speech.") if voice_mode == "Voice Cloning" and not voice_upload: raise gr.Error("Please upload an audio file for voice cloning.") with generation_state_lock: if is_currently_generating: raise gr.Error("A generation is already in progress. Please wait.") is_currently_generating = True stop_generation_requested = False try: tts_manager.load_or_get_model( DEFAULT_MODEL_VARIANT, temperature, lsd_decode_steps, noise_clamp, eos_threshold ) if stop_generation_requested: return None if voice_mode == "Voice Cloning": voice_state = tts_manager.get_voice_state_for_clone(voice_upload) else: voice_state = tts_manager.get_voice_state_for_preset(voice_dropdown) if stop_generation_requested: return None print(f"Generating with voice mode: {voice_mode}, temp: {temperature}, lsd_steps: {lsd_decode_steps}") generated_audio = tts_manager.generate_audio( validation_result, voice_state, frames_after_eos, enable_custom_frames ) if stop_generation_requested: return None output_file_path = tts_manager.save_audio_to_file(generated_audio) return output_file_path except gr.Error: raise except Exception as e: full_error = traceback.format_exc() print(f"Unexpected error: {full_error}") raise gr.Error(f"An unexpected error occurred: {str(e)}") finally: with generation_state_lock: is_currently_generating = False stop_generation_requested = False # Load custom theme with fallback try: theme = gr.Theme.from_hub("JohnSmith9982/small_and_pretty") except Exception as e: print(f"Warning: Could not load custom theme: {e}. Using default Soft theme.") theme = gr.themes.Soft() css = """ footer {visibility: hidden} .gradio-container { max-width: 100% !important; padding: 0 !important; } @media (min-width: 768px) { .gradio-container { padding-left: 2% !important; padding-right: 2% !important; } } .header-section { text-align: left; margin-bottom: 1.5rem; } .main-title { color: #10b981; font-weight: 800; font-size: 1.8rem; margin: 5px 0; } @media (min-width: 768px) { .main-title { font-size: 2.2rem; } } .logo-container { display: flex; justify-content: flex-start; align-items: center; gap: 10px; margin-bottom: 10px; } .logo-img { height: 40px; border-radius: 8px; } @media (min-width: 768px) { .logo-img { height: 50px; } .logo-container { gap: 15px; } } .description { max-width: 900px; margin: 10px 0; font-size: 0.95rem; line-height: 1.5; color: #4b5563; } .links-row { display: flex; flex-wrap: wrap; justify-content: flex-start; gap: 8px; margin: 10px 0; font-size: 0.85rem; } @media (min-width: 768px) { .links-row { gap: 10px; font-size: 0.9rem; } } .links-row a { color: #10b981; text-decoration: none; padding: 3px 10px; border: 1px solid #10b981; border-radius: 15px; transition: all 0.2s; white-space: nowrap; } .links-row a:hover { background-color: #10b981; color: white; } .social-handles { display: flex; justify-content: center; gap: 20px; margin: 15px 0; } .social-icon { width: 28px; height: 28px; transition: all 0.3s ease; } .social-icon:hover { transform: scale(1.1) translateY(-3px); } .disclaimer { text-align: center; font-size: 0.8rem; color: #9ca3af; margin-top: 30px; padding: 15px; border-top: 1px solid #f3f4f6; } @media (min-width: 768px) { .disclaimer { margin-top: 40px; padding: 20px; } } #voice-mode .wrap { display: flex !important; flex-direction: row !important; width: 100% !important; } #voice-mode .wrap label { flex: 1 !important; justify-content: center !important; text-align: center !important; } """ with gr.Blocks() as demo: with gr.Column(elem_classes="header-section"): with gr.Row(): with gr.Column(scale=4): gr.HTML("""