import gradio as gr import torch # Thêm BitsAndBytesConfig để cấu hình quantization from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, BitsAndBytesConfig from threading import Thread import os # --- 1. CÀI ĐẶT MODEL VỚI QUANTIZATION 4-BIT --- # Lấy token từ secrets của Space hf_token = os.environ.get("HF_TOKEN") model_id = "phamhoangf/struct-aware-baseline-qwen3-4b" # Tải tokenizer (không thay đổi) tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token) # Cấu hình quantization 4-bit quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 ) # Tải model với cấu hình quantization # Điều này sẽ giảm VRAM sử dụng đi ~ một nửa model = AutoModelForCausalLM.from_pretrained( model_id, quantization_config=quantization_config, device_map="auto", # device_map="auto" tự động xử lý việc đặt các lớp lên GPU token=hf_token ) # --- 2. HÀM DỰ ĐOÁN ĐÃ HỖ TRỢ STREAMING (KHÔNG THAY ĐỔI) --- def predict(message, history): messages = [] for user_msg, assistant_msg in history: messages.append({"role": "user", "content": user_msg}) messages.append({"role": "assistant", "content": assistant_msg}) messages.append({"role": "user", "content": message}) prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = tokenizer(prompt, return_tensors="pt").to(model.device) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) generation_kwargs = dict( **inputs, streamer=streamer, max_new_tokens=1024, do_sample=True, temperature=0.7, top_p=0.8, top_k=20, ) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() generated_text = "" for new_text in streamer: generated_text += new_text yield generated_text # --- 3. TẠO GIAO DIỆN --- # Thêm type="messages" để loại bỏ cảnh báo (warning) gr.ChatInterface( predict, chatbot=gr.Chatbot(height=500), title="Struct-Aware Baseline Qwen3 4B (4-bit)", description="Giao diện chat cho mô hình phamhoangf/struct-aware-baseline-qwen3-4b (chạy với 4-bit quantization).", type="messages" ).launch()