import subprocess import sys import os import tempfile from huggingface_hub import hf_hub_download # --- 1. PRE-FLIGHT: BYPASS BUILD ISOLATION --- def pre_flight_setup(): try: import chatterbox except ImportError: print("Applying pkuseg build isolation bypass...") subprocess.check_call([sys.executable, "-m", "pip", "install", "numpy<2.0.0", "cython", "wheel"]) subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-build-isolation", "pkuseg==0.0.25"]) subprocess.check_call([sys.executable, "-m", "pip", "install", "chatterbox-tts>=0.1.7"]) pre_flight_setup() # --- 2. MAIN APPLICATION --- import gradio as gr import torch import torch.nn as nn import torchaudio as ta from peft import PeftModel from chatterbox.tts import ChatterboxTTS from chatterbox.models.tokenizers import EnTokenizer device = "cuda" if torch.cuda.is_available() else "cpu" repo_id = "Praha-Labs/PrahaTTS-ML" def load_model(): print(f"Loading base Chatterbox model on {device}...") model = ChatterboxTTS.from_pretrained(device=device) print("Applying custom Indic tokenizer...") try: tokenizer_path = hf_hub_download(repo_id=repo_id, filename="tokenizer_indic.json") model.tokenizer = EnTokenizer(tokenizer_path) except Exception as e: print(f"Error during tokenizer inject: {e}") # --- CRITICAL FIX: MANUALLY RESIZE PYTORCH EMBEDDINGS --- # We must resize the base model's vocabulary layers to match the new # Malayalam vocab size (2573) before loading the adapter weights. vocab_size = 2573 print(f"Resizing base model embeddings to handle vocab size of {vocab_size}...") target_layer = model.t3 if hasattr(model, 't3') else model if hasattr(target_layer, 'text_emb'): embed_dim = target_layer.text_emb.embedding_dim target_layer.text_emb = nn.Embedding(vocab_size, embed_dim) if hasattr(target_layer, 'text_head'): in_features = target_layer.text_head.in_features has_bias = target_layer.text_head.bias is not None target_layer.text_head = nn.Linear(in_features, vocab_size, bias=has_bias) # Send resized layers to the correct device target_layer.to(device) print("Loading LoRA adapter weights...") try: if hasattr(model, 't3'): model.t3 = PeftModel.from_pretrained(model.t3, repo_id) else: model = PeftModel.from_pretrained(model, repo_id) print("LoRA adapter loaded successfully.") except Exception as e: print(f"Failed to load PEFT adapter: {e}") return model # Initialize Model tts_model = load_model() def synthesize_audio(text, ref_audio, exaggeration, cfg_weight): if not text.strip(): return None, "Please enter some text." audio_prompt_path = ref_audio if ref_audio else None try: wav = tts_model.generate( text, audio_prompt_path=audio_prompt_path, exaggeration=exaggeration, cfg_weight=cfg_weight ) temp_out = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") ta.save(temp_out.name, wav.cpu(), tts_model.sr) return temp_out.name, "Generation successful!" except Exception as e: return None, f"Generation Error: {str(e)}" # Define the Gradio Interface with gr.Blocks(title="PrahaTTS-ML: Malayalam TTS", theme=gr.themes.Soft()) as demo: gr.Markdown("# ๐Ÿ—ฃ๏ธ PrahaTTS-ML: Malayalam LoRA Adapter") gr.Markdown( "This Space runs the [Praha-Labs/PrahaTTS-ML](https://huggingface.co/Praha-Labs/PrahaTTS-ML) model. " "It is a Malayalam LoRA adapter built on top of ResembleAI's Chatterbox non-turbo TTS model. \n\n" "**Note**: Provide up to 5-10 seconds of clear reference audio for voice cloning capabilities." ) with gr.Row(): with gr.Column(): text_input = gr.Textbox( label="Input Text (Malayalam)", lines=4, placeholder="เดจเดฎเดธเตเด•เดพเดฐเด‚, เดจเดฟเด™เตเด™เตพเด•เตเด•เต†เด™เตเด™เดจเต†เดฏเตเดฃเตเดŸเต?" ) ref_audio_input = gr.Audio( label="Reference Voice Audio (Optional, for Voice Cloning)", type="filepath" ) with gr.Accordion("Advanced Voice Controls", open=False): exaggeration_slider = gr.Slider( minimum=0.0, maximum=1.0, value=0.5, step=0.05, label="Emotion Exaggeration", info="Lower for monotone, higher for dramatic/expressive" ) cfg_slider = gr.Slider( minimum=0.0, maximum=1.0, value=0.5, step=0.05, label="CFG Weight", info="Lower if speech is too fast, higher to strictly mimic the reference voice" ) generate_btn = gr.Button("Synthesize Speech", variant="primary") with gr.Column(): audio_output = gr.Audio(label="Generated Output", interactive=False) status_output = gr.Textbox(label="Status Logging", interactive=False) generate_btn.click( fn=synthesize_audio, inputs=[text_input, ref_audio_input, exaggeration_slider, cfg_slider], outputs=[audio_output, status_output] ) if __name__ == "__main__": demo.launch()