File size: 1,065 Bytes
6bc7f9b
 
 
079004b
 
 
 
6172539
6bc7f9b
 
6172539
6bc7f9b
079004b
6bc7f9b
6172539
079004b
6bc7f9b
 
 
6172539
079004b
6bc7f9b
6172539
6bc7f9b
079004b
6bc7f9b
 
 
079004b
6bc7f9b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
from transformers import TextIteratorStreamer
from threading import Thread
from model_loader import load_model

# Load model & tokenizer từ file model_loader
model, tokenizer = load_model()

# Prompt ngầm
SYSTEM_PROMPT = "Bạn là chatbot chuyên tóm tắt và tổng hợp ý chính từ nội dung người dùng đưa ra."

def generate_text(prompt, max_new_tokens=1024):
    # Ghép prompt ngầm + user input
    inputs = tokenizer(SYSTEM_PROMPT + "\nUser: " + prompt + "\nBot:", return_tensors="pt").to("cpu")

    # Tạo streamer để stream output
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    generation_kwargs = dict(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=False,  # Greedy decoding để nhanh hơn
        streamer=streamer
    )

    # Chạy model.generate trong thread
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    # Stream text trả ra dần
    for new_text in streamer:
        yield new_text