saeedabdulmuizz commited on
Commit
2a22d6f
·
verified ·
1 Parent(s): 56b24d2

Update app.py

Browse files

Added option for translation

Files changed (1) hide show
  1. app.py +78 -11
app.py CHANGED
@@ -6,6 +6,8 @@ import urllib.request
6
  import os # Add this import at the top
7
  import soundfile as sf
8
  from huggingface_hub import hf_hub_download
 
 
9
  from matcha.models.matcha_tts import MatchaTTS
10
  from matcha.hifigan.models import Generator as HiFiGAN
11
  from matcha.hifigan.config import v1
@@ -40,7 +42,52 @@ def load_models():
40
 
41
  return model, vocoder
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  model, vocoder = load_models()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  # --- Update the function signature to accept two arguments ---
46
  @torch.inference_mode()
@@ -74,16 +121,36 @@ def process(text, speaker_id):
74
  sf.write(output_path, audio, 22050)
75
  return output_path
76
 
77
- # --- Update the Interface inputs to match (2 inputs) ---
78
- demo = gr.Interface(
79
- fn=process,
80
- inputs=[
81
- gr.Textbox(label="Kashmiri Text"),
82
- # Added a slider so you can select the voice (0 is usually the default)
83
- gr.Slider(0, model.n_spks - 1, step=1, value=0, label="Speaker ID")
84
- ],
85
- outputs=gr.Audio(label="Audio", type="filepath"),
86
- title="GAASH-Lab: Kashmiri TTS"
87
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  demo.launch()
 
6
  import os # Add this import at the top
7
  import soundfile as sf
8
  from huggingface_hub import hf_hub_download
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer
10
+ from peft import PeftModel
11
  from matcha.models.matcha_tts import MatchaTTS
12
  from matcha.hifigan.models import Generator as HiFiGAN
13
  from matcha.hifigan.config import v1
 
42
 
43
  return model, vocoder
44
 
45
+ # Translation Config
46
+ TRANSLATION_BASE_MODEL = "sarvamai/sarvam-translate"
47
+ TRANSLATION_ADAPTER = "GAASH-Lab/Sarvam-Kashmiri-finetuned"
48
+
49
+ def load_translation_models():
50
+ print("[*] Loading Sarvam Translate Adapter...")
51
+ try:
52
+ tokenizer = AutoTokenizer.from_pretrained(TRANSLATION_BASE_MODEL)
53
+ base_model = AutoModelForCausalLM.from_pretrained(
54
+ TRANSLATION_BASE_MODEL,
55
+ device_map="auto",
56
+ torch_dtype=torch.float16
57
+ )
58
+ model = PeftModel.from_pretrained(base_model, TRANSLATION_ADAPTER)
59
+ model.eval()
60
+ return tokenizer, model
61
+ except Exception as e:
62
+ print(f"Error loading translation model: {e}")
63
+ return None, None
64
+
65
  model, vocoder = load_models()
66
+ trans_tokenizer, trans_model = load_translation_models()
67
+
68
+ def translate(text):
69
+ if trans_model is None:
70
+ return "Translation model unavailable."
71
+
72
+ messages = [
73
+ {"role": "system", "content": "Translate the text below to Kashmiri."},
74
+ {"role": "user", "content": text},
75
+ ]
76
+
77
+ try:
78
+ # Note: apply_chat_template returns input_ids tensor directly if tokenize=True and return_tensors="pt"
79
+ input_ids = trans_tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(trans_model.device)
80
+ except Exception as e:
81
+ print(f"Chat template error: {e}")
82
+ return "Error in translation template."
83
+
84
+ with torch.no_grad():
85
+ outputs = trans_model.generate(input_ids, max_new_tokens=256)
86
+
87
+ # Slice reusing the input length
88
+ decoded = trans_tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
89
+ return decoded.strip()
90
+
91
 
92
  # --- Update the function signature to accept two arguments ---
93
  @torch.inference_mode()
 
121
  sf.write(output_path, audio, 22050)
122
  return output_path
123
 
124
+ # --- Gradio UI with Translation Option ---
125
+ with gr.Blocks(title="GAASH-Lab: Kashmiri TTS & Translation") as demo:
126
+ gr.Markdown("# GAASH-Lab: Kashmiri TTS & Translation")
127
+ gr.Markdown("Enter text in English (check the box) or Kashmiri directly.")
128
+
129
+ with gr.Row():
130
+ with gr.Column():
131
+ input_text = gr.Textbox(label="Input Text", placeholder="Type here...")
132
+ is_english = gr.Checkbox(label="Input is English (Translate first)", value=False)
133
+ speaker_slider = gr.Slider(0, model.n_spks - 1, step=1, value=0, label="Speaker ID")
134
+ gen_btn = gr.Button("Generate Speech", variant="primary")
135
+
136
+ with gr.Column():
137
+ trans_view = gr.Textbox(label="Processed/Translated Kashmiri Text", interactive=False)
138
+ audio_output = gr.Audio(label="Audio", type="filepath")
139
+
140
+ def pipeline(text, is_eng, spk_id):
141
+ processed_text = text
142
+ if is_eng:
143
+ print(f"Translating input: {text}")
144
+ processed_text = translate(text)
145
+
146
+ print(f"Synthesizing for: {processed_text}")
147
+ audio_path = process(processed_text, spk_id)
148
+ return processed_text, audio_path
149
+
150
+ gen_btn.click(
151
+ pipeline,
152
+ inputs=[input_text, is_english, speaker_slider],
153
+ outputs=[trans_view, audio_output]
154
+ )
155
 
156
  demo.launch()