GirishaBuilds01 commited on
Commit
60fe91b
·
verified ·
1 Parent(s): 15d573f

Update core/model_loader.py

Browse files
Files changed (1) hide show
  1. core/model_loader.py +4 -4
core/model_loader.py CHANGED
@@ -3,19 +3,19 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
3
 
4
  SUPPORTED_MODELS = {
5
  "DistilGPT2 (Fast CPU)": "distilgpt2",
6
- "TinyLlama (Better LLM)": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
7
- "Phi-2 (Research Heavy)": "microsoft/phi-2"
8
  }
9
 
10
  def load_model(model_key):
11
  model_name = SUPPORTED_MODELS[model_key]
12
 
13
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
14
  model = AutoModelForCausalLM.from_pretrained(
15
  model_name,
16
- torch_dtype=torch.float32,
17
- device_map="cpu"
18
  )
19
 
 
20
  model.eval()
 
21
  return model, tokenizer
 
3
 
4
  SUPPORTED_MODELS = {
5
  "DistilGPT2 (Fast CPU)": "distilgpt2",
 
 
6
  }
7
 
8
  def load_model(model_key):
9
  model_name = SUPPORTED_MODELS[model_key]
10
 
11
  tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+
13
  model = AutoModelForCausalLM.from_pretrained(
14
  model_name,
15
+ dtype=torch.float32 # ✅ Use dtype instead of torch_dtype
 
16
  )
17
 
18
+ model.to("cpu") # ✅ Explicit CPU move
19
  model.eval()
20
+
21
  return model, tokenizer