| |
|
| | import gradio as gr |
| | from pathlib import Path |
| | import torch |
| | import urllib.request |
| | import os |
| |
|
| | |
| | try: |
| | import spaces |
| | SPACES_AVAILABLE = True |
| | except ImportError: |
| | SPACES_AVAILABLE = False |
| | print("[!] spaces module not available, running without GPU decorator") |
| | import soundfile as sf |
| | import traceback |
| | import huggingface_hub |
| | from huggingface_hub import hf_hub_download |
| |
|
| | |
| | if not hasattr(huggingface_hub, "cached_download"): |
| | huggingface_hub.cached_download = hf_hub_download |
| |
|
| | from transformers import AutoTokenizer, AutoModelForCausalLM |
| | from peft import PeftModel |
| | from matcha.models.matcha_tts import MatchaTTS |
| | from matcha.hifigan.models import Generator as HiFiGAN |
| | from matcha.hifigan.config import v1 |
| | from matcha.hifigan.env import AttrDict |
| | from matcha.text import text_to_sequence |
| | from matcha.utils.utils import intersperse |
| |
|
| | HF_TOKEN = os.getenv("HF_TOKEN") |
| |
|
| | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | MODEL_REPO = "GAASH-Lab/Matcha-TTS-Kashmiri" |
| |
|
| | def load_models(): |
| | print("[*] Downloading GAASH-Lab checkpoint...") |
| | ckpt = hf_hub_download(repo_id=MODEL_REPO, filename="model.ckpt", token=HF_TOKEN) |
| | model = MatchaTTS.load_from_checkpoint(ckpt, map_location=DEVICE, weights_only=False) |
| | model.eval() |
| | |
| | print("[*] Loading HiFi-GAN vocoder...") |
| | |
| | |
| | vocoder_path = Path("hifigan_T2_v1") |
| | if not vocoder_path.exists(): |
| | url = "https://github.com/shivammehta25/Matcha-TTS-checkpoints/releases/download/v1.0/generator_v1" |
| | urllib.request.urlretrieve(url, vocoder_path) |
| |
|
| | vocoder = HiFiGAN(AttrDict(v1)).to(DEVICE) |
| | state_dict = torch.load(vocoder_path, map_location=DEVICE) |
| | vocoder.load_state_dict(state_dict['generator']) |
| | vocoder.eval() |
| | vocoder.remove_weight_norm() |
| | |
| | return model, vocoder |
| |
|
| | |
| | TRANSLATION_BASE_MODEL = "sarvamai/sarvam-translate" |
| | TRANSLATION_ADAPTER = "GAASH-Lab/Sarvam-Kashmiri-finetuned" |
| |
|
| | |
| | _trans_cache = {"tokenizer": None, "model": None, "loaded": False} |
| |
|
| | def load_translation_models(): |
| | """Load translation model lazily on first use (CPU deployment).""" |
| | global _trans_cache |
| | |
| | if _trans_cache["loaded"]: |
| | return _trans_cache["tokenizer"], _trans_cache["model"] |
| | |
| | print("[*] Loading Sarvam Translate Adapter (CPU mode)...") |
| | try: |
| | |
| | tokenizer = AutoTokenizer.from_pretrained(TRANSLATION_BASE_MODEL, trust_remote_code=True) |
| | tokenizer.padding_side = "left" |
| | if tokenizer.pad_token is None: |
| | tokenizer.pad_token = tokenizer.eos_token |
| | |
| | |
| | |
| | print("[*] Loading base model on CPU (bfloat16)...") |
| | base_model = AutoModelForCausalLM.from_pretrained( |
| | TRANSLATION_BASE_MODEL, |
| | torch_dtype=torch.bfloat16, |
| | device_map="cpu", |
| | low_cpu_mem_usage=True, |
| | trust_remote_code=True |
| | ) |
| | |
| | |
| | print("[*] Loading LoRA adapter...") |
| | model = PeftModel.from_pretrained(base_model, TRANSLATION_ADAPTER) |
| | |
| | |
| | |
| | print("[*] Merging LoRA weights for faster inference...") |
| | model = model.merge_and_unload() |
| | model.eval() |
| | |
| | print(f"[+] Translation model loaded and merged successfully on CPU.") |
| | _trans_cache["tokenizer"] = tokenizer |
| | _trans_cache["model"] = model |
| | _trans_cache["loaded"] = True |
| | return tokenizer, model |
| | except Exception as e: |
| | print(f"[-] Error loading translation model: {e}") |
| | traceback.print_exc() |
| | return None, None |
| |
|
| | |
| | model, vocoder = load_models() |
| | |
| |
|
| | def _translate_impl(text): |
| | """Internal translation implementation - matching evaluate_model.py approach.""" |
| | |
| | trans_tokenizer, trans_model = load_translation_models() |
| | |
| | if trans_model is None: |
| | return "Translation model unavailable." |
| | |
| | |
| | messages = [ |
| | {"role": "system", "content": "Translate the text below to Kashmiri."}, |
| | {"role": "user", "content": text}, |
| | ] |
| | |
| | try: |
| | |
| | prompt = trans_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| | inputs = trans_tokenizer(prompt, padding=True, truncation=True, max_length=512, return_tensors="pt") |
| | |
| | |
| | inputs = {k: v.to(trans_model.device) for k, v in inputs.items()} |
| | |
| | print(f"[DEBUG] Input tokens: {inputs['input_ids'].shape[1]}") |
| | |
| | except Exception as e: |
| | print(f"Chat template error: {e}") |
| | traceback.print_exc() |
| | return "Error in translation template." |
| |
|
| | try: |
| | import time |
| | start_time = time.time() |
| | print("[DEBUG] Starting generation...") |
| | |
| | |
| | |
| | |
| | with torch.no_grad(): |
| | generated = trans_model.generate( |
| | **inputs, |
| | max_new_tokens=256, |
| | do_sample=False, |
| | num_beams=1, |
| | ) |
| | |
| | elapsed = time.time() - start_time |
| | print(f"[DEBUG] Generation completed in {elapsed:.2f}s") |
| | |
| | |
| | input_len = inputs['input_ids'].shape[1] |
| | output_ids = generated[0][input_len:] |
| | decoded = trans_tokenizer.decode(output_ids, skip_special_tokens=True).replace("\n", "") |
| | |
| | print(f"[DEBUG] New tokens: {len(output_ids)}") |
| | print(f"[DEBUG] Decoded: '{decoded}'") |
| | |
| | return decoded.strip() |
| | |
| | except Exception as e: |
| | print(f"Generation error: {e}") |
| | traceback.print_exc() |
| | return "Error during translation generation." |
| |
|
| | |
| | def translate(text): |
| | return _translate_impl(text) |
| |
|
| |
|
| | |
| | @torch.inference_mode() |
| | def process(text, speaker_id, n_timesteps=10): |
| | |
| | text = text.replace("ي", "ی").replace("ك", "ک").replace("۔", "").strip() |
| | |
| | |
| | cleaner = "basic_cleaners" |
| | sequence, _ = text_to_sequence(text, [cleaner]) |
| | |
| | |
| | |
| | filtered_sequence = [s for s in sequence if isinstance(s, int)] |
| | |
| | x = torch.tensor(intersperse(filtered_sequence, 0), dtype=torch.long, device=DEVICE)[None] |
| | x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=DEVICE) |
| | |
| | |
| | |
| | spks = torch.tensor([int(speaker_id)], device=DEVICE, dtype=torch.long) |
| | |
| | |
| | output = model.synthesise( |
| | x, |
| | x_lengths, |
| | n_timesteps=n_timesteps, |
| | temperature=0.667, |
| | spks=spks, |
| | length_scale=1.0 |
| | ) |
| | |
| | |
| | audio = vocoder(output['mel']).clamp(-1, 1).cpu().squeeze().numpy() |
| | output_path = "out.wav" |
| | sf.write(output_path, audio, 22050) |
| | return output_path |
| |
|
| | |
| | with gr.Blocks(title="GAASH-Lab: Kashmiri TTS & Translation") as demo: |
| | gr.Markdown("# GAASH-Lab: Kashmiri TTS & Translation") |
| | gr.Markdown("Enter text in English (check the box) or Kashmiri directly.") |
| | |
| | with gr.Row(): |
| | with gr.Column(): |
| | input_text = gr.Textbox(label="Input Text", placeholder="Type here...") |
| | is_english = gr.Checkbox(label="Input is English (Translate first)", value=False) |
| | speaker_radio = gr.Radio(choices=["Male", "Female"], value="Male", label="Speaker Voice") |
| | quality_radio = gr.Radio( |
| | choices=["Low (fast)", "Medium", "High"], |
| | value="Low (fast)", |
| | label="Quality" |
| | ) |
| | gen_btn = gr.Button("Generate Speech", variant="primary") |
| | |
| | with gr.Column(): |
| | trans_view = gr.Textbox(label="Processed/Translated Kashmiri Text", interactive=False) |
| | audio_output = gr.Audio(label="Audio", type="filepath") |
| |
|
| | def pipeline(text, is_eng, spk_voice, quality): |
| | spk_id = 422 if spk_voice == "Male" else 423 |
| | |
| | if "Low" in quality: |
| | steps = 10 |
| | elif "Medium" in quality: |
| | steps = 50 |
| | else: |
| | steps = 500 |
| |
|
| | processed_text = text |
| | if is_eng: |
| | print(f"Translating input: {text}") |
| | processed_text = translate(text) |
| | |
| | print(f"Synthesizing for: {processed_text}") |
| | audio_path = process(processed_text, spk_id, steps) |
| | return processed_text, audio_path |
| |
|
| | gen_btn.click( |
| | pipeline, |
| | inputs=[input_text, is_english, speaker_radio, quality_radio], |
| | outputs=[trans_view, audio_output] |
| | ) |
| |
|
| | demo.launch(ssr_mode=False) |