Heng2004's picture
Create models/sea_llm.py
da4b8c0 verified
raw
history blame
1.03 kB
# models/sea_llm.py
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from rag.retrieval import build_prompt
MODEL_NAME = "SeaLLMs/SeaLLMs-v3-1.5B-Chat"
# Load tokenizer + model once at startup
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
dtype=torch.float32, # CPU on free tier β†’ float32
)
def generate_answer(question: str) -> str:
"""
Build prompt (with context) and generate a short answer.
"""
prompt = build_prompt(question)
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=160, # shorter answers = faster
do_sample=False, # greedy decoding β†’ stable & a bit faster
)
# slice off the prompt part
generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
answer = tokenizer.decode(generated_ids, skip_special_tokens=True)
return answer.strip()