File size: 2,473 Bytes
4ab68e3
8a8eb9f
c8417b9
 
88f5ef8
24dfae3
95c30e0
c8417b9
88f5ef8
95c30e0
 
8a8eb9f
95c30e0
c8417b9
95c30e0
 
c8417b9
 
 
 
 
 
 
 
 
88f5ef8
 
c8417b9
 
88f5ef8
8a8eb9f
 
c8417b9
88f5ef8
8a8eb9f
 
 
 
 
 
 
88f5ef8
 
 
 
 
 
 
 
 
 
 
 
 
8a8eb9f
 
c8417b9
 
8a8eb9f
88f5ef8
 
 
 
 
 
 
c8417b9
88f5ef8
 
c8417b9
88f5ef8
 
c8417b9
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import gradio as gr
import torch
# Thêm BitsAndBytesConfig để cấu hình quantization
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, BitsAndBytesConfig
from threading import Thread
import os

# --- 1. CÀI ĐẶT MODEL VỚI QUANTIZATION 4-BIT ---

# Lấy token từ secrets của Space
hf_token = os.environ.get("HF_TOKEN")
model_id = "phamhoangf/struct-aware-baseline-qwen3-4b"

# Tải tokenizer (không thay đổi)
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)

# Cấu hình quantization 4-bit
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

# Tải model với cấu hình quantization
# Điều này sẽ giảm VRAM sử dụng đi ~ một nửa
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=quantization_config,
    device_map="auto", # device_map="auto" tự động xử lý việc đặt các lớp lên GPU
    token=hf_token
)

# --- 2. HÀM DỰ ĐOÁN ĐÃ HỖ TRỢ STREAMING (KHÔNG THAY ĐỔI) ---

def predict(message, history):
    messages = []
    for user_msg, assistant_msg in history:
        messages.append({"role": "user", "content": user_msg})
        messages.append({"role": "assistant", "content": assistant_msg})
    messages.append({"role": "user", "content": message})

    prompt = tokenizer.apply_chat_template(
        messages, 
        tokenize=False, 
        add_generation_prompt=True
    )

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

    generation_kwargs = dict(
        **inputs,
        streamer=streamer,
        max_new_tokens=1024,
        do_sample=True,
        temperature=0.7,
        top_p=0.8,
        top_k=20,
    )

    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    generated_text = ""
    for new_text in streamer:
        generated_text += new_text
        yield generated_text

# --- 3. TẠO GIAO DIỆN ---
# Thêm type="messages" để loại bỏ cảnh báo (warning)
gr.ChatInterface(
    predict,
    chatbot=gr.Chatbot(height=500),
    title="Struct-Aware Baseline Qwen3 4B (4-bit)",
    description="Giao diện chat cho mô hình phamhoangf/struct-aware-baseline-qwen3-4b (chạy với 4-bit quantization).",
    type="messages" 
).launch()