phamhoangf commited on
Commit
c8417b9
·
verified ·
1 Parent(s): f5e213e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -27
app.py CHANGED
@@ -1,31 +1,38 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
 
4
  from threading import Thread
5
  import os
6
 
7
- # --- 1. CÀI ĐẶT MODEL ---
8
 
9
  # Lấy token từ secrets của Space
10
  hf_token = os.environ.get("HF_TOKEN")
11
  model_id = "phamhoangf/struct-aware-baseline-qwen3-4b"
12
 
13
- # Tải tokenizer
14
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
15
 
16
- # Tải model trực tiếp thay vì dùng pipeline
17
- # Điều này cho phép chúng ta truy cập hàm .generate() với streamer
 
 
 
 
 
 
 
18
  model = AutoModelForCausalLM.from_pretrained(
19
  model_id,
20
- torch_dtype=torch.bfloat16,
21
- device_map="auto",
22
  token=hf_token
23
  )
24
 
25
- # --- 2. HÀM DỰ ĐOÁN ĐÃ HỖ TRỢ STREAMING ---
26
 
27
  def predict(message, history):
28
- # Xây dựng prompt từ lịch sử trò chuyện
29
  messages = []
30
  for user_msg, assistant_msg in history:
31
  messages.append({"role": "user", "content": user_msg})
@@ -38,42 +45,33 @@ def predict(message, history):
38
  add_generation_prompt=True
39
  )
40
 
41
- # Tokenize input
42
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
43
-
44
- # Tạo một streamer
45
- # skip_prompt=True để không lặp lại prompt trong output
46
- # skip_special_tokens=True để bỏ qua các token như </s>
47
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
48
 
49
- # Các tham số cho việc sinh token
50
  generation_kwargs = dict(
51
  **inputs,
52
  streamer=streamer,
53
  max_new_tokens=1024,
54
  do_sample=True,
55
  temperature=0.7,
56
- top_p=0.95,
57
- top_k=50,
58
  )
59
 
60
- # Chạy việc sinh token trong một luồng (thread) riêng
61
- # để nó không block luồng chính đang trả kết quả cho Gradio
62
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
63
  thread.start()
64
 
65
- # Yield từng phần của văn bản được tạo ra
66
  generated_text = ""
67
  for new_text in streamer:
68
  generated_text += new_text
69
- yield generated_text # Trả về chuỗi đã được cập nhật cho Gradio
70
-
71
 
72
  # --- 3. TẠO GIAO DIỆN ---
73
- # Giữ nguyên như
74
  gr.ChatInterface(
75
  predict,
76
- title="Struct-Aware Baseline Qwen3 4B",
77
- description="Giao diện chat cho mô hình phamhoangf/struct-aware-baseline-qwen3-4b"
78
- ).launch()
79
-
 
 
1
  import gradio as gr
2
  import torch
3
+ # Thêm BitsAndBytesConfig để cấu hình quantization
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, BitsAndBytesConfig
5
  from threading import Thread
6
  import os
7
 
8
+ # --- 1. CÀI ĐẶT MODEL VỚI QUANTIZATION 4-BIT ---
9
 
10
  # Lấy token từ secrets của Space
11
  hf_token = os.environ.get("HF_TOKEN")
12
  model_id = "phamhoangf/struct-aware-baseline-qwen3-4b"
13
 
14
+ # Tải tokenizer (không thay đổi)
15
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
16
 
17
+ # Cấu hình quantization 4-bit
18
+ quantization_config = BitsAndBytesConfig(
19
+ load_in_4bit=True,
20
+ bnb_4bit_quant_type="nf4",
21
+ bnb_4bit_compute_dtype=torch.bfloat16
22
+ )
23
+
24
+ # Tải model với cấu hình quantization
25
+ # Điều này sẽ giảm VRAM sử dụng đi ~ một nửa
26
  model = AutoModelForCausalLM.from_pretrained(
27
  model_id,
28
+ quantization_config=quantization_config,
29
+ device_map="auto", # device_map="auto" tự động xử lý việc đặt các lớp lên GPU
30
  token=hf_token
31
  )
32
 
33
+ # --- 2. HÀM DỰ ĐOÁN ĐÃ HỖ TRỢ STREAMING (KHÔNG THAY ĐỔI) ---
34
 
35
  def predict(message, history):
 
36
  messages = []
37
  for user_msg, assistant_msg in history:
38
  messages.append({"role": "user", "content": user_msg})
 
45
  add_generation_prompt=True
46
  )
47
 
 
48
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
 
 
 
49
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
50
 
 
51
  generation_kwargs = dict(
52
  **inputs,
53
  streamer=streamer,
54
  max_new_tokens=1024,
55
  do_sample=True,
56
  temperature=0.7,
57
+ top_p=0.8,
58
+ top_k=20,
59
  )
60
 
 
 
61
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
62
  thread.start()
63
 
 
64
  generated_text = ""
65
  for new_text in streamer:
66
  generated_text += new_text
67
+ yield generated_text
 
68
 
69
  # --- 3. TẠO GIAO DIỆN ---
70
+ # Thêm type="messages" để loại bỏ cảnh báo (warning)
71
  gr.ChatInterface(
72
  predict,
73
+ chatbot=gr.Chatbot(height=500),
74
+ title="Struct-Aware Baseline Qwen3 4B (4-bit)",
75
+ description="Giao diện chat cho mô hình phamhoangf/struct-aware-baseline-qwen3-4b (chạy với 4-bit quantization).",
76
+ type="messages"
77
+ ).launch()