Hareesh Polla commited on
Commit ·
28093f4
1
Parent(s): ce2ab10
add MPS support for Apple Silicon
Browse files- config.json +1 -0
- model.py +2 -2
config.json
CHANGED
|
@@ -8,6 +8,7 @@
|
|
| 8 |
"model_type": "inf5",
|
| 9 |
"remove_sil": true,
|
| 10 |
"speed": 1.0,
|
|
|
|
| 11 |
"torch_dtype": "float32",
|
| 12 |
"transformers_version": "4.49.0",
|
| 13 |
"vocab_path": "checkpoints/vocab.txt"
|
|
|
|
| 8 |
"model_type": "inf5",
|
| 9 |
"remove_sil": true,
|
| 10 |
"speed": 1.0,
|
| 11 |
+
"device": "",
|
| 12 |
"torch_dtype": "float32",
|
| 13 |
"transformers_version": "4.49.0",
|
| 14 |
"vocab_path": "checkpoints/vocab.txt"
|
model.py
CHANGED
|
@@ -24,8 +24,8 @@ import os
|
|
| 24 |
class INF5Config(PretrainedConfig):
|
| 25 |
model_type = "inf5"
|
| 26 |
|
| 27 |
-
def __init__(self,
|
| 28 |
-
speed: float = 1.0, remove_sil: bool = True, **kwargs):
|
| 29 |
super().__init__(**kwargs)
|
| 30 |
self.ckpt_path = ckpt_path
|
| 31 |
self.vocab_path = vocab_path
|
|
|
|
| 24 |
class INF5Config(PretrainedConfig):
|
| 25 |
model_type = "inf5"
|
| 26 |
|
| 27 |
+
def __init__(self, ckpt_path: str = "checkpoints/model_best.pt", vocab_path: str = "checkpoints/vocab.txt",
|
| 28 |
+
speed: float = 1.0, remove_sil: bool = True, device: str = "", **kwargs):
|
| 29 |
super().__init__(**kwargs)
|
| 30 |
self.ckpt_path = ckpt_path
|
| 31 |
self.vocab_path = vocab_path
|