Hareesh Polla commited on
Commit
a60c006
·
1 Parent(s): 4cb53db

add MPS support for Apple Silicon

Browse files
Files changed (1) hide show
  1. model.py +6 -6
model.py CHANGED
@@ -37,10 +37,10 @@ class INF5Model(PreTrainedModel):
37
 
38
  def __init__(self, config):
39
  super().__init__(config)
40
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
 
42
  # Load vocoder
43
- self.vocoder = torch.compile(load_vocoder(vocoder_name="vocos", is_local=False, device=device))
44
 
45
  # Download and load model weights
46
  # safetensors_path = hf_hub_download(config.name_or_path, filename="model.safetensors")
@@ -55,7 +55,7 @@ class INF5Model(PreTrainedModel):
55
  dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4),
56
  mel_spec_type="vocos",
57
  vocab_file=vocab_path,
58
- device=device
59
  )
60
  )
61
 
@@ -83,8 +83,8 @@ class INF5Model(PreTrainedModel):
83
  ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_path, ref_text)
84
 
85
 
86
- self.ema_model.to(self.device)
87
- self.vocoder.to(self.device)
88
 
89
  # Perform inference
90
  audio, final_sample_rate, _ = infer_process(
@@ -95,7 +95,7 @@ class INF5Model(PreTrainedModel):
95
  self.vocoder,
96
  mel_spec_type="vocos",
97
  speed=self.config.speed,
98
- device=self.device,
99
  )
100
 
101
  # Convert to pydub format and remove silence if needed
 
37
 
38
  def __init__(self, config):
39
  super().__init__(config)
40
+ self._device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
41
 
42
  # Load vocoder
43
+ self.vocoder = torch.compile(load_vocoder(vocoder_name="vocos", is_local=False, device=self._device))
44
 
45
  # Download and load model weights
46
  # safetensors_path = hf_hub_download(config.name_or_path, filename="model.safetensors")
 
55
  dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4),
56
  mel_spec_type="vocos",
57
  vocab_file=vocab_path,
58
+ device=self._device
59
  )
60
  )
61
 
 
83
  ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_path, ref_text)
84
 
85
 
86
+ self.ema_model.to(self._device)
87
+ self.vocoder.to(self._device)
88
 
89
  # Perform inference
90
  audio, final_sample_rate, _ = infer_process(
 
95
  self.vocoder,
96
  mel_spec_type="vocos",
97
  speed=self.config.speed,
98
+ device=self._device,
99
  )
100
 
101
  # Convert to pydub format and remove silence if needed