Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| # --- 1. 設定模型 ID --- | |
| # 這是救急策略:用公開的小模型來展示 "Base vs SFT" 的差異 | |
| # Base Model (左邊): 沒學過對話,只會接龍 | |
| base_model_id = "Qwen/Qwen2.5-1.5B" | |
| # Instruct Model (右邊): 已經做過 SFT,會聽指令 (用來代表你的訓練成果) | |
| chat_model_id = "Qwen/Qwen2.5-1.5B-Instruct" | |
| print("正在載入模型,這可能需要幾分鐘...") | |
| # --- 2. 載入模型 (使用 CPU, float32 以確保相容性) --- | |
| # 載入 Base Model | |
| tokenizer_base = AutoTokenizer.from_pretrained(base_model_id) | |
| model_base = AutoModelForCausalLM.from_pretrained(base_model_id, torch_dtype=torch.float32) | |
| # 載入 Instruct Model | |
| tokenizer_chat = AutoTokenizer.from_pretrained(chat_model_id) | |
| model_chat = AutoModelForCausalLM.from_pretrained(chat_model_id, torch_dtype=torch.float32) | |
| # --- 3. 定義核心比對邏輯 --- | |
| def compare_models(prompt): | |
| # 生成參數設定 | |
| gen_kwargs = { | |
| "max_new_tokens": 150, | |
| "do_sample": True, | |
| "temperature": 0.7, | |
| "top_p": 0.9, | |
| "repetition_penalty": 1.1 | |
| } | |
| # === A. 左邊:原始模型 (Base) === | |
| # Base 模型不懂對話,通常不需要 apply_chat_template,直接丟文字進去 | |
| inputs_base = tokenizer_base(prompt, return_tensors="pt") | |
| outputs_base = model_base.generate(**inputs_base, **gen_kwargs) | |
| # 只取生成的內容 | |
| response_base = tokenizer_base.decode(outputs_base[0], skip_special_tokens=True) | |
| # === B. 右邊:訓練後模型 (Instruct/SFT) === | |
| # Instruct 模型需要套用對話模板 | |
| messages = [{"role": "user", "content": prompt}] | |
| text_chat = tokenizer_chat.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| inputs_chat = tokenizer_chat(text_chat, return_tensors="pt") | |
| outputs_chat = model_chat.generate(**inputs_chat, **gen_kwargs) | |
| # 解碼時去掉前面的 prompt | |
| response_chat = tokenizer_chat.decode(outputs_chat[0], skip_special_tokens=True) | |
| # 簡單處理:移除系統指令部分 (視模型輸出而定,有時需要字串處理) | |
| if "user" in response_chat and "assistant" in response_chat: | |
| try: | |
| response_chat = response_chat.split("assistant\n")[-1] | |
| except: | |
| pass | |
| return response_base, response_chat | |
| # --- 4. 建立介面 --- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("#SFT Demo") | |
| gr.Markdown("展示 **原始基底模型 (Base)** 與 **指令微調後模型 (SFT)** 針對同一問題的回答差異。") | |
| with gr.Row(): | |
| input_box = gr.Textbox(label="輸入測試問題 (Prompt)", placeholder="例如:請推薦台灣的旅遊景點。", lines=2) | |
| submit_btn = gr.Button("開始比對", variant="primary") | |
| with gr.Row(): | |
| with gr.Column(): | |
| output_base = gr.Textbox(label="❌ 訓練前 (Base Model)", lines=10, interactive=False) | |
| gr.Markdown("*特徵:容易答非所問、不斷重複問題、或像機器人一樣講廢話。*") | |
| with gr.Column(): | |
| output_chat = gr.Textbox(label="✅ 訓練後 (SFT Model)", lines=10, interactive=False) | |
| gr.Markdown("*特徵:能理解指令、條理分明、回答符合人類邏輯。*") | |
| submit_btn.click(compare_models, inputs=input_box, outputs=[output_base, output_chat]) | |
| # 啟動 | |
| if __name__ == "__main__": | |
| demo.launch() |