eddddyy commited on
Commit
0204d4a
·
verified ·
1 Parent(s): 27693fd

Update model_loader.py

Browse files
Files changed (1) hide show
  1. model_loader.py +4 -5
model_loader.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
- from config import HF_TOKEN, MODEL_ID # ✅ Make sure this line is here!
4
 
5
  def load_model():
6
  try:
@@ -8,13 +8,13 @@ def load_model():
8
 
9
  tokenizer = AutoTokenizer.from_pretrained(
10
  MODEL_ID,
11
- token=HF_TOKEN,
12
  trust_remote_code=True
13
  )
14
 
15
  model = AutoModelForCausalLM.from_pretrained(
16
  MODEL_ID,
17
- token=HF_TOKEN,
18
  trust_remote_code=True,
19
  device_map="auto" if torch.cuda.is_available() else "cpu",
20
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
@@ -27,7 +27,7 @@ def load_model():
27
  "text-generation",
28
  model=model,
29
  tokenizer=tokenizer,
30
- max_new_tokens=2048, # 🧠 Increased token window
31
  do_sample=True,
32
  temperature=0.7,
33
  top_p=0.9
@@ -36,4 +36,3 @@ def load_model():
36
  except Exception as e:
37
  print(f"❌ Failed to load model: {e}")
38
  raise RuntimeError(f"Model loading failed: {e}")
39
-
 
1
  import torch
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
+ from config import HF_TOKEN, MODEL_ID
4
 
5
  def load_model():
6
  try:
 
8
 
9
  tokenizer = AutoTokenizer.from_pretrained(
10
  MODEL_ID,
11
+ token=HF_TOKEN if HF_TOKEN else None,
12
  trust_remote_code=True
13
  )
14
 
15
  model = AutoModelForCausalLM.from_pretrained(
16
  MODEL_ID,
17
+ token=HF_TOKEN if HF_TOKEN else None,
18
  trust_remote_code=True,
19
  device_map="auto" if torch.cuda.is_available() else "cpu",
20
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
 
27
  "text-generation",
28
  model=model,
29
  tokenizer=tokenizer,
30
+ max_new_tokens=2048,
31
  do_sample=True,
32
  temperature=0.7,
33
  top_p=0.9
 
36
  except Exception as e:
37
  print(f"❌ Failed to load model: {e}")
38
  raise RuntimeError(f"Model loading failed: {e}")