Upload folder using huggingface_hub
Browse files
app.py
CHANGED
|
@@ -17,32 +17,23 @@ MODEL_REPO = "GAASH-Lab/Matcha-TTS-Kashmiri"
|
|
| 17 |
def load_models():
|
| 18 |
print("[*] Downloading GAASH-Lab checkpoint...")
|
| 19 |
ckpt = hf_hub_download(repo_id=MODEL_REPO, filename="model.ckpt", token=HF_TOKEN)
|
| 20 |
-
|
| 21 |
-
# Load Matcha
|
| 22 |
model = MatchaTTS.load_from_checkpoint(ckpt, map_location=DEVICE, weights_only=False)
|
| 23 |
model.eval()
|
| 24 |
-
|
| 25 |
-
print("[*]
|
| 26 |
-
#
|
| 27 |
-
#
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
vocoder_ckpt = hf_hub_download(repo_id="jaketae/hifigan-lj-v1", filename="generator.pth")
|
| 33 |
|
| 34 |
vocoder = HiFiGAN(AttrDict(v1)).to(DEVICE)
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
state_dict = torch.load(vocoder_ckpt, map_location=DEVICE)
|
| 38 |
-
# HiFi-GAN checkpoints usually store the weights under the 'generator' key
|
| 39 |
-
if 'generator' in state_dict:
|
| 40 |
-
vocoder.load_state_dict(state_dict['generator'])
|
| 41 |
-
else:
|
| 42 |
-
vocoder.load_state_dict(state_dict)
|
| 43 |
-
|
| 44 |
vocoder.eval()
|
| 45 |
vocoder.remove_weight_norm()
|
|
|
|
| 46 |
return model, vocoder
|
| 47 |
|
| 48 |
model, vocoder = load_models()
|
|
|
|
| 17 |
def load_models():
|
| 18 |
print("[*] Downloading GAASH-Lab checkpoint...")
|
| 19 |
ckpt = hf_hub_download(repo_id=MODEL_REPO, filename="model.ckpt", token=HF_TOKEN)
|
|
|
|
|
|
|
| 20 |
model = MatchaTTS.load_from_checkpoint(ckpt, map_location=DEVICE, weights_only=False)
|
| 21 |
model.eval()
|
| 22 |
+
|
| 23 |
+
print("[*] Loading HiFi-GAN vocoder...")
|
| 24 |
+
# The file 'generator_v1' is what the code calls 'hifigan_T2_v1'
|
| 25 |
+
# We download it from the official GitHub release if not found locally
|
| 26 |
+
vocoder_path = Path("hifigan_T2_v1")
|
| 27 |
+
if not vocoder_path.exists():
|
| 28 |
+
url = "https://github.com/shivammehta25/Matcha-TTS-checkpoints/releases/download/v1.0/generator_v1"
|
| 29 |
+
urllib.request.urlretrieve(url, vocoder_path)
|
|
|
|
| 30 |
|
| 31 |
vocoder = HiFiGAN(AttrDict(v1)).to(DEVICE)
|
| 32 |
+
state_dict = torch.load(vocoder_path, map_location=DEVICE)
|
| 33 |
+
vocoder.load_state_dict(state_dict['generator'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
vocoder.eval()
|
| 35 |
vocoder.remove_weight_norm()
|
| 36 |
+
|
| 37 |
return model, vocoder
|
| 38 |
|
| 39 |
model, vocoder = load_models()
|