Spaces:
Sleeping
Sleeping
| import subprocess | |
| import sys | |
| import os | |
| import tempfile | |
| from huggingface_hub import hf_hub_download | |
| # --- 1. PRE-FLIGHT: BYPASS BUILD ISOLATION --- | |
| def pre_flight_setup(): | |
| try: | |
| import chatterbox | |
| except ImportError: | |
| print("Applying pkuseg build isolation bypass...") | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "numpy<2.0.0", "cython", "wheel"]) | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-build-isolation", "pkuseg==0.0.25"]) | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "chatterbox-tts>=0.1.7"]) | |
| pre_flight_setup() | |
| # --- 2. MAIN APPLICATION --- | |
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torchaudio as ta | |
| from peft import PeftModel | |
| from chatterbox.tts import ChatterboxTTS | |
| from chatterbox.models.tokenizers import EnTokenizer | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| repo_id = "Praha-Labs/PrahaTTS-ML" | |
| def load_model(): | |
| print(f"Loading base Chatterbox model on {device}...") | |
| model = ChatterboxTTS.from_pretrained(device=device) | |
| print("Applying custom Indic tokenizer...") | |
| try: | |
| tokenizer_path = hf_hub_download(repo_id=repo_id, filename="tokenizer_indic.json") | |
| model.tokenizer = EnTokenizer(tokenizer_path) | |
| except Exception as e: | |
| print(f"Error during tokenizer inject: {e}") | |
| # --- CRITICAL FIX: MANUALLY RESIZE PYTORCH EMBEDDINGS --- | |
| # We must resize the base model's vocabulary layers to match the new | |
| # Malayalam vocab size (2573) before loading the adapter weights. | |
| vocab_size = 2573 | |
| print(f"Resizing base model embeddings to handle vocab size of {vocab_size}...") | |
| target_layer = model.t3 if hasattr(model, 't3') else model | |
| if hasattr(target_layer, 'text_emb'): | |
| embed_dim = target_layer.text_emb.embedding_dim | |
| target_layer.text_emb = nn.Embedding(vocab_size, embed_dim) | |
| if hasattr(target_layer, 'text_head'): | |
| in_features = target_layer.text_head.in_features | |
| has_bias = target_layer.text_head.bias is not None | |
| target_layer.text_head = nn.Linear(in_features, vocab_size, bias=has_bias) | |
| # Send resized layers to the correct device | |
| target_layer.to(device) | |
| print("Loading LoRA adapter weights...") | |
| try: | |
| if hasattr(model, 't3'): | |
| model.t3 = PeftModel.from_pretrained(model.t3, repo_id) | |
| else: | |
| model = PeftModel.from_pretrained(model, repo_id) | |
| print("LoRA adapter loaded successfully.") | |
| except Exception as e: | |
| print(f"Failed to load PEFT adapter: {e}") | |
| return model | |
| # Initialize Model | |
| tts_model = load_model() | |
| def synthesize_audio(text, ref_audio, exaggeration, cfg_weight): | |
| if not text.strip(): | |
| return None, "Please enter some text." | |
| audio_prompt_path = ref_audio if ref_audio else None | |
| try: | |
| wav = tts_model.generate( | |
| text, | |
| audio_prompt_path=audio_prompt_path, | |
| exaggeration=exaggeration, | |
| cfg_weight=cfg_weight | |
| ) | |
| temp_out = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") | |
| ta.save(temp_out.name, wav.cpu(), tts_model.sr) | |
| return temp_out.name, "Generation successful!" | |
| except Exception as e: | |
| return None, f"Generation Error: {str(e)}" | |
| # Define the Gradio Interface | |
| with gr.Blocks(title="PrahaTTS-ML: Malayalam TTS", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 🗣️ PrahaTTS-ML: Malayalam LoRA Adapter") | |
| gr.Markdown( | |
| "This Space runs the [Praha-Labs/PrahaTTS-ML](https://huggingface.co/Praha-Labs/PrahaTTS-ML) model. " | |
| "It is a Malayalam LoRA adapter built on top of ResembleAI's Chatterbox non-turbo TTS model. \n\n" | |
| "**Note**: Provide up to 5-10 seconds of clear reference audio for voice cloning capabilities." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| text_input = gr.Textbox( | |
| label="Input Text (Malayalam)", | |
| lines=4, | |
| placeholder="നമസ്കാരം, നിങ്ങൾക്കെങ്ങനെയുണ്ട്?" | |
| ) | |
| ref_audio_input = gr.Audio( | |
| label="Reference Voice Audio (Optional, for Voice Cloning)", | |
| type="filepath" | |
| ) | |
| with gr.Accordion("Advanced Voice Controls", open=False): | |
| exaggeration_slider = gr.Slider( | |
| minimum=0.0, maximum=1.0, value=0.5, step=0.05, | |
| label="Emotion Exaggeration", | |
| info="Lower for monotone, higher for dramatic/expressive" | |
| ) | |
| cfg_slider = gr.Slider( | |
| minimum=0.0, maximum=1.0, value=0.5, step=0.05, | |
| label="CFG Weight", | |
| info="Lower if speech is too fast, higher to strictly mimic the reference voice" | |
| ) | |
| generate_btn = gr.Button("Synthesize Speech", variant="primary") | |
| with gr.Column(): | |
| audio_output = gr.Audio(label="Generated Output", interactive=False) | |
| status_output = gr.Textbox(label="Status Logging", interactive=False) | |
| generate_btn.click( | |
| fn=synthesize_audio, | |
| inputs=[text_input, ref_audio_input, exaggeration_slider, cfg_slider], | |
| outputs=[audio_output, status_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |