File size: 10,092 Bytes
61ee693
632183b
33d01bc
632183b
a6bef00
61ee693
5b50bdc
 
 
 
 
 
 
 
632183b
f750d21
1c76d21
632183b
1c76d21
 
 
 
 
d3c2948
2a22d6f
632183b
 
 
 
067912c
 
632183b
f7eb188
 
632183b
 
 
 
d53f049
 
 
 
abcb225
 
 
 
 
 
 
 
d53f049
632183b
abcb225
 
632183b
 
abcb225
632183b
 
2a22d6f
 
 
 
d3c2948
 
 
2a22d6f
d3c2948
 
 
 
 
 
 
2a22d6f
61ee693
f750d21
61ee693
 
 
 
d3c2948
 
 
61ee693
d3c2948
 
 
 
e0fba39
 
 
f750d21
6469bde
2a22d6f
48cb3b4
 
 
 
 
2a22d6f
7da8398
48cb3b4
d3c2948
 
 
2a22d6f
 
f750d21
 
2a22d6f
 
d3c2948
632183b
d3c2948
2a22d6f
5b50bdc
61ee693
d3c2948
 
 
2a22d6f
 
 
61ee693
2a22d6f
 
 
 
 
 
61ee693
 
7da8398
 
d3c2948
7da8398
dda3deb
7da8398
dda3deb
2a22d6f
 
dda3deb
2a22d6f
 
dda3deb
 
 
 
 
48cb3b4
 
 
dda3deb
61ee693
 
48cb3b4
 
 
dda3deb
 
 
 
61ee693
 
3b50d55
61ee693
 
 
 
 
 
 
dda3deb
 
 
 
 
2a22d6f
d3c2948
 
 
5b50bdc
632183b
bf328a0
632183b
aa24966
067912c
4267bdb
067912c
ea4e599
067912c
ea4e599
edd0ea2
 
 
 
48cb3b4
edd0ea2
067912c
 
bf328a0
 
 
ea4e599
 
067912c
 
 
aa24966
067912c
bf328a0
067912c
 
 
ea4e599
632183b
067912c
 
 
632183b
2a22d6f
 
 
 
 
 
 
 
 
68eb694
aa24966
 
 
 
 
2a22d6f
 
 
 
 
 
aa24966
68eb694
aa24966
 
 
 
 
 
 
 
2a22d6f
 
 
 
 
 
aa24966
2a22d6f
 
 
 
aa24966
2a22d6f
 
bf328a0
f750d21
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268

import gradio as gr
from pathlib import Path
import torch
import urllib.request
import os  

# HuggingFace Spaces GPU support
try:
    import spaces
    SPACES_AVAILABLE = True
except ImportError:
    SPACES_AVAILABLE = False
    print("[!] spaces module not available, running without GPU decorator")
import soundfile as sf
import traceback
import huggingface_hub
from huggingface_hub import hf_hub_download

# Patch for older diffusers compatibility with newer huggingface_hub
if not hasattr(huggingface_hub, "cached_download"):
    huggingface_hub.cached_download = hf_hub_download

from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from matcha.models.matcha_tts import MatchaTTS
from matcha.hifigan.models import Generator as HiFiGAN
from matcha.hifigan.config import v1
from matcha.hifigan.env import AttrDict
from matcha.text import text_to_sequence
from matcha.utils.utils import intersperse

HF_TOKEN = os.getenv("HF_TOKEN")

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_REPO = "GAASH-Lab/Matcha-TTS-Kashmiri"

def load_models():
    print("[*] Downloading GAASH-Lab checkpoint...")
    ckpt = hf_hub_download(repo_id=MODEL_REPO, filename="model.ckpt", token=HF_TOKEN)
    model = MatchaTTS.load_from_checkpoint(ckpt, map_location=DEVICE, weights_only=False)
    model.eval()
    
    print("[*] Loading HiFi-GAN vocoder...")
    # The file 'generator_v1' is what the code calls 'hifigan_T2_v1'
    # We download it from the official GitHub release if not found locally
    vocoder_path = Path("hifigan_T2_v1")
    if not vocoder_path.exists():
        url = "https://github.com/shivammehta25/Matcha-TTS-checkpoints/releases/download/v1.0/generator_v1"
        urllib.request.urlretrieve(url, vocoder_path)

    vocoder = HiFiGAN(AttrDict(v1)).to(DEVICE)
    state_dict = torch.load(vocoder_path, map_location=DEVICE)
    vocoder.load_state_dict(state_dict['generator'])
    vocoder.eval()
    vocoder.remove_weight_norm()
    
    return model, vocoder

# Translation Config
TRANSLATION_BASE_MODEL = "sarvamai/sarvam-translate"
TRANSLATION_ADAPTER = "GAASH-Lab/Sarvam-Kashmiri-finetuned"

# Global cache for translation model (loaded lazily when GPU is available)
_trans_cache = {"tokenizer": None, "model": None, "loaded": False}

def load_translation_models():
    """Load translation model lazily on first use (CPU deployment)."""
    global _trans_cache
    
    if _trans_cache["loaded"]:
        return _trans_cache["tokenizer"], _trans_cache["model"]
    
    print("[*] Loading Sarvam Translate Adapter (CPU mode)...")
    try:
        # Load the tokenizer with left padding (required for causal LM)
        tokenizer = AutoTokenizer.from_pretrained(TRANSLATION_BASE_MODEL, trust_remote_code=True)
        tokenizer.padding_side = "left"
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        
        # Load the base model on CPU with bfloat16 to reduce memory
        # bfloat16 is better supported on CPU than float16
        print("[*] Loading base model on CPU (bfloat16)...")
        base_model = AutoModelForCausalLM.from_pretrained(
            TRANSLATION_BASE_MODEL,
            torch_dtype=torch.bfloat16,
            device_map="cpu",
            low_cpu_mem_usage=True,
            trust_remote_code=True
        )
        
        # Load the LoRA adapter
        print("[*] Loading LoRA adapter...")
        model = PeftModel.from_pretrained(base_model, TRANSLATION_ADAPTER)
        
        # Merge LoRA weights into base model for faster inference
        # This eliminates adapter overhead during generation
        print("[*] Merging LoRA weights for faster inference...")
        model = model.merge_and_unload()
        model.eval()
        
        print(f"[+] Translation model loaded and merged successfully on CPU.")
        _trans_cache["tokenizer"] = tokenizer
        _trans_cache["model"] = model
        _trans_cache["loaded"] = True
        return tokenizer, model
    except Exception as e:
        print(f"[-] Error loading translation model: {e}")
        traceback.print_exc()
        return None, None

# Load TTS models at startup (they're smaller)
model, vocoder = load_models()
# Translation model will be loaded lazily when GPU is available

def _translate_impl(text):
    """Internal translation implementation - matching evaluate_model.py approach."""
    # Load model lazily (will be cached after first load)
    trans_tokenizer, trans_model = load_translation_models()
    
    if trans_model is None:
        return "Translation model unavailable."
    
    # Build chat messages (matching evaluate_model.py)
    messages = [
        {"role": "system", "content": "Translate the text below to Kashmiri."},
        {"role": "user", "content": text},
    ]
    
    try:
        # Apply chat template (matching evaluate_model.py)
        prompt = trans_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = trans_tokenizer(prompt, padding=True, truncation=True, max_length=512, return_tensors="pt")
        
        # Move inputs to model's device
        inputs = {k: v.to(trans_model.device) for k, v in inputs.items()}
        
        print(f"[DEBUG] Input tokens: {inputs['input_ids'].shape[1]}")
        
    except Exception as e:
        print(f"Chat template error: {e}")
        traceback.print_exc()
        return "Error in translation template."

    try:
        import time
        start_time = time.time()
        print("[DEBUG] Starting generation...")
        
        # Generation settings optimized for CPU inference
        # - Greedy decoding (do_sample=False) is faster than sampling
        # - Same quality as temp=0.01 which was near-greedy anyway
        with torch.no_grad():
            generated = trans_model.generate(
                **inputs,
                max_new_tokens=256,  # Keep full length for long texts
                do_sample=False,     # Greedy decoding for speed
                num_beams=1,         # No beam search overhead
            )
        
        elapsed = time.time() - start_time
        print(f"[DEBUG] Generation completed in {elapsed:.2f}s")
        
        # Decode only the new tokens (matching evaluate_model.py)
        input_len = inputs['input_ids'].shape[1]
        output_ids = generated[0][input_len:]
        decoded = trans_tokenizer.decode(output_ids, skip_special_tokens=True).replace("\n", "")
        
        print(f"[DEBUG] New tokens: {len(output_ids)}")
        print(f"[DEBUG] Decoded: '{decoded}'")
        
        return decoded.strip()
        
    except Exception as e:
        print(f"Generation error: {e}")
        traceback.print_exc()
        return "Error during translation generation."

