import gradio as gr from pathlib import Path import torch import urllib.request import os # HuggingFace Spaces GPU support try: import spaces SPACES_AVAILABLE = True except ImportError: SPACES_AVAILABLE = False print("[!] spaces module not available, running without GPU decorator") import soundfile as sf import traceback import huggingface_hub from huggingface_hub import hf_hub_download # Patch for older diffusers compatibility with newer huggingface_hub if not hasattr(huggingface_hub, "cached_download"): huggingface_hub.cached_download = hf_hub_download from transformers import AutoTokenizer, AutoModelForCausalLM from peft import PeftModel from matcha.models.matcha_tts import MatchaTTS from matcha.hifigan.models import Generator as HiFiGAN from matcha.hifigan.config import v1 from matcha.hifigan.env import AttrDict from matcha.text import text_to_sequence from matcha.utils.utils import intersperse HF_TOKEN = os.getenv("HF_TOKEN") DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") MODEL_REPO = "GAASH-Lab/Matcha-TTS-Kashmiri" def load_models(): print("[*] Downloading GAASH-Lab checkpoint...") ckpt = hf_hub_download(repo_id=MODEL_REPO, filename="model.ckpt", token=HF_TOKEN) model = MatchaTTS.load_from_checkpoint(ckpt, map_location=DEVICE, weights_only=False) model.eval() print("[*] Loading HiFi-GAN vocoder...") # The file 'generator_v1' is what the code calls 'hifigan_T2_v1' # We download it from the official GitHub release if not found locally vocoder_path = Path("hifigan_T2_v1") if not vocoder_path.exists(): url = "https://github.com/shivammehta25/Matcha-TTS-checkpoints/releases/download/v1.0/generator_v1" urllib.request.urlretrieve(url, vocoder_path) vocoder = HiFiGAN(AttrDict(v1)).to(DEVICE) state_dict = torch.load(vocoder_path, map_location=DEVICE) vocoder.load_state_dict(state_dict['generator']) vocoder.eval() vocoder.remove_weight_norm() return model, vocoder # Translation Config TRANSLATION_BASE_MODEL = "sarvamai/sarvam-translate" TRANSLATION_ADAPTER = "GAASH-Lab/Sarvam-Kashmiri-finetuned" # Global cache for translation model (loaded lazily when GPU is available) _trans_cache = {"tokenizer": None, "model": None, "loaded": False} def load_translation_models(): """Load translation model lazily on first use (CPU deployment).""" global _trans_cache if _trans_cache["loaded"]: return _trans_cache["tokenizer"], _trans_cache["model"] print("[*] Loading Sarvam Translate Adapter (CPU mode)...") try: # Load the tokenizer with left padding (required for causal LM) tokenizer = AutoTokenizer.from_pretrained(TRANSLATION_BASE_MODEL, trust_remote_code=True) tokenizer.padding_side = "left" if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Load the base model on CPU with bfloat16 to reduce memory # bfloat16 is better supported on CPU than float16 print("[*] Loading base model on CPU (bfloat16)...") base_model = AutoModelForCausalLM.from_pretrained( TRANSLATION_BASE_MODEL, torch_dtype=torch.bfloat16, device_map="cpu", low_cpu_mem_usage=True, trust_remote_code=True ) # Load the LoRA adapter print("[*] Loading LoRA adapter...") model = PeftModel.from_pretrained(base_model, TRANSLATION_ADAPTER) # Merge LoRA weights into base model for faster inference # This eliminates adapter overhead during generation print("[*] Merging LoRA weights for faster inference...") model = model.merge_and_unload() model.eval() print(f"[+] Translation model loaded and merged successfully on CPU.") _trans_cache["tokenizer"] = tokenizer _trans_cache["model"] = model _trans_cache["loaded"] = True return tokenizer, model except Exception as e: print(f"[-] Error loading translation model: {e}") traceback.print_exc() return None, None # Load TTS models at startup (they're smaller) model, vocoder = load_models() # Translation model will be loaded lazily when GPU is available def _translate_impl(text): """Internal translation implementation - matching evaluate_model.py approach.""" # Load model lazily (will be cached after first load) trans_tokenizer, trans_model = load_translation_models() if trans_model is None: return "Translation model unavailable." # Build chat messages (matching evaluate_model.py) messages = [ {"role": "system", "content": "Translate the text below to Kashmiri."}, {"role": "user", "content": text}, ] try: # Apply chat template (matching evaluate_model.py) prompt = trans_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = trans_tokenizer(prompt, padding=True, truncation=True, max_length=512, return_tensors="pt") # Move inputs to model's device inputs = {k: v.to(trans_model.device) for k, v in inputs.items()} print(f"[DEBUG] Input tokens: {inputs['input_ids'].shape[1]}") except Exception as e: print(f"Chat template error: {e}") traceback.print_exc() return "Error in translation template." try: import time start_time = time.time() print("[DEBUG] Starting generation...") # Generation settings optimized for CPU inference # - Greedy decoding (do_sample=False) is faster than sampling # - Same quality as temp=0.01 which was near-greedy anyway with torch.no_grad(): generated = trans_model.generate( **inputs, max_new_tokens=256, # Keep full length for long texts do_sample=False, # Greedy decoding for speed num_beams=1, # No beam search overhead ) elapsed = time.time() - start_time print(f"[DEBUG] Generation completed in {elapsed:.2f}s") # Decode only the new tokens (matching evaluate_model.py) input_len = inputs['input_ids'].shape[1] output_ids = generated[0][input_len:] decoded = trans_tokenizer.decode(output_ids, skip_special_tokens=True).replace("\n", "") print(f"[DEBUG] New tokens: {len(output_ids)}") print(f"[DEBUG] Decoded: '{decoded}'") return decoded.strip() except Exception as e: print(f"Generation error: {e}") traceback.print_exc() return "Error during translation generation." # Simple wrapper function for CPU deployment def translate(text): return _translate_impl(text) # --- Update the function signature to accept two arguments --- @torch.inference_mode() def process(text, speaker_id, n_timesteps=10): # 1. Kashmiri script normalization text = text.replace("ي", "ی").replace("ك", "ک").replace("۔", "").strip() # 2. Text to Sequence cleaner = "basic_cleaners" sequence, _ = text_to_sequence(text, [cleaner]) # Filter out any non-integer values (unknown characters not in vocabulary) # This happens when text contains characters not supported by the TTS model filtered_sequence = [s for s in sequence if isinstance(s, int)] x = torch.tensor(intersperse(filtered_sequence, 0), dtype=torch.long, device=DEVICE)[None] x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=DEVICE) # 3. Use the Speaker ID from the interface # Even if you only use one voice, the model requires this tensor spks = torch.tensor([int(speaker_id)], device=DEVICE, dtype=torch.long) # 4. Generate Mel-spectrogram output = model.synthesise( x, x_lengths, n_timesteps=n_timesteps, temperature=0.667, spks=spks, length_scale=1.0 ) # 5. Generate Waveform audio = vocoder(output['mel']).clamp(-1, 1).cpu().squeeze().numpy() output_path = "out.wav" sf.write(output_path, audio, 22050) return output_path # --- Gradio UI with Translation Option --- with gr.Blocks(title="GAASH-Lab: Kashmiri TTS & Translation") as demo: gr.Markdown("# GAASH-Lab: Kashmiri TTS & Translation") gr.Markdown("Enter text in English (check the box) or Kashmiri directly.") with gr.Row(): with gr.Column(): input_text = gr.Textbox(label="Input Text", placeholder="Type here...") is_english = gr.Checkbox(label="Input is English (Translate first)", value=False) speaker_radio = gr.Radio(choices=["Male", "Female"], value="Male", label="Speaker Voice") quality_radio = gr.Radio( choices=["Low (fast)", "Medium", "High"], value="Low (fast)", label="Quality" ) gen_btn = gr.Button("Generate Speech", variant="primary") with gr.Column(): trans_view = gr.Textbox(label="Processed/Translated Kashmiri Text", interactive=False) audio_output = gr.Audio(label="Audio", type="filepath") def pipeline(text, is_eng, spk_voice, quality): spk_id = 422 if spk_voice == "Male" else 423 if "Low" in quality: steps = 10 elif "Medium" in quality: steps = 50 else: steps = 500 processed_text = text if is_eng: print(f"Translating input: {text}") processed_text = translate(text) print(f"Synthesizing for: {processed_text}") audio_path = process(processed_text, spk_id, steps) return processed_text, audio_path gen_btn.click( pipeline, inputs=[input_text, is_english, speaker_radio, quality_radio], outputs=[trans_view, audio_output] ) demo.launch(ssr_mode=False)