agent_support / generator.py
NV9523's picture
Update generator.py
079004b verified
import torch
from transformers import TextIteratorStreamer
from threading import Thread
from model_loader import load_model
# Load model & tokenizer từ file model_loader
model, tokenizer = load_model()
# Prompt ngầm
SYSTEM_PROMPT = "Bạn là chatbot chuyên tóm tắt và tổng hợp ý chính từ nội dung người dùng đưa ra."
def generate_text(prompt, max_new_tokens=1024):
# Ghép prompt ngầm + user input
inputs = tokenizer(SYSTEM_PROMPT + "\nUser: " + prompt + "\nBot:", return_tensors="pt").to("cpu")
# Tạo streamer để stream output
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False, # Greedy decoding để nhanh hơn
streamer=streamer
)
# Chạy model.generate trong thread
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Stream text trả ra dần
for new_text in streamer:
yield new_text