Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| # Thêm BitsAndBytesConfig để cấu hình quantization | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, BitsAndBytesConfig | |
| from threading import Thread | |
| import os | |
| # --- 1. CÀI ĐẶT MODEL VỚI QUANTIZATION 4-BIT --- | |
| # Lấy token từ secrets của Space | |
| hf_token = os.environ.get("HF_TOKEN") | |
| model_id = "phamhoangf/struct-aware-baseline-qwen3-4b" | |
| # Tải tokenizer (không thay đổi) | |
| tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token) | |
| # Cấu hình quantization 4-bit | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16 | |
| ) | |
| # Tải model với cấu hình quantization | |
| # Điều này sẽ giảm VRAM sử dụng đi ~ một nửa | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| quantization_config=quantization_config, | |
| device_map="auto", # device_map="auto" tự động xử lý việc đặt các lớp lên GPU | |
| token=hf_token | |
| ) | |
| # --- 2. HÀM DỰ ĐOÁN ĐÃ HỖ TRỢ STREAMING (KHÔNG THAY ĐỔI) --- | |
| def predict(message, history): | |
| messages = [] | |
| for user_msg, assistant_msg in history: | |
| messages.append({"role": "user", "content": user_msg}) | |
| messages.append({"role": "assistant", "content": assistant_msg}) | |
| messages.append({"role": "user", "content": message}) | |
| prompt = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| generation_kwargs = dict( | |
| **inputs, | |
| streamer=streamer, | |
| max_new_tokens=1024, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.8, | |
| top_k=20, | |
| ) | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| generated_text = "" | |
| for new_text in streamer: | |
| generated_text += new_text | |
| yield generated_text | |
| # --- 3. TẠO GIAO DIỆN --- | |
| # Thêm type="messages" để loại bỏ cảnh báo (warning) | |
| gr.ChatInterface( | |
| predict, | |
| chatbot=gr.Chatbot(height=500), | |
| title="Struct-Aware Baseline Qwen3 4B (4-bit)", | |
| description="Giao diện chat cho mô hình phamhoangf/struct-aware-baseline-qwen3-4b (chạy với 4-bit quantization).", | |
| type="messages" | |
| ).launch() |