Hareesh Polla
commited on
Commit
·
ce2ab10
1
Parent(s):
b1e9a0e
add MPS support for Apple Silicon
Browse files- config.json +0 -1
- model.py +3 -6
config.json
CHANGED
|
@@ -8,7 +8,6 @@
|
|
| 8 |
"model_type": "inf5",
|
| 9 |
"remove_sil": true,
|
| 10 |
"speed": 1.0,
|
| 11 |
-
"device": "cpu",
|
| 12 |
"torch_dtype": "float32",
|
| 13 |
"transformers_version": "4.49.0",
|
| 14 |
"vocab_path": "checkpoints/vocab.txt"
|
|
|
|
| 8 |
"model_type": "inf5",
|
| 9 |
"remove_sil": true,
|
| 10 |
"speed": 1.0,
|
|
|
|
| 11 |
"torch_dtype": "float32",
|
| 12 |
"transformers_version": "4.49.0",
|
| 13 |
"vocab_path": "checkpoints/vocab.txt"
|
model.py
CHANGED
|
@@ -24,8 +24,8 @@ import os
|
|
| 24 |
class INF5Config(PretrainedConfig):
|
| 25 |
model_type = "inf5"
|
| 26 |
|
| 27 |
-
def __init__(self, ckpt_path: str = "checkpoints/model_best.pt", vocab_path: str = "checkpoints/vocab.txt",
|
| 28 |
-
speed: float = 1.0, remove_sil: bool = True,
|
| 29 |
super().__init__(**kwargs)
|
| 30 |
self.ckpt_path = ckpt_path
|
| 31 |
self.vocab_path = vocab_path
|
|
@@ -38,10 +38,7 @@ class INF5Model(PreTrainedModel):
|
|
| 38 |
|
| 39 |
def __init__(self, config):
|
| 40 |
super().__init__(config)
|
| 41 |
-
if config.device
|
| 42 |
-
self._device = config.device
|
| 43 |
-
else:
|
| 44 |
-
self._device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
| 45 |
|
| 46 |
# Load vocoder
|
| 47 |
self.vocoder = torch.compile(load_vocoder(vocoder_name="vocos", is_local=False, device=self._device))
|
|
|
|
| 24 |
class INF5Config(PretrainedConfig):
|
| 25 |
model_type = "inf5"
|
| 26 |
|
| 27 |
+
def __init__(self, device: str, ckpt_path: str = "checkpoints/model_best.pt", vocab_path: str = "checkpoints/vocab.txt",
|
| 28 |
+
speed: float = 1.0, remove_sil: bool = True, **kwargs):
|
| 29 |
super().__init__(**kwargs)
|
| 30 |
self.ckpt_path = ckpt_path
|
| 31 |
self.vocab_path = vocab_path
|
|
|
|
| 38 |
|
| 39 |
def __init__(self, config):
|
| 40 |
super().__init__(config)
|
| 41 |
+
self._device = config.device if config.device else "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
# Load vocoder
|
| 44 |
self.vocoder = torch.compile(load_vocoder(vocoder_name="vocos", is_local=False, device=self._device))
|