amasha03 commited on
Commit
c36673e
·
verified ·
1 Parent(s): 754e278

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -36
app.py CHANGED
@@ -1,63 +1,52 @@
1
  import gradio as gr
2
  import torch
3
- from TTS.utils.synthesizer import Synthesizer
4
- from TTS.tts.models.vits import Vits
5
- from TTS.tts.configs.vits_config import VitsConfig
6
  from huggingface_hub import hf_hub_download
7
  import os
8
 
9
  def load_eng_model():
10
  repo_id = "E-motionAssistant/text-to-speech-VITS-english"
11
- print(f"--- Bypassing TTS Library English Defaults ---")
12
 
13
  model_path = hf_hub_download(repo_id=repo_id, filename="best_model.pth")
14
  config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
15
 
16
- # 1. Load the Config
17
- config = VitsConfig()
18
- config.load_json(config_path)
19
-
20
- # 2. THE LOBOTOMY: Strip the language and characters from the library's view
21
- # This stops the library from forcing '131'
22
- config.model_args.num_chars = 137
23
- if hasattr(config, 'characters'):
24
- config.characters = None # Forces the model to use the checkpoint's internal map
25
 
26
- # 3. Build the Model Architecture Manually
27
- model = Vits.init_from_config(config)
 
 
28
 
29
- # 4. Load the weights
30
- checkpoint = torch.load(model_path, map_location="cpu")
 
31
 
32
- # Use 'strict=False' but now the architecture should actually match
33
- model.load_state_dict(checkpoint["model"], strict=False)
34
- model.eval()
35
-
36
- # 5. Build Synthesizer WITHOUT a language label
37
- syn = Synthesizer(
38
- tts_checkpoint=model_path,
39
- tts_config_path=config_path,
40
- use_cuda=False
41
- )
42
- syn.tts_model = model
43
 
44
- return syn
45
 
46
  # --- Initialization ---
47
  try:
48
  eng_tts = load_eng_model()
49
- print("--- SUCCESS: LIBRARY BYPASSED, MODEL LOADED ---")
50
  except Exception as e:
51
- print(f"LOAD FAILED: {e}")
52
  eng_tts = None
53
 
54
  def generate_voice(text):
55
  if not eng_tts: return None
56
  try:
57
- output_path = os.path.join(os.getcwd(), "output.wav")
58
- # Synthesize using the manual model
59
- wav = eng_tts.tts(text=str(text))
60
- eng_tts.save_wav(wav, output_path)
61
  return output_path
62
  except Exception as e:
63
  print(f"Synthesis Error: {e}")
@@ -67,7 +56,7 @@ demo = gr.Interface(
67
  fn=generate_voice,
68
  inputs=gr.Textbox(label="English Text"),
69
  outputs=gr.Audio(label="Result", type="filepath"),
70
- title="TTS Library Override"
71
  )
72
 
73
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import torch
3
+ from TTS.api import TTS
 
 
4
  from huggingface_hub import hf_hub_download
5
  import os
6
 
7
  def load_eng_model():
8
  repo_id = "E-motionAssistant/text-to-speech-VITS-english"
9
+ print("--- Starting Weights Surgery ---")
10
 
11
  model_path = hf_hub_download(repo_id=repo_id, filename="best_model.pth")
12
  config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
13
 
14
+ # 1. Load the "Brain" (Checkpoint) directly into PyTorch
15
+ checkpoint = torch.load(model_path, map_location="cpu")
 
 
 
 
 
 
 
16
 
17
+ # 2. PERFORM SURGERY: Shrink the layer from 137 down to 131
18
+ # This removes the mismatch error entirely
19
+ raw_weights = checkpoint['model']['text_encoder.emb.weight']
20
+ print(f"Original weight shape: {raw_weights.shape}")
21
 
22
+ if raw_weights.shape[0] == 137:
23
+ print("Trimming 137 -> 131...")
24
+ checkpoint['model']['text_encoder.emb.weight'] = raw_weights[:131, :]
25
 
26
+ # 3. Save the "Fixed" brain to a new file
27
+ fixed_model_path = os.path.join(os.getcwd(), "fixed_model.pth")
28
+ torch.save(checkpoint, fixed_model_path)
29
+ print("Surgery complete. Fixed model saved.")
30
+
31
+ # 4. Load using the standard TTS library
32
+ # Now that the weights match (131), it won't crash!
33
+ tts = TTS(model_path=fixed_model_path, config_path=config_path, gpu=False)
 
 
 
34
 
35
+ return tts
36
 
37
  # --- Initialization ---
38
  try:
39
  eng_tts = load_eng_model()
40
+ print("--- SUCCESS: SURGERY WORKED, SYSTEM ONLINE ---")
41
  except Exception as e:
42
+ print(f"CRITICAL ERROR: {e}")
43
  eng_tts = None
44
 
45
  def generate_voice(text):
46
  if not eng_tts: return None
47
  try:
48
+ output_path = "output.wav"
49
+ eng_tts.tts_to_file(text=str(text), file_path=output_path)
 
 
50
  return output_path
51
  except Exception as e:
52
  print(f"Synthesis Error: {e}")
 
56
  fn=generate_voice,
57
  inputs=gr.Textbox(label="English Text"),
58
  outputs=gr.Audio(label="Result", type="filepath"),
59
+ title="English TTS (Surgery Version)"
60
  )
61
 
62
  if __name__ == "__main__":