amasha03 commited on
Commit
12dfa83
·
verified ·
1 Parent(s): 97553c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -25
app.py CHANGED
@@ -1,38 +1,44 @@
1
  import gradio as gr
2
- from transformers import VitsModel, AutoTokenizer
3
  import torch
 
4
 
5
- # 1. Load the models specifically as VITS models
6
- # This avoids the "Unrecognized model" error
7
- models = {
8
- "English": {
9
- "model": VitsModel.from_pretrained("E-motionAssistant/text-to-speech-VITS-english"),
10
- "tokenizer": AutoTokenizer.from_pretrained("E-motionAssistant/text-to-speech-VITS-english")
11
- },
12
- "Sinhala": {
13
- "model": VitsModel.from_pretrained("E-motionAssistant/text-to-speech-VITS-sinhala"),
14
- "tokenizer": AutoTokenizer.from_pretrained("E-motionAssistant/text-to-speech-VITS-sinhala")
15
- },
16
- "Tamil": {
17
- "model": VitsModel.from_pretrained("E-motionAssistant/text-to-speech-VITS-tamil"),
18
- "tokenizer": AutoTokenizer.from_pretrained("E-motionAssistant/text-to-speech-VITS-tamil")
19
- }
20
- }
21
 
22
  def generate_speech(text, language):
23
  try:
24
- selected = models[language]
25
- inputs = selected["tokenizer"](text, return_tensors="pt")
 
 
 
 
 
 
26
 
27
  with torch.no_grad():
28
- output = selected["model"](**inputs).waveform
29
-
30
- # VITS models typically output at 22050Hz
31
- # We convert the tensor to a numpy array for Gradio
32
- return (22050, output.cpu().numpy().squeeze())
 
33
 
34
  except Exception as e:
35
- print(f"Error: {e}")
36
  return None
37
 
38
  demo = gr.Interface(
 
1
  import gradio as gr
2
+ from transformers import AutoModel, AutoTokenizer
3
  import torch
4
+ import scipy.io.wavfile
5
 
6
+ # Load models with 'trust_remote_code' to handle custom architectures
7
+ def load_model(model_id):
8
+ # trust_remote_code is essential for models that aren't 'native' to transformers
9
+ model = AutoModel.from_pretrained(model_id, trust_remote_code=True)
10
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
11
+ return model, tokenizer
12
+
13
+ print("Loading English Model...")
14
+ eng_model, eng_tok = load_model("E-motionAssistant/text-to-speech-VITS-english")
15
+
16
+ print("Loading Sinhala Model...")
17
+ sin_model, sin_tok = load_model("E-motionAssistant/text-to-speech-VITS-sinhala")
18
+
19
+ print("Loading Tamil Model...")
20
+ tam_model, tam_tok = load_model("E-motionAssistant/text-to-speech-VITS-tamil")
 
21
 
22
  def generate_speech(text, language):
23
  try:
24
+ if language == "English":
25
+ model, tokenizer = eng_model, eng_tok
26
+ elif language == "Sinhala":
27
+ model, tokenizer = sin_model, sin_tok
28
+ else:
29
+ model, tokenizer = tam_model, tam_tok
30
+
31
+ inputs = tokenizer(text, return_tensors="pt")
32
 
33
  with torch.no_grad():
34
+ # VITS models usually return a 'waveform' attribute
35
+ output = model(**inputs)
36
+ waveform = output.waveform.cpu().numpy().squeeze()
37
+
38
+ # Standard VITS sampling rate is 22050
39
+ return (22050, waveform)
40
 
41
  except Exception as e:
 
42
  return None
43
 
44
  demo = gr.Interface(