Hareesh Polla commited on
Commit
ce2ab10
·
1 Parent(s): b1e9a0e

add MPS support for Apple Silicon

Browse files
Files changed (2) hide show
  1. config.json +0 -1
  2. model.py +3 -6
config.json CHANGED
@@ -8,7 +8,6 @@
8
  "model_type": "inf5",
9
  "remove_sil": true,
10
  "speed": 1.0,
11
- "device": "cpu",
12
  "torch_dtype": "float32",
13
  "transformers_version": "4.49.0",
14
  "vocab_path": "checkpoints/vocab.txt"
 
8
  "model_type": "inf5",
9
  "remove_sil": true,
10
  "speed": 1.0,
 
11
  "torch_dtype": "float32",
12
  "transformers_version": "4.49.0",
13
  "vocab_path": "checkpoints/vocab.txt"
model.py CHANGED
@@ -24,8 +24,8 @@ import os
24
  class INF5Config(PretrainedConfig):
25
  model_type = "inf5"
26
 
27
- def __init__(self, ckpt_path: str = "checkpoints/model_best.pt", vocab_path: str = "checkpoints/vocab.txt",
28
- speed: float = 1.0, remove_sil: bool = True, device: str = "cpu", **kwargs):
29
  super().__init__(**kwargs)
30
  self.ckpt_path = ckpt_path
31
  self.vocab_path = vocab_path
@@ -38,10 +38,7 @@ class INF5Model(PreTrainedModel):
38
 
39
  def __init__(self, config):
40
  super().__init__(config)
41
- if config.device:
42
- self._device = config.device
43
- else:
44
- self._device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
45
 
46
  # Load vocoder
47
  self.vocoder = torch.compile(load_vocoder(vocoder_name="vocos", is_local=False, device=self._device))
 
24
  class INF5Config(PretrainedConfig):
25
  model_type = "inf5"
26
 
27
+ def __init__(self, device: str, ckpt_path: str = "checkpoints/model_best.pt", vocab_path: str = "checkpoints/vocab.txt",
28
+ speed: float = 1.0, remove_sil: bool = True, **kwargs):
29
  super().__init__(**kwargs)
30
  self.ckpt_path = ckpt_path
31
  self.vocab_path = vocab_path
 
38
 
39
  def __init__(self, config):
40
  super().__init__(config)
41
+ self._device = config.device if config.device else "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
 
 
 
42
 
43
  # Load vocoder
44
  self.vocoder = torch.compile(load_vocoder(vocoder_name="vocos", is_local=False, device=self._device))