prelington commited on
Commit
2eaa513
·
verified ·
1 Parent(s): 9c56b78

Create model_loader.py

Browse files
Files changed (1) hide show
  1. model_loader.py +26 -0
model_loader.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model_loader.py
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from safetensors.torch import load_file
4
+ import torch
5
+ from config import MODEL_NAME, DEVICE
6
+
7
+ def load_model():
8
+ """
9
+ Load model using safetensors if available, else fallback to normal HF model.
10
+ """
11
+ try:
12
+ # Attempt to load safetensors
13
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
14
+ model = AutoModelForCausalLM.from_pretrained(
15
+ MODEL_NAME,
16
+ device_map="auto",
17
+ torch_dtype=torch.float16,
18
+ use_safetensors=True
19
+ )
20
+ print("[INFO] Model loaded using safetensors.")
21
+ except Exception as e:
22
+ print(f"[WARN] Could not use safetensors. Loading normal model. {e}")
23
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
24
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE)
25
+
26
+ return tokenizer, model