BurhaanZargar commited on
Commit
abcb225
·
verified ·
1 Parent(s): e42828e

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +11 -20
app.py CHANGED
@@ -17,32 +17,23 @@ MODEL_REPO = "GAASH-Lab/Matcha-TTS-Kashmiri"
17
  def load_models():
18
  print("[*] Downloading GAASH-Lab checkpoint...")
19
  ckpt = hf_hub_download(repo_id=MODEL_REPO, filename="model.ckpt", token=HF_TOKEN)
20
-
21
- # Load Matcha
22
  model = MatchaTTS.load_from_checkpoint(ckpt, map_location=DEVICE, weights_only=False)
23
  model.eval()
24
-
25
- print("[*] Downloading HiFi-GAN vocoder...")
26
- # FIX: Use the correct repository ID and filename
27
- # Many Matcha-TTS setups use 'shivammehta25/Matcha-TTS' or specific vocoder repos
28
- try:
29
- vocoder_ckpt = hf_hub_download(repo_id="shivammehta25/Matcha-TTS", filename="hifigan_v1")
30
- except Exception:
31
- # Fallback to another common public HiFi-GAN checkpoint if the above is unavailable
32
- vocoder_ckpt = hf_hub_download(repo_id="jaketae/hifigan-lj-v1", filename="generator.pth")
33
 
34
  vocoder = HiFiGAN(AttrDict(v1)).to(DEVICE)
35
-
36
- # Load state dict
37
- state_dict = torch.load(vocoder_ckpt, map_location=DEVICE)
38
- # HiFi-GAN checkpoints usually store the weights under the 'generator' key
39
- if 'generator' in state_dict:
40
- vocoder.load_state_dict(state_dict['generator'])
41
- else:
42
- vocoder.load_state_dict(state_dict)
43
-
44
  vocoder.eval()
45
  vocoder.remove_weight_norm()
 
46
  return model, vocoder
47
 
48
  model, vocoder = load_models()
 
17
  def load_models():
18
  print("[*] Downloading GAASH-Lab checkpoint...")
19
  ckpt = hf_hub_download(repo_id=MODEL_REPO, filename="model.ckpt", token=HF_TOKEN)
 
 
20
  model = MatchaTTS.load_from_checkpoint(ckpt, map_location=DEVICE, weights_only=False)
21
  model.eval()
22
+
23
+ print("[*] Loading HiFi-GAN vocoder...")
24
+ # The file 'generator_v1' is what the code calls 'hifigan_T2_v1'
25
+ # We download it from the official GitHub release if not found locally
26
+ vocoder_path = Path("hifigan_T2_v1")
27
+ if not vocoder_path.exists():
28
+ url = "https://github.com/shivammehta25/Matcha-TTS-checkpoints/releases/download/v1.0/generator_v1"
29
+ urllib.request.urlretrieve(url, vocoder_path)
 
30
 
31
  vocoder = HiFiGAN(AttrDict(v1)).to(DEVICE)
32
+ state_dict = torch.load(vocoder_path, map_location=DEVICE)
33
+ vocoder.load_state_dict(state_dict['generator'])
 
 
 
 
 
 
 
34
  vocoder.eval()
35
  vocoder.remove_weight_norm()
36
+
37
  return model, vocoder
38
 
39
  model, vocoder = load_models()