Hareesh Polla commited on
Commit
28093f4
·
1 Parent(s): ce2ab10

add MPS support for Apple Silicon

Browse files
Files changed (2) hide show
  1. config.json +1 -0
  2. model.py +2 -2
config.json CHANGED
@@ -8,6 +8,7 @@
8
  "model_type": "inf5",
9
  "remove_sil": true,
10
  "speed": 1.0,
 
11
  "torch_dtype": "float32",
12
  "transformers_version": "4.49.0",
13
  "vocab_path": "checkpoints/vocab.txt"
 
8
  "model_type": "inf5",
9
  "remove_sil": true,
10
  "speed": 1.0,
11
+ "device": "",
12
  "torch_dtype": "float32",
13
  "transformers_version": "4.49.0",
14
  "vocab_path": "checkpoints/vocab.txt"
model.py CHANGED
@@ -24,8 +24,8 @@ import os
24
  class INF5Config(PretrainedConfig):
25
  model_type = "inf5"
26
 
27
- def __init__(self, device: str, ckpt_path: str = "checkpoints/model_best.pt", vocab_path: str = "checkpoints/vocab.txt",
28
- speed: float = 1.0, remove_sil: bool = True, **kwargs):
29
  super().__init__(**kwargs)
30
  self.ckpt_path = ckpt_path
31
  self.vocab_path = vocab_path
 
24
  class INF5Config(PretrainedConfig):
25
  model_type = "inf5"
26
 
27
+ def __init__(self, ckpt_path: str = "checkpoints/model_best.pt", vocab_path: str = "checkpoints/vocab.txt",
28
+ speed: float = 1.0, remove_sil: bool = True, device: str = "", **kwargs):
29
  super().__init__(**kwargs)
30
  self.ckpt_path = ckpt_path
31
  self.vocab_path = vocab_path