NV9523 commited on
Commit
079004b
·
verified ·
1 Parent(s): 6bc7f9b

Update generator.py

Browse files
Files changed (1) hide show
  1. generator.py +9 -5
generator.py CHANGED
@@ -1,27 +1,31 @@
1
  import torch
2
  from transformers import TextIteratorStreamer
3
  from threading import Thread
4
- from model_loader import model, tokenizer
 
 
 
5
 
6
  # Prompt ngầm
7
  SYSTEM_PROMPT = "Bạn là chatbot chuyên tóm tắt và tổng hợp ý chính từ nội dung người dùng đưa ra."
8
 
9
  def generate_text(prompt, max_new_tokens=1024):
 
10
  inputs = tokenizer(SYSTEM_PROMPT + "\nUser: " + prompt + "\nBot:", return_tensors="pt").to("cpu")
11
 
 
12
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
13
  generation_kwargs = dict(
14
  **inputs,
15
  max_new_tokens=max_new_tokens,
16
- do_sample=False, # Greedy decoding để nhanh hơn
17
- temperature=0.7,
18
- top_p=0.9,
19
  streamer=streamer
20
  )
21
 
 
22
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
23
  thread.start()
24
 
25
- # Trả về stream kết quả
26
  for new_text in streamer:
27
  yield new_text
 
1
  import torch
2
  from transformers import TextIteratorStreamer
3
  from threading import Thread
4
+ from model_loader import load_model
5
+
6
+ # Load model & tokenizer từ file model_loader
7
+ model, tokenizer = load_model()
8
 
9
  # Prompt ngầm
10
  SYSTEM_PROMPT = "Bạn là chatbot chuyên tóm tắt và tổng hợp ý chính từ nội dung người dùng đưa ra."
11
 
12
  def generate_text(prompt, max_new_tokens=1024):
13
+ # Ghép prompt ngầm + user input
14
  inputs = tokenizer(SYSTEM_PROMPT + "\nUser: " + prompt + "\nBot:", return_tensors="pt").to("cpu")
15
 
16
+ # Tạo streamer để stream output
17
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
18
  generation_kwargs = dict(
19
  **inputs,
20
  max_new_tokens=max_new_tokens,
21
+ do_sample=False, # Greedy decoding để nhanh hơn
 
 
22
  streamer=streamer
23
  )
24
 
25
+ # Chạy model.generate trong thread
26
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
27
  thread.start()
28
 
29
+ # Stream text trả ra dần
30
  for new_text in streamer:
31
  yield new_text