Spaces:
Runtime error
Runtime error
feat(model): load model from safetensors
Browse files- inference.py +4 -3
inference.py
CHANGED
|
@@ -24,6 +24,7 @@ from utils.snac_utils import get_snac, generate_audio_data
|
|
| 24 |
import whisper
|
| 25 |
from tqdm import tqdm
|
| 26 |
from huggingface_hub import snapshot_download
|
|
|
|
| 27 |
|
| 28 |
|
| 29 |
torch.set_printoptions(sci_mode=False)
|
|
@@ -351,14 +352,14 @@ def load_model(ckpt_dir, device):
|
|
| 351 |
whispermodel = whisper.load_model("small").to(device)
|
| 352 |
text_tokenizer = Tokenizer(ckpt_dir)
|
| 353 |
fabric = L.Fabric(devices=1, strategy="auto")
|
| 354 |
-
config = Config.from_file(ckpt_dir + "/
|
| 355 |
config.post_adapter = False
|
| 356 |
|
| 357 |
with fabric.init_module(empty_init=False):
|
| 358 |
model = GPT(config)
|
| 359 |
|
| 360 |
model = fabric.setup(model)
|
| 361 |
-
state_dict =
|
| 362 |
model.load_state_dict(state_dict, strict=True)
|
| 363 |
model.to(device).eval()
|
| 364 |
|
|
@@ -366,7 +367,7 @@ def load_model(ckpt_dir, device):
|
|
| 366 |
|
| 367 |
|
| 368 |
def download_model(ckpt_dir):
|
| 369 |
-
repo_id = "
|
| 370 |
snapshot_download(repo_id, local_dir=ckpt_dir, revision="main")
|
| 371 |
|
| 372 |
|
|
|
|
| 24 |
import whisper
|
| 25 |
from tqdm import tqdm
|
| 26 |
from huggingface_hub import snapshot_download
|
| 27 |
+
from safetensors.torch import load_file
|
| 28 |
|
| 29 |
|
| 30 |
torch.set_printoptions(sci_mode=False)
|
|
|
|
| 352 |
whispermodel = whisper.load_model("small").to(device)
|
| 353 |
text_tokenizer = Tokenizer(ckpt_dir)
|
| 354 |
fabric = L.Fabric(devices=1, strategy="auto")
|
| 355 |
+
config = Config.from_file(ckpt_dir + "/config.json")
|
| 356 |
config.post_adapter = False
|
| 357 |
|
| 358 |
with fabric.init_module(empty_init=False):
|
| 359 |
model = GPT(config)
|
| 360 |
|
| 361 |
model = fabric.setup(model)
|
| 362 |
+
state_dict = load_file(ckpt_dir + "/lit_model.safetensors")
|
| 363 |
model.load_state_dict(state_dict, strict=True)
|
| 364 |
model.to(device).eval()
|
| 365 |
|
|
|
|
| 367 |
|
| 368 |
|
| 369 |
def download_model(ckpt_dir):
|
| 370 |
+
repo_id = "leafspark/mini-omni-safetensors"
|
| 371 |
snapshot_download(repo_id, local_dir=ckpt_dir, revision="main")
|
| 372 |
|
| 373 |
|