Heng2004 commited on
Commit
cf716c3
·
verified ·
1 Parent(s): 1406a1d

Update model_utils.py

Browse files
Files changed (1) hide show
  1. model_utils.py +8 -2
model_utils.py CHANGED
@@ -14,10 +14,16 @@ from loader import load_curriculum, load_manual_qa, rebuild_combined_qa
14
  MODEL_NAME = "SeaLLMs/SeaLLMs-v3-1.5B-Chat"
15
 
16
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
 
 
 
17
  model = AutoModelForCausalLM.from_pretrained(
18
  MODEL_NAME,
19
  torch_dtype=torch.float32,
20
- )
 
 
 
21
 
22
  # Load data once at import time
23
  load_curriculum()
@@ -109,7 +115,7 @@ def build_prompt(question: str) -> str:
109
 
110
  def generate_answer(question: str) -> str:
111
  prompt = build_prompt(question)
112
- inputs = tokenizer(prompt, return_tensors="pt")
113
  with torch.no_grad():
114
  outputs = model.generate(
115
  **inputs,
 
14
  MODEL_NAME = "SeaLLMs/SeaLLMs-v3-1.5B-Chat"
15
 
16
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
17
+
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+
20
  model = AutoModelForCausalLM.from_pretrained(
21
  MODEL_NAME,
22
  torch_dtype=torch.float32,
23
+ ).to(device)
24
+
25
+ model.eval()
26
+
27
 
28
  # Load data once at import time
29
  load_curriculum()
 
115
 
116
  def generate_answer(question: str) -> str:
117
  prompt = build_prompt(question)
118
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
119
  with torch.no_grad():
120
  outputs = model.generate(
121
  **inputs,