saeedabdulmuizz's picture
Update app.py
aa24966 verified
import gradio as gr
from pathlib import Path
import torch
import urllib.request
import os
# HuggingFace Spaces GPU support
try:
import spaces
SPACES_AVAILABLE = True
except ImportError:
SPACES_AVAILABLE = False
print("[!] spaces module not available, running without GPU decorator")
import soundfile as sf
import traceback
import huggingface_hub
from huggingface_hub import hf_hub_download
# Patch for older diffusers compatibility with newer huggingface_hub
if not hasattr(huggingface_hub, "cached_download"):
huggingface_hub.cached_download = hf_hub_download
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from matcha.models.matcha_tts import MatchaTTS
from matcha.hifigan.models import Generator as HiFiGAN
from matcha.hifigan.config import v1
from matcha.hifigan.env import AttrDict
from matcha.text import text_to_sequence
from matcha.utils.utils import intersperse
HF_TOKEN = os.getenv("HF_TOKEN")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_REPO = "GAASH-Lab/Matcha-TTS-Kashmiri"
def load_models():
print("[*] Downloading GAASH-Lab checkpoint...")
ckpt = hf_hub_download(repo_id=MODEL_REPO, filename="model.ckpt", token=HF_TOKEN)
model = MatchaTTS.load_from_checkpoint(ckpt, map_location=DEVICE, weights_only=False)
model.eval()
print("[*] Loading HiFi-GAN vocoder...")
# The file 'generator_v1' is what the code calls 'hifigan_T2_v1'
# We download it from the official GitHub release if not found locally
vocoder_path = Path("hifigan_T2_v1")
if not vocoder_path.exists():
url = "https://github.com/shivammehta25/Matcha-TTS-checkpoints/releases/download/v1.0/generator_v1"
urllib.request.urlretrieve(url, vocoder_path)
vocoder = HiFiGAN(AttrDict(v1)).to(DEVICE)
state_dict = torch.load(vocoder_path, map_location=DEVICE)
vocoder.load_state_dict(state_dict['generator'])
vocoder.eval()
vocoder.remove_weight_norm()
return model, vocoder
# Translation Config
TRANSLATION_BASE_MODEL = "sarvamai/sarvam-translate"
TRANSLATION_ADAPTER = "GAASH-Lab/Sarvam-Kashmiri-finetuned"
# Global cache for translation model (loaded lazily when GPU is available)
_trans_cache = {"tokenizer": None, "model": None, "loaded": False}
def load_translation_models():
"""Load translation model lazily on first use (CPU deployment)."""
global _trans_cache
if _trans_cache["loaded"]:
return _trans_cache["tokenizer"], _trans_cache["model"]
print("[*] Loading Sarvam Translate Adapter (CPU mode)...")
try:
# Load the tokenizer with left padding (required for causal LM)
tokenizer = AutoTokenizer.from_pretrained(TRANSLATION_BASE_MODEL, trust_remote_code=True)
tokenizer.padding_side = "left"
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load the base model on CPU with bfloat16 to reduce memory
# bfloat16 is better supported on CPU than float16
print("[*] Loading base model on CPU (bfloat16)...")
base_model = AutoModelForCausalLM.from_pretrained(
TRANSLATION_BASE_MODEL,
torch_dtype=torch.bfloat16,
device_map="cpu",
low_cpu_mem_usage=True,
trust_remote_code=True
)
# Load the LoRA adapter
print("[*] Loading LoRA adapter...")
model = PeftModel.from_pretrained(base_model, TRANSLATION_ADAPTER)
# Merge LoRA weights into base model for faster inference
# This eliminates adapter overhead during generation
print("[*] Merging LoRA weights for faster inference...")
model = model.merge_and_unload()
model.eval()
print(f"[+] Translation model loaded and merged successfully on CPU.")
_trans_cache["tokenizer"] = tokenizer
_trans_cache["model"] = model
_trans_cache["loaded"] = True
return tokenizer, model
except Exception as e:
print(f"[-] Error loading translation model: {e}")
traceback.print_exc()
return None, None
# Load TTS models at startup (they're smaller)
model, vocoder = load_models()
# Translation model will be loaded lazily when GPU is available
def _translate_impl(text):
"""Internal translation implementation - matching evaluate_model.py approach."""
# Load model lazily (will be cached after first load)
trans_tokenizer, trans_model = load_translation_models()
if trans_model is None:
return "Translation model unavailable."
# Build chat messages (matching evaluate_model.py)
messages = [
{"role": "system", "content": "Translate the text below to Kashmiri."},
{"role": "user", "content": text},
]
try:
# Apply chat template (matching evaluate_model.py)
prompt = trans_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = trans_tokenizer(prompt, padding=True, truncation=True, max_length=512, return_tensors="pt")
# Move inputs to model's device
inputs = {k: v.to(trans_model.device) for k, v in inputs.items()}
print(f"[DEBUG] Input tokens: {inputs['input_ids'].shape[1]}")
except Exception as e:
print(f"Chat template error: {e}")
traceback.print_exc()
return "Error in translation template."
try:
import time
start_time = time.time()
print("[DEBUG] Starting generation...")
# Generation settings optimized for CPU inference
# - Greedy decoding (do_sample=False) is faster than sampling
# - Same quality as temp=0.01 which was near-greedy anyway
with torch.no_grad():
generated = trans_model.generate(
**inputs,
max_new_tokens=256, # Keep full length for long texts
do_sample=False, # Greedy decoding for speed
num_beams=1, # No beam search overhead
)
elapsed = time.time() - start_time
print(f"[DEBUG] Generation completed in {elapsed:.2f}s")
# Decode only the new tokens (matching evaluate_model.py)
input_len = inputs['input_ids'].shape[1]
output_ids = generated[0][input_len:]
decoded = trans_tokenizer.decode(output_ids, skip_special_tokens=True).replace("\n", "")
print(f"[DEBUG] New tokens: {len(output_ids)}")
print(f"[DEBUG] Decoded: '{decoded}'")
return decoded.strip()
except Exception as e:
print(f"Generation error: {e}")
traceback.print_exc()
return "Error during translation generation."
# Simple wrapper function for CPU deployment
def translate(text):
return _translate_impl(text)
# --- Update the function signature to accept two arguments ---
@torch.inference_mode()
def process(text, speaker_id, n_timesteps=10):
# 1. Kashmiri script normalization
text = text.replace("ي", "ی").replace("ك", "ک").replace("۔", "").strip()
# 2. Text to Sequence
cleaner = "basic_cleaners"
sequence, _ = text_to_sequence(text, [cleaner])
# Filter out any non-integer values (unknown characters not in vocabulary)
# This happens when text contains characters not supported by the TTS model
filtered_sequence = [s for s in sequence if isinstance(s, int)]
x = torch.tensor(intersperse(filtered_sequence, 0), dtype=torch.long, device=DEVICE)[None]
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=DEVICE)
# 3. Use the Speaker ID from the interface
# Even if you only use one voice, the model requires this tensor
spks = torch.tensor([int(speaker_id)], device=DEVICE, dtype=torch.long)
# 4. Generate Mel-spectrogram
output = model.synthesise(
x,
x_lengths,
n_timesteps=n_timesteps,
temperature=0.667,
spks=spks,
length_scale=1.0
)
# 5. Generate Waveform
audio = vocoder(output['mel']).clamp(-1, 1).cpu().squeeze().numpy()
output_path = "out.wav"
sf.write(output_path, audio, 22050)
return output_path
# --- Gradio UI with Translation Option ---
with gr.Blocks(title="GAASH-Lab: Kashmiri TTS & Translation") as demo:
gr.Markdown("# GAASH-Lab: Kashmiri TTS & Translation")
gr.Markdown("Enter text in English (check the box) or Kashmiri directly.")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(label="Input Text", placeholder="Type here...")
is_english = gr.Checkbox(label="Input is English (Translate first)", value=False)
speaker_radio = gr.Radio(choices=["Male", "Female"], value="Male", label="Speaker Voice")
quality_radio = gr.Radio(
choices=["Low (fast)", "Medium", "High"],
value="Low (fast)",
label="Quality"
)
gen_btn = gr.Button("Generate Speech", variant="primary")
with gr.Column():
trans_view = gr.Textbox(label="Processed/Translated Kashmiri Text", interactive=False)
audio_output = gr.Audio(label="Audio", type="filepath")
def pipeline(text, is_eng, spk_voice, quality):
spk_id = 422 if spk_voice == "Male" else 423
if "Low" in quality:
steps = 10
elif "Medium" in quality:
steps = 50
else:
steps = 500
processed_text = text
if is_eng:
print(f"Translating input: {text}")
processed_text = translate(text)
print(f"Synthesizing for: {processed_text}")
audio_path = process(processed_text, spk_id, steps)
return processed_text, audio_path
gen_btn.click(
pipeline,
inputs=[input_text, is_english, speaker_radio, quality_radio],
outputs=[trans_view, audio_output]
)
demo.launch(ssr_mode=False)