svp19 commited on
Commit
cdf0dee
·
1 Parent(s): 35d62de

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +18 -6
model.py CHANGED
@@ -37,19 +37,35 @@ class INF5Model(PreTrainedModel):
37
  super().__init__(config)
38
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
 
40
- # Load the vocoder and model
41
  self.vocoder = torch.compile(load_vocoder(vocoder_name="vocos", is_local=False, device=device))
 
 
 
 
 
 
 
42
  self.ema_model = torch.compile(
43
  load_model(
44
  DiT,
45
  dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4),
46
- config.ckpt_path,
47
  mel_spec_type="vocos",
48
  vocab_file=config.vocab_path,
49
  device=device
50
  )
51
  )
52
 
 
 
 
 
 
 
 
 
 
53
  def forward(self, text: str, ref_audio_path: str, ref_text: str):
54
  """
55
  Generate speech given a reference audio & text input.
@@ -105,10 +121,6 @@ class INF5Model(PreTrainedModel):
105
 
106
  return np.array(audio_segment.get_array_of_samples())
107
 
108
- @classmethod
109
- def from_pretrained(cls, model_name_or_path, *model_args, **kwargs):
110
- config = AutoConfig.from_pretrained(model_name_or_path)
111
- return cls(config)
112
 
113
 
114
  if __name__ == '__main__':
 
37
  super().__init__(config)
38
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
 
40
+ # Load vocoder
41
  self.vocoder = torch.compile(load_vocoder(vocoder_name="vocos", is_local=False, device=device))
42
+
43
+ # Download and load model weights
44
+ safetensors_path = hf_hub_download(config.name_or_path, filename="model.safetensors")
45
+ print(f"Loading model weights from {safetensors_path} (safetensors)...")
46
+ state_dict = load_file(safetensors_path, device=device)
47
+
48
+ # Load the model
49
  self.ema_model = torch.compile(
50
  load_model(
51
  DiT,
52
  dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4),
53
+ None, # Skip loading from file, as we load state_dict directly
54
  mel_spec_type="vocos",
55
  vocab_file=config.vocab_path,
56
  device=device
57
  )
58
  )
59
 
60
+ # Load state dict into model
61
+ self.ema_model.load_state_dict(state_dict, strict=False)
62
+
63
+ @classmethod
64
+ def from_pretrained(cls, model_name_or_path, *model_args, **kwargs):
65
+ config = AutoConfig.from_pretrained(model_name_or_path)
66
+ config.name_or_path = model_name_or_path
67
+ return cls(config)
68
+
69
  def forward(self, text: str, ref_audio_path: str, ref_text: str):
70
  """
71
  Generate speech given a reference audio & text input.
 
121
 
122
  return np.array(audio_segment.get_array_of_samples())
123
 
 
 
 
 
124
 
125
 
126
  if __name__ == '__main__':