Spaces:
Runtime error
Runtime error
File size: 2,473 Bytes
4ab68e3 8a8eb9f c8417b9 88f5ef8 24dfae3 95c30e0 c8417b9 88f5ef8 95c30e0 8a8eb9f 95c30e0 c8417b9 95c30e0 c8417b9 88f5ef8 c8417b9 88f5ef8 8a8eb9f c8417b9 88f5ef8 8a8eb9f 88f5ef8 8a8eb9f c8417b9 8a8eb9f 88f5ef8 c8417b9 88f5ef8 c8417b9 88f5ef8 c8417b9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 | import gradio as gr
import torch
# Thêm BitsAndBytesConfig để cấu hình quantization
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, BitsAndBytesConfig
from threading import Thread
import os
# --- 1. CÀI ĐẶT MODEL VỚI QUANTIZATION 4-BIT ---
# Lấy token từ secrets của Space
hf_token = os.environ.get("HF_TOKEN")
model_id = "phamhoangf/struct-aware-baseline-qwen3-4b"
# Tải tokenizer (không thay đổi)
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
# Cấu hình quantization 4-bit
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
# Tải model với cấu hình quantization
# Điều này sẽ giảm VRAM sử dụng đi ~ một nửa
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=quantization_config,
device_map="auto", # device_map="auto" tự động xử lý việc đặt các lớp lên GPU
token=hf_token
)
# --- 2. HÀM DỰ ĐOÁN ĐÃ HỖ TRỢ STREAMING (KHÔNG THAY ĐỔI) ---
def predict(message, history):
messages = []
for user_msg, assistant_msg in history:
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "assistant", "content": assistant_msg})
messages.append({"role": "user", "content": message})
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=1024,
do_sample=True,
temperature=0.7,
top_p=0.8,
top_k=20,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
generated_text = ""
for new_text in streamer:
generated_text += new_text
yield generated_text
# --- 3. TẠO GIAO DIỆN ---
# Thêm type="messages" để loại bỏ cảnh báo (warning)
gr.ChatInterface(
predict,
chatbot=gr.Chatbot(height=500),
title="Struct-Aware Baseline Qwen3 4B (4-bit)",
description="Giao diện chat cho mô hình phamhoangf/struct-aware-baseline-qwen3-4b (chạy với 4-bit quantization).",
type="messages"
).launch() |