Update model.py
Browse files
model.py
CHANGED
|
@@ -37,19 +37,35 @@ class INF5Model(PreTrainedModel):
|
|
| 37 |
super().__init__(config)
|
| 38 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 39 |
|
| 40 |
-
# Load
|
| 41 |
self.vocoder = torch.compile(load_vocoder(vocoder_name="vocos", is_local=False, device=device))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
self.ema_model = torch.compile(
|
| 43 |
load_model(
|
| 44 |
DiT,
|
| 45 |
dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4),
|
| 46 |
-
|
| 47 |
mel_spec_type="vocos",
|
| 48 |
vocab_file=config.vocab_path,
|
| 49 |
device=device
|
| 50 |
)
|
| 51 |
)
|
| 52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
def forward(self, text: str, ref_audio_path: str, ref_text: str):
|
| 54 |
"""
|
| 55 |
Generate speech given a reference audio & text input.
|
|
@@ -105,10 +121,6 @@ class INF5Model(PreTrainedModel):
|
|
| 105 |
|
| 106 |
return np.array(audio_segment.get_array_of_samples())
|
| 107 |
|
| 108 |
-
@classmethod
|
| 109 |
-
def from_pretrained(cls, model_name_or_path, *model_args, **kwargs):
|
| 110 |
-
config = AutoConfig.from_pretrained(model_name_or_path)
|
| 111 |
-
return cls(config)
|
| 112 |
|
| 113 |
|
| 114 |
if __name__ == '__main__':
|
|
|
|
| 37 |
super().__init__(config)
|
| 38 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 39 |
|
| 40 |
+
# Load vocoder
|
| 41 |
self.vocoder = torch.compile(load_vocoder(vocoder_name="vocos", is_local=False, device=device))
|
| 42 |
+
|
| 43 |
+
# Download and load model weights
|
| 44 |
+
safetensors_path = hf_hub_download(config.name_or_path, filename="model.safetensors")
|
| 45 |
+
print(f"Loading model weights from {safetensors_path} (safetensors)...")
|
| 46 |
+
state_dict = load_file(safetensors_path, device=device)
|
| 47 |
+
|
| 48 |
+
# Load the model
|
| 49 |
self.ema_model = torch.compile(
|
| 50 |
load_model(
|
| 51 |
DiT,
|
| 52 |
dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4),
|
| 53 |
+
None, # Skip loading from file, as we load state_dict directly
|
| 54 |
mel_spec_type="vocos",
|
| 55 |
vocab_file=config.vocab_path,
|
| 56 |
device=device
|
| 57 |
)
|
| 58 |
)
|
| 59 |
|
| 60 |
+
# Load state dict into model
|
| 61 |
+
self.ema_model.load_state_dict(state_dict, strict=False)
|
| 62 |
+
|
| 63 |
+
@classmethod
|
| 64 |
+
def from_pretrained(cls, model_name_or_path, *model_args, **kwargs):
|
| 65 |
+
config = AutoConfig.from_pretrained(model_name_or_path)
|
| 66 |
+
config.name_or_path = model_name_or_path
|
| 67 |
+
return cls(config)
|
| 68 |
+
|
| 69 |
def forward(self, text: str, ref_audio_path: str, ref_text: str):
|
| 70 |
"""
|
| 71 |
Generate speech given a reference audio & text input.
|
|
|
|
| 121 |
|
| 122 |
return np.array(audio_segment.get_array_of_samples())
|
| 123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
|
| 126 |
if __name__ == '__main__':
|