amasha03 commited on
Commit
48ee974
·
verified ·
1 Parent(s): 16efb96

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -22
app.py CHANGED
@@ -1,52 +1,90 @@
1
  import gradio as gr
 
 
2
  from TTS.api import TTS
3
  from huggingface_hub import hf_hub_download
4
- import os
5
 
6
  # --- IMPORTING YOUR SEPARATE ROMANIZER ---
7
- from romanizer import sinhala_to_roman
 
 
 
 
 
 
 
8
 
9
- def load_my_model(repo_id):
10
  print(f"Downloading {repo_id}...")
11
  model_path = hf_hub_download(repo_id=repo_id, filename="best_model.pth")
12
  config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
13
- # Initialize TTS
14
  return TTS(model_path=model_path, config_path=config_path, gpu=False)
15
 
16
- # Load Models
17
- print("Initializing Sinhala...")
18
- sin_tts = load_my_model("E-motionAssistant/text-to-speech-VITS-sinhala")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- print("Initializing Tamil...")
21
- tam_tts = load_my_model("E-motionAssistant/text-to-speech-VITS-tamil")
 
 
 
 
 
 
 
 
 
22
 
23
  def generate_voice(text, language):
24
  try:
 
 
 
 
 
25
  if language == "Sinhala":
26
- # Use your separate function
27
  processed_text = sinhala_to_roman(text)
28
- print(f"Input: {text} -> Romanized: {processed_text}")
29
- engine = sin_tts
30
  else:
31
- processed_text = text
32
- engine = tam_tts
33
 
34
- output_path = "output.wav"
35
- engine.tts_to_file(text=processed_text, file_path=output_path)
36
  return output_path
 
37
  except Exception as e:
38
- print(f"Error: {e}")
39
  return None
40
 
41
- # Gradio Interface
 
42
  demo = gr.Interface(
43
  fn=generate_voice,
44
  inputs=[
45
- gr.Textbox(label="Input Text"),
46
- gr.Dropdown(["Sinhala", "Tamil"], label="Select Language")
47
  ],
48
  outputs=gr.Audio(label="Synthesized Speech", type="filepath"),
49
- title="Multilingual VITS TTS"
 
50
  )
51
 
52
- demo.launch()
 
 
1
  import gradio as gr
2
+ import torch
3
+ import os
4
  from TTS.api import TTS
5
  from huggingface_hub import hf_hub_download
 
6
 
7
  # --- IMPORTING YOUR SEPARATE ROMANIZER ---
8
+ # Ensure romanizer.py is in the same directory
9
+ try:
10
+ from romanizer import sinhala_to_roman
11
+ except ImportError:
12
+ print("Warning: romanizer.py not found. Sinhala might not process correctly.")
13
+ def sinhala_to_roman(text): return text
14
+
15
+ # --- MODEL LOADING LOGIC ---
16
 
17
+ def load_standard_model(repo_id):
18
  print(f"Downloading {repo_id}...")
19
  model_path = hf_hub_download(repo_id=repo_id, filename="best_model.pth")
20
  config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
 
21
  return TTS(model_path=model_path, config_path=config_path, gpu=False)
22
 
23
+ def load_eng_model_with_surgery():
24
+ repo_id = "E-motionAssistant/text-to-speech-VITS-english"
25
+ print("--- Starting Weights Surgery for English ---")
26
+
27
+ model_path = hf_hub_download(repo_id=repo_id, filename="best_model.pth")
28
+ config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
29
+
30
+ # Load and fix the tensor mismatch (137 -> 131)
31
+ checkpoint = torch.load(model_path, map_location="cpu")
32
+ raw_weights = checkpoint['model']['text_encoder.emb.weight']
33
+
34
+ if raw_weights.shape[0] == 137:
35
+ checkpoint['model']['text_encoder.emb.weight'] = raw_weights[:131, :]
36
+ fixed_model_path = "fixed_eng_model.pth"
37
+ torch.save(checkpoint, fixed_model_path)
38
+ print("Surgery complete.")
39
+ return TTS(model_path=fixed_model_path, config_path=config_path, gpu=False)
40
+
41
+ return TTS(model_path=model_path, config_path=config_path, gpu=False)
42
 
43
+ # --- INITIALIZE ALL MODELS ---
44
+
45
+ print("Initializing Models...")
46
+ models = {
47
+ "Sinhala": load_standard_model("E-motionAssistant/text-to-speech-VITS-sinhala"),
48
+ "Tamil": load_standard_model("E-motionAssistant/text-to-speech-VITS-tamil"),
49
+ "English": load_eng_model_with_surgery()
50
+ }
51
+ print("All systems online.")
52
+
53
+ # --- INFERENCE FUNCTION ---
54
 
55
  def generate_voice(text, language):
56
  try:
57
+ engine = models.get(language)
58
+ if not engine:
59
+ return None
60
+
61
+ # Apply specific preprocessing
62
  if language == "Sinhala":
 
63
  processed_text = sinhala_to_roman(text)
64
+ print(f"Sinhala Romanized: {processed_text}")
 
65
  else:
66
+ processed_text = text
 
67
 
68
+ output_path = f"output_{language.lower()}.wav"
69
+ engine.tts_to_file(text=str(processed_text), file_path=output_path)
70
  return output_path
71
+
72
  except Exception as e:
73
+ print(f"Synthesis Error ({language}): {e}")
74
  return None
75
 
76
+ # --- GRADIO INTERFACE ---
77
+
78
  demo = gr.Interface(
79
  fn=generate_voice,
80
  inputs=[
81
+ gr.Textbox(label="Input Text", placeholder="Enter text here..."),
82
+ gr.Dropdown(["English", "Sinhala", "Tamil"], label="Select Language", value="English")
83
  ],
84
  outputs=gr.Audio(label="Synthesized Speech", type="filepath"),
85
+ title="Trilingual VITS TTS System",
86
+ description="A unified interface for English (with weight surgery), Sinhala (romanized), and Tamil speech synthesis."
87
  )
88
 
89
+ if __name__ == "__main__":
90
+ demo.launch()