nambn0321 commited on
Commit
d2b5676
·
verified ·
1 Parent(s): 2ef63bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -47
app.py CHANGED
@@ -1,68 +1,60 @@
1
  import os
 
2
  import torch
3
- import gradio as gr
4
  import numpy as np
5
  import soundfile as sf
6
- import json
7
-
8
  from huggingface_hub import snapshot_download
 
9
  from TTS.utils.synthesizer import Synthesizer
 
10
 
11
- # Step 1: Download and load Glow-TTS from Hugging Face
12
  model_dir = snapshot_download(repo_id="nambn0321/TTS_model")
13
 
 
14
  synthesizer = Synthesizer(
15
  tts_checkpoint=os.path.join(model_dir, "best_model.pth"),
16
  tts_config_path=os.path.join(model_dir, "config.json"),
17
  use_cuda=torch.cuda.is_available()
18
  )
19
 
20
- # Step 2: Load HiFi-GAN
21
  hifigan_checkpoint_path = os.path.join(model_dir, "g_02500000.pth")
22
- hifigan_config_path = os.path.join(model_dir, "config (1).json") # or config.json if shared
23
 
24
-
25
- with open(hifigan_config_path) as f:
26
  hifigan_config = json.load(f)
27
 
28
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
-
30
- vocoder = Generator(hifigan_config).to(device)
31
- vocoder.load_state_dict(torch.load(hifigan_checkpoint_path, map_location=device)["generator"])
32
- vocoder.eval()
33
-
34
- # Step 3: Text → Mel → Waveform
35
- def tts_fn(text):
36
- try:
37
- # Generate mel spectrogram
38
- mel = synthesizer.tts(text, use_glow=False, speaker_name=None, return_spec=True)
39
-
40
- # De-normalize mel if Glow-TTS used symmetric normalization
41
- mel = (mel + 1) * (4.0 / 2) # from symmetric [-1, 1] → [0, 4]
42
-
43
- # Convert to tensor
44
- mel_tensor = torch.tensor(mel).unsqueeze(0).to(torch.float32).to(device)
45
-
46
- # Generate waveform
47
- with torch.no_grad():
48
- audio = vocoder(mel_tensor).squeeze().cpu().numpy()
49
-
50
- # Save to file
51
- out_path = "output.wav"
52
- sf.write(out_path, audio, samplerate=22050)
53
-
54
- return out_path
55
-
56
- except Exception as e:
57
- error_msg = f"Error during TTS processing: {str(e)}"
58
- print(error_msg)
59
- return error_msg
60
-
61
- # Step 4: Launch Gradio app
62
  gr.Interface(
63
- fn=tts_fn,
64
- inputs=gr.Textbox(label="Enter Text"),
65
- outputs=gr.Audio(label="Generated Audio", type="filepath"),
66
- title="Glow-TTS + HiFi-GAN Vocoder",
67
- description="Text-to-speech using a pretrained Glow-TTS model and your custom HiFi-GAN vocoder."
68
  ).launch()
 
1
  import os
2
+ import json
3
  import torch
 
4
  import numpy as np
5
  import soundfile as sf
6
+ import gradio as gr
 
7
  from huggingface_hub import snapshot_download
8
+
9
  from TTS.utils.synthesizer import Synthesizer
10
+ from models import Generator # Your HiFi-GAN Generator class
11
 
12
+ # Download and load models
13
  model_dir = snapshot_download(repo_id="nambn0321/TTS_model")
14
 
15
+ # Glow-TTS
16
  synthesizer = Synthesizer(
17
  tts_checkpoint=os.path.join(model_dir, "best_model.pth"),
18
  tts_config_path=os.path.join(model_dir, "config.json"),
19
  use_cuda=torch.cuda.is_available()
20
  )
21
 
22
+ # HiFi-GAN
23
  hifigan_checkpoint_path = os.path.join(model_dir, "g_02500000.pth")
24
+ hifigan_config_path = os.path.join(model_dir, "config (1).json")
25
 
26
+ with open(hifigan_config_path, "r") as f:
 
27
  hifigan_config = json.load(f)
28
 
29
+ hifigan = Generator(hifigan_config)
30
+ hifigan.load_state_dict(torch.load(hifigan_checkpoint_path, map_location="cpu")["generator"])
31
+ hifigan.eval()
32
+ if torch.cuda.is_available():
33
+ hifigan.cuda()
34
+
35
+ # Inference function
36
+ def tts(text):
37
+ # Glow-TTS: text -> mel
38
+ wav_tensor = synthesizer.tts(text, None, None, return_wav=False) # returns mel
39
+ mel = wav_tensor.squeeze().cpu().numpy()
40
+
41
+ # HiFi-GAN: mel -> waveform
42
+ mel_tensor = torch.from_numpy(mel).unsqueeze(0) # [1, num_mels, T]
43
+ if torch.cuda.is_available():
44
+ mel_tensor = mel_tensor.cuda()
45
+
46
+ with torch.no_grad():
47
+ audio_tensor = hifigan(mel_tensor).cpu().squeeze()
48
+
49
+ audio_np = audio_tensor.numpy()
50
+ sf.write("output.wav", audio_np, samplerate=22050)
51
+ return "output.wav"
52
+
53
+ # Gradio UI
 
 
 
 
 
 
 
 
 
54
  gr.Interface(
55
+ fn=tts,
56
+ inputs=gr.Textbox(label="Enter text", placeholder="Type something..."),
57
+ outputs=gr.Audio(label="Generated Speech"),
58
+ title="Glow-TTS + HiFi-GAN TTS",
59
+ description="Enter text and listen to the generated speech using Glow-TTS and HiFi-GAN"
60
  ).launch()