File size: 10,092 Bytes
61ee693 632183b 33d01bc 632183b a6bef00 61ee693 5b50bdc 632183b f750d21 1c76d21 632183b 1c76d21 d3c2948 2a22d6f 632183b 067912c 632183b f7eb188 632183b d53f049 abcb225 d53f049 632183b abcb225 632183b abcb225 632183b 2a22d6f d3c2948 2a22d6f d3c2948 2a22d6f 61ee693 f750d21 61ee693 d3c2948 61ee693 d3c2948 e0fba39 f750d21 6469bde 2a22d6f 48cb3b4 2a22d6f 7da8398 48cb3b4 d3c2948 2a22d6f f750d21 2a22d6f d3c2948 632183b d3c2948 2a22d6f 5b50bdc 61ee693 d3c2948 2a22d6f 61ee693 2a22d6f 61ee693 7da8398 d3c2948 7da8398 dda3deb 7da8398 dda3deb 2a22d6f dda3deb 2a22d6f dda3deb 48cb3b4 dda3deb 61ee693 48cb3b4 dda3deb 61ee693 3b50d55 61ee693 dda3deb 2a22d6f d3c2948 5b50bdc 632183b bf328a0 632183b aa24966 067912c 4267bdb 067912c ea4e599 067912c ea4e599 edd0ea2 48cb3b4 edd0ea2 067912c bf328a0 ea4e599 067912c aa24966 067912c bf328a0 067912c ea4e599 632183b 067912c 632183b 2a22d6f 68eb694 aa24966 2a22d6f aa24966 68eb694 aa24966 2a22d6f aa24966 2a22d6f aa24966 2a22d6f bf328a0 f750d21 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 |
import gradio as gr
from pathlib import Path
import torch
import urllib.request
import os
# HuggingFace Spaces GPU support
try:
import spaces
SPACES_AVAILABLE = True
except ImportError:
SPACES_AVAILABLE = False
print("[!] spaces module not available, running without GPU decorator")
import soundfile as sf
import traceback
import huggingface_hub
from huggingface_hub import hf_hub_download
# Patch for older diffusers compatibility with newer huggingface_hub
if not hasattr(huggingface_hub, "cached_download"):
huggingface_hub.cached_download = hf_hub_download
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from matcha.models.matcha_tts import MatchaTTS
from matcha.hifigan.models import Generator as HiFiGAN
from matcha.hifigan.config import v1
from matcha.hifigan.env import AttrDict
from matcha.text import text_to_sequence
from matcha.utils.utils import intersperse
HF_TOKEN = os.getenv("HF_TOKEN")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_REPO = "GAASH-Lab/Matcha-TTS-Kashmiri"
def load_models():
print("[*] Downloading GAASH-Lab checkpoint...")
ckpt = hf_hub_download(repo_id=MODEL_REPO, filename="model.ckpt", token=HF_TOKEN)
model = MatchaTTS.load_from_checkpoint(ckpt, map_location=DEVICE, weights_only=False)
model.eval()
print("[*] Loading HiFi-GAN vocoder...")
# The file 'generator_v1' is what the code calls 'hifigan_T2_v1'
# We download it from the official GitHub release if not found locally
vocoder_path = Path("hifigan_T2_v1")
if not vocoder_path.exists():
url = "https://github.com/shivammehta25/Matcha-TTS-checkpoints/releases/download/v1.0/generator_v1"
urllib.request.urlretrieve(url, vocoder_path)
vocoder = HiFiGAN(AttrDict(v1)).to(DEVICE)
state_dict = torch.load(vocoder_path, map_location=DEVICE)
vocoder.load_state_dict(state_dict['generator'])
vocoder.eval()
vocoder.remove_weight_norm()
return model, vocoder
# Translation Config
TRANSLATION_BASE_MODEL = "sarvamai/sarvam-translate"
TRANSLATION_ADAPTER = "GAASH-Lab/Sarvam-Kashmiri-finetuned"
# Global cache for translation model (loaded lazily when GPU is available)
_trans_cache = {"tokenizer": None, "model": None, "loaded": False}
def load_translation_models():
"""Load translation model lazily on first use (CPU deployment)."""
global _trans_cache
if _trans_cache["loaded"]:
return _trans_cache["tokenizer"], _trans_cache["model"]
print("[*] Loading Sarvam Translate Adapter (CPU mode)...")
try:
# Load the tokenizer with left padding (required for causal LM)
tokenizer = AutoTokenizer.from_pretrained(TRANSLATION_BASE_MODEL, trust_remote_code=True)
tokenizer.padding_side = "left"
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load the base model on CPU with bfloat16 to reduce memory
# bfloat16 is better supported on CPU than float16
print("[*] Loading base model on CPU (bfloat16)...")
base_model = AutoModelForCausalLM.from_pretrained(
TRANSLATION_BASE_MODEL,
torch_dtype=torch.bfloat16,
device_map="cpu",
low_cpu_mem_usage=True,
trust_remote_code=True
)
# Load the LoRA adapter
print("[*] Loading LoRA adapter...")
model = PeftModel.from_pretrained(base_model, TRANSLATION_ADAPTER)
# Merge LoRA weights into base model for faster inference
# This eliminates adapter overhead during generation
print("[*] Merging LoRA weights for faster inference...")
model = model.merge_and_unload()
model.eval()
print(f"[+] Translation model loaded and merged successfully on CPU.")
_trans_cache["tokenizer"] = tokenizer
_trans_cache["model"] = model
_trans_cache["loaded"] = True
return tokenizer, model
except Exception as e:
print(f"[-] Error loading translation model: {e}")
traceback.print_exc()
return None, None
# Load TTS models at startup (they're smaller)
model, vocoder = load_models()
# Translation model will be loaded lazily when GPU is available
def _translate_impl(text):
"""Internal translation implementation - matching evaluate_model.py approach."""
# Load model lazily (will be cached after first load)
trans_tokenizer, trans_model = load_translation_models()
if trans_model is None:
return "Translation model unavailable."
# Build chat messages (matching evaluate_model.py)
messages = [
{"role": "system", "content": "Translate the text below to Kashmiri."},
{"role": "user", "content": text},
]
try:
# Apply chat template (matching evaluate_model.py)
prompt = trans_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = trans_tokenizer(prompt, padding=True, truncation=True, max_length=512, return_tensors="pt")
# Move inputs to model's device
inputs = {k: v.to(trans_model.device) for k, v in inputs.items()}
print(f"[DEBUG] Input tokens: {inputs['input_ids'].shape[1]}")
except Exception as e:
print(f"Chat template error: {e}")
traceback.print_exc()
return "Error in translation template."
try:
import time
start_time = time.time()
print("[DEBUG] Starting generation...")
# Generation settings optimized for CPU inference
# - Greedy decoding (do_sample=False) is faster than sampling
# - Same quality as temp=0.01 which was near-greedy anyway
with torch.no_grad():
generated = trans_model.generate(
**inputs,
max_new_tokens=256, # Keep full length for long texts
do_sample=False, # Greedy decoding for speed
num_beams=1, # No beam search overhead
)
elapsed = time.time() - start_time
print(f"[DEBUG] Generation completed in {elapsed:.2f}s")
# Decode only the new tokens (matching evaluate_model.py)
input_len = inputs['input_ids'].shape[1]
output_ids = generated[0][input_len:]
decoded = trans_tokenizer.decode(output_ids, skip_special_tokens=True).replace("\n", "")
print(f"[DEBUG] New tokens: {len(output_ids)}")
print(f"[DEBUG] Decoded: '{decoded}'")
return decoded.strip()
except Exception as e:
print(f"Generation error: {e}")
traceback.print_exc()
return "Error during translation generation."
# Simple wrapper function for CPU deployment
def translate(text):
return _translate_impl(text)
# --- Update the function signature to accept two arguments ---
@torch.inference_mode()
def process(text, speaker_id, n_timesteps=10):
# 1. Kashmiri script normalization
text = text.replace("ي", "ی").replace("ك", "ک").replace("۔", "").strip()
# 2. Text to Sequence
cleaner = "basic_cleaners"
sequence, _ = text_to_sequence(text, [cleaner])
# Filter out any non-integer values (unknown characters not in vocabulary)
# This happens when text contains characters not supported by the TTS model
filtered_sequence = [s for s in sequence if isinstance(s, int)]
x = torch.tensor(intersperse(filtered_sequence, 0), dtype=torch.long, device=DEVICE)[None]
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=DEVICE)
# 3. Use the Speaker ID from the interface
# Even if you only use one voice, the model requires this tensor
spks = torch.tensor([int(speaker_id)], device=DEVICE, dtype=torch.long)
# 4. Generate Mel-spectrogram
output = model.synthesise(
x,
x_lengths,
n_timesteps=n_timesteps,
temperature=0.667,
spks=spks,
length_scale=1.0
)
# 5. Generate Waveform
audio = vocoder(output['mel']).clamp(-1, 1).cpu().squeeze().numpy()
output_path = "out.wav"
sf.write(output_path, audio, 22050)
return output_path
# --- Gradio UI with Translation Option ---
with gr.Blocks(title="GAASH-Lab: Kashmiri TTS & Translation") as demo:
gr.Markdown("# GAASH-Lab: Kashmiri TTS & Translation")
gr.Markdown("Enter text in English (check the box) or Kashmiri directly.")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(label="Input Text", placeholder="Type here...")
is_english = gr.Checkbox(label="Input is English (Translate first)", value=False)
speaker_radio = gr.Radio(choices=["Male", "Female"], value="Male", label="Speaker Voice")
quality_radio = gr.Radio(
choices=["Low (fast)", "Medium", "High"],
value="Low (fast)",
label="Quality"
)
gen_btn = gr.Button("Generate Speech", variant="primary")
with gr.Column():
trans_view = gr.Textbox(label="Processed/Translated Kashmiri Text", interactive=False)
audio_output = gr.Audio(label="Audio", type="filepath")
def pipeline(text, is_eng, spk_voice, quality):
spk_id = 422 if spk_voice == "Male" else 423
if "Low" in quality:
steps = 10
elif "Medium" in quality:
steps = 50
else:
steps = 500
processed_text = text
if is_eng:
print(f"Translating input: {text}")
processed_text = translate(text)
print(f"Synthesizing for: {processed_text}")
audio_path = process(processed_text, spk_id, steps)
return processed_text, audio_path
gen_btn.click(
pipeline,
inputs=[input_text, is_english, speaker_radio, quality_radio],
outputs=[trans_view, audio_output]
)
demo.launch(ssr_mode=False) |