phamhoangf commited on
Commit
88f5ef8
·
verified ·
1 Parent(s): 00c5fdd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -30
app.py CHANGED
@@ -1,54 +1,79 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import pipeline, AutoTokenizer
 
4
  import os
5
 
 
 
6
  # Lấy token từ secrets của Space
7
  hf_token = os.environ.get("HF_TOKEN")
8
-
9
- # Tải mô hình và tokenizer
10
- # device_map="auto" sẽ tự động sử dụng GPU nếu có
11
  model_id = "phamhoangf/struct-aware-baseline-qwen3-4b"
12
 
13
- # Sử dụng token để xác thực khi tải tokenizer
14
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
15
 
16
- # ---- SỬA LỖI ĐÂY ----
17
- # Truyền token trực tiếp cho pipeline, không dùng model_kwargs
18
- pipe = pipeline(
19
- "text-generation",
20
- model=model_id,
21
  torch_dtype=torch.bfloat16,
22
  device_map="auto",
23
- token=hf_token # Sửa từ 'model_kwargs' thành cách này
24
  )
25
 
 
 
26
  def predict(message, history):
27
- # Xây dựng prompt từ lịch sử trò chuyện theo template của Qwen2
28
  messages = []
29
  for user_msg, assistant_msg in history:
30
  messages.append({"role": "user", "content": user_msg})
31
  messages.append({"role": "assistant", "content": assistant_msg})
32
  messages.append({"role": "user", "content": message})
33
 
34
- # Tạo prompt hoàn chỉnh
35
- prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
36
-
37
- # Tạo văn bản
38
- outputs = pipe(
39
- prompt,
40
- max_new_tokens=256,
 
 
 
 
 
 
 
 
 
 
 
 
41
  do_sample=True,
42
  temperature=0.7,
43
- top_k=20,
44
- top_p=0.8,
45
  )
46
-
47
- # Trích xuất phần trả lời
48
- generated_text = outputs[0]["generated_text"]
49
- # Lấy phần văn bản mới được tạo ra (sau prompt)
50
- response = generated_text[len(prompt):]
51
- return response
52
-
53
- # Tạo giao diện Chat, giao diện này cũng tự động tạo ra một API
54
- gr.ChatInterface(predict).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
4
+ from threading import Thread
5
  import os
6
 
7
+ # --- 1. CÀI ĐẶT MODEL ---
8
+
9
  # Lấy token từ secrets của Space
10
  hf_token = os.environ.get("HF_TOKEN")
 
 
 
11
  model_id = "phamhoangf/struct-aware-baseline-qwen3-4b"
12
 
13
+ # Tải tokenizer
14
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
15
 
16
+ # Tải model trực tiếp thay vì dùng pipeline
17
+ # Điều này cho phép chúng ta truy cập hàm .generate() với streamer
18
+ model = AutoModelForCausalLM.from_pretrained(
19
+ model_id,
 
20
  torch_dtype=torch.bfloat16,
21
  device_map="auto",
22
+ token=hf_token
23
  )
24
 
25
+ # --- 2. HÀM DỰ ĐOÁN ĐÃ HỖ TRỢ STREAMING ---
26
+
27
  def predict(message, history):
28
+ # Xây dựng prompt từ lịch sử trò chuyện
29
  messages = []
30
  for user_msg, assistant_msg in history:
31
  messages.append({"role": "user", "content": user_msg})
32
  messages.append({"role": "assistant", "content": assistant_msg})
33
  messages.append({"role": "user", "content": message})
34
 
35
+ prompt = tokenizer.apply_chat_template(
36
+ messages,
37
+ tokenize=False,
38
+ add_generation_prompt=True
39
+ )
40
+
41
+ # Tokenize input
42
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
43
+
44
+ # Tạo một streamer
45
+ # skip_prompt=True để không lặp lại prompt trong output
46
+ # skip_special_tokens=True để bỏ qua các token như </s>
47
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
48
+
49
+ # Các tham số cho việc sinh token
50
+ generation_kwargs = dict(
51
+ **inputs,
52
+ streamer=streamer,
53
+ max_new_tokens=1024,
54
  do_sample=True,
55
  temperature=0.7,
56
+ top_p=0.95,
57
+ top_k=50,
58
  )
59
+
60
+ # Chạy việc sinh token trong một luồng (thread) riêng
61
+ # để nó không block luồng chính đang trả kết quả cho Gradio
62
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
63
+ thread.start()
64
+
65
+ # Yield từng phần của văn bản được tạo ra
66
+ generated_text = ""
67
+ for new_text in streamer:
68
+ generated_text += new_text
69
+ yield generated_text # Trả về chuỗi đã được cập nhật cho Gradio
70
+
71
+
72
+ # --- 3. TẠO GIAO DIỆN ---
73
+ # Giữ nguyên như cũ
74
+ gr.ChatInterface(
75
+ predict,
76
+ title="Struct-Aware Baseline Qwen3 4B",
77
+ description="Giao diện chat cho mô hình phamhoangf/struct-aware-baseline-qwen3-4b"
78
+ ).launch()```
79
+