|
|
|
|
|
|
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
from rag.retrieval import build_prompt |
|
|
|
|
|
|
|
|
MODEL_NAME = "SeaLLMs/SeaLLMs-v3-1.5B-Chat" |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_NAME, |
|
|
dtype=torch.float32, |
|
|
) |
|
|
|
|
|
|
|
|
def generate_answer(question: str) -> str: |
|
|
""" |
|
|
Build prompt (with context) and generate a short answer. |
|
|
""" |
|
|
prompt = build_prompt(question) |
|
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=160, |
|
|
do_sample=False, |
|
|
) |
|
|
|
|
|
|
|
|
generated_ids = outputs[0][inputs["input_ids"].shape[1]:] |
|
|
answer = tokenizer.decode(generated_ids, skip_special_tokens=True) |
|
|
|
|
|
return answer.strip() |
|
|
|