Update model.py
Browse files
model.py
CHANGED
|
@@ -43,9 +43,9 @@ class INF5Model(PreTrainedModel):
|
|
| 43 |
self.vocoder = torch.compile(load_vocoder(vocoder_name="vocos", is_local=False, device=device))
|
| 44 |
|
| 45 |
# Download and load model weights
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
|
| 50 |
# Download vocab.txt from HF Hub
|
| 51 |
vocab_path = hf_hub_download(config.name_or_path, filename="checkpoints/vocab.txt")
|
|
@@ -60,7 +60,7 @@ class INF5Model(PreTrainedModel):
|
|
| 60 |
)
|
| 61 |
|
| 62 |
# # Load state dict into model
|
| 63 |
-
|
| 64 |
|
| 65 |
|
| 66 |
def forward(self, text: str, ref_audio_path: str, ref_text: str):
|
|
|
|
| 43 |
self.vocoder = torch.compile(load_vocoder(vocoder_name="vocos", is_local=False, device=device))
|
| 44 |
|
| 45 |
# Download and load model weights
|
| 46 |
+
safetensors_path = hf_hub_download(config.name_or_path, filename="model.safetensors")
|
| 47 |
+
print(f"Loading model weights from {safetensors_path} (safetensors)...")
|
| 48 |
+
state_dict = load_file(safetensors_path, device=str(device))
|
| 49 |
|
| 50 |
# Download vocab.txt from HF Hub
|
| 51 |
vocab_path = hf_hub_download(config.name_or_path, filename="checkpoints/vocab.txt")
|
|
|
|
| 60 |
)
|
| 61 |
|
| 62 |
# # Load state dict into model
|
| 63 |
+
self.ema_model.load_state_dict(state_dict, strict=False)
|
| 64 |
|
| 65 |
|
| 66 |
def forward(self, text: str, ref_audio_path: str, ref_text: str):
|