# Simple wrapper function for CPU deployment
def translate(text):
    return _translate_impl(text)


# --- Update the function signature to accept two arguments ---
@torch.inference_mode()
def process(text, speaker_id, n_timesteps=10):
    # 1. Kashmiri script normalization
    text = text.replace("ي", "ی").replace("ك", "ک").replace("۔", "").strip()
    
    # 2. Text to Sequence
    cleaner = "basic_cleaners" 
    sequence, _ = text_to_sequence(text, [cleaner])
    
    # Filter out any non-integer values (unknown characters not in vocabulary)
    # This happens when text contains characters not supported by the TTS model
    filtered_sequence = [s for s in sequence if isinstance(s, int)]
       
    x = torch.tensor(intersperse(filtered_sequence, 0), dtype=torch.long, device=DEVICE)[None]
    x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=DEVICE)
    
    # 3. Use the Speaker ID from the interface
    # Even if you only use one voice, the model requires this tensor
    spks = torch.tensor([int(speaker_id)], device=DEVICE, dtype=torch.long)
    
    # 4. Generate Mel-spectrogram
    output = model.synthesise(
        x, 
        x_lengths, 
        n_timesteps=n_timesteps, 
        temperature=0.667, 
        spks=spks,
        length_scale=1.0
    )
    
    # 5. Generate Waveform
    audio = vocoder(output['mel']).clamp(-1, 1).cpu().squeeze().numpy()
    output_path = "out.wav"
    sf.write(output_path, audio, 22050)
    return output_path

# --- Gradio UI with Translation Option ---
with gr.Blocks(title="GAASH-Lab: Kashmiri TTS & Translation") as demo:
    gr.Markdown("# GAASH-Lab: Kashmiri TTS & Translation")
    gr.Markdown("Enter text in English (check the box) or Kashmiri directly.")
    
    with gr.Row():
        with gr.Column():
            input_text = gr.Textbox(label="Input Text", placeholder="Type here...")
            is_english = gr.Checkbox(label="Input is English (Translate first)", value=False)
            speaker_radio = gr.Radio(choices=["Male", "Female"], value="Male", label="Speaker Voice")
            quality_radio = gr.Radio(
                choices=["Low (fast)", "Medium", "High"], 
                value="Low (fast)", 
                label="Quality"
            )
            gen_btn = gr.Button("Generate Speech", variant="primary")
        
        with gr.Column():
            trans_view = gr.Textbox(label="Processed/Translated Kashmiri Text", interactive=False)
            audio_output = gr.Audio(label="Audio", type="filepath")

    def pipeline(text, is_eng, spk_voice, quality):
        spk_id = 422 if spk_voice == "Male" else 423
        
        if "Low" in quality:
            steps = 10
        elif "Medium" in quality:
            steps = 50
        else:
            steps = 500

        processed_text = text
        if is_eng:
            print(f"Translating input: {text}")
            processed_text = translate(text)
        
        print(f"Synthesizing for: {processed_text}")
        audio_path = process(processed_text, spk_id, steps)
        return processed_text, audio_path

    gen_btn.click(
        pipeline, 
        inputs=[input_text, is_english, speaker_radio, quality_radio], 
        outputs=[trans_view, audio_output]
    )

demo.launch(ssr_mode=False)