Taiwan-Bot / app.py
Turtle1106's picture
Update app.py
98d1185 verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# --- 1. 設定模型 ID ---
# 這是救急策略:用公開的小模型來展示 "Base vs SFT" 的差異
# Base Model (左邊): 沒學過對話,只會接龍
base_model_id = "Qwen/Qwen2.5-1.5B"
# Instruct Model (右邊): 已經做過 SFT,會聽指令 (用來代表你的訓練成果)
chat_model_id = "Qwen/Qwen2.5-1.5B-Instruct"
print("正在載入模型,這可能需要幾分鐘...")
# --- 2. 載入模型 (使用 CPU, float32 以確保相容性) ---
# 載入 Base Model
tokenizer_base = AutoTokenizer.from_pretrained(base_model_id)
model_base = AutoModelForCausalLM.from_pretrained(base_model_id, torch_dtype=torch.float32)
# 載入 Instruct Model
tokenizer_chat = AutoTokenizer.from_pretrained(chat_model_id)
model_chat = AutoModelForCausalLM.from_pretrained(chat_model_id, torch_dtype=torch.float32)
# --- 3. 定義核心比對邏輯 ---
def compare_models(prompt):
# 生成參數設定
gen_kwargs = {
"max_new_tokens": 150,
"do_sample": True,
"temperature": 0.7,
"top_p": 0.9,
"repetition_penalty": 1.1
}
# === A. 左邊:原始模型 (Base) ===
# Base 模型不懂對話,通常不需要 apply_chat_template,直接丟文字進去
inputs_base = tokenizer_base(prompt, return_tensors="pt")
outputs_base = model_base.generate(**inputs_base, **gen_kwargs)
# 只取生成的內容
response_base = tokenizer_base.decode(outputs_base[0], skip_special_tokens=True)
# === B. 右邊:訓練後模型 (Instruct/SFT) ===
# Instruct 模型需要套用對話模板
messages = [{"role": "user", "content": prompt}]
text_chat = tokenizer_chat.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs_chat = tokenizer_chat(text_chat, return_tensors="pt")
outputs_chat = model_chat.generate(**inputs_chat, **gen_kwargs)
# 解碼時去掉前面的 prompt
response_chat = tokenizer_chat.decode(outputs_chat[0], skip_special_tokens=True)
# 簡單處理:移除系統指令部分 (視模型輸出而定,有時需要字串處理)
if "user" in response_chat and "assistant" in response_chat:
try:
response_chat = response_chat.split("assistant\n")[-1]
except:
pass
return response_base, response_chat
# --- 4. 建立介面 ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("#SFT Demo")
gr.Markdown("展示 **原始基底模型 (Base)** 與 **指令微調後模型 (SFT)** 針對同一問題的回答差異。")
with gr.Row():
input_box = gr.Textbox(label="輸入測試問題 (Prompt)", placeholder="例如:請推薦台灣的旅遊景點。", lines=2)
submit_btn = gr.Button("開始比對", variant="primary")
with gr.Row():
with gr.Column():
output_base = gr.Textbox(label="❌ 訓練前 (Base Model)", lines=10, interactive=False)
gr.Markdown("*特徵:容易答非所問、不斷重複問題、或像機器人一樣講廢話。*")
with gr.Column():
output_chat = gr.Textbox(label="✅ 訓練後 (SFT Model)", lines=10, interactive=False)
gr.Markdown("*特徵:能理解指令、條理分明、回答符合人類邏輯。*")
submit_btn.click(compare_models, inputs=input_box, outputs=[output_base, output_chat])
# 啟動
if __name__ == "__main__":
demo.launch()