NV9523 commited on
Commit
6bc7f9b
·
verified ·
1 Parent(s): d044d45

Update generator.py

Browse files
Files changed (1) hide show
  1. generator.py +22 -16
generator.py CHANGED
@@ -1,21 +1,27 @@
1
- from model_loader import load_model
 
 
 
2
 
3
- # load 1 lần để tái sử dụng
4
- pipe = load_model()
5
 
6
- SYSTEM_PROMPT = (
7
- "Bạn một chatbot chuyên TÓM TẮT TỔNG HỢP ý chính. "
8
- "Nhiệm vụ: nhận câu hỏi hoặc đoạn văn bản và trả lời ngắn gọn, "
9
- "liệt kê các điểm quan trọng nhất.\n\n"
10
- )
11
 
12
- def generate_text(user_input: str, max_new_tokens: int = 128):
13
- prompt = SYSTEM_PROMPT + "Người dùng hỏi: " + user_input + "\n\nTóm tắt & tổng hợp:"
14
- outputs = pipe(
15
- prompt,
16
  max_new_tokens=max_new_tokens,
17
- do_sample=False, # greedy decoding nhanh và ổn định hơn
18
- temperature=0.6,
19
- top_p=0.9
 
20
  )
21
- return outputs[0]["generated_text"]
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import TextIteratorStreamer
3
+ from threading import Thread
4
+ from model_loader import model, tokenizer
5
 
6
+ # Prompt ngầm
7
+ SYSTEM_PROMPT = "Bạn là chatbot chuyên tóm tắt và tổng hợp ý chính từ nội dung người dùng đưa ra."
8
 
9
+ def generate_text(prompt, max_new_tokens=1024):
10
+ inputs = tokenizer(SYSTEM_PROMPT + "\nUser: " + prompt + "\nBot:", return_tensors="pt").to("cpu")
 
 
 
11
 
12
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
13
+ generation_kwargs = dict(
14
+ **inputs,
 
15
  max_new_tokens=max_new_tokens,
16
+ do_sample=False, # Greedy decoding để nhanh hơn
17
+ temperature=0.7,
18
+ top_p=0.9,
19
+ streamer=streamer
20
  )
21
+
22
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
23
+ thread.start()
24
+
25
+ # Trả về stream kết quả
26
+ for new_text in streamer:
27
+ yield new_text