Turtle1106 commited on
Commit
83998e5
·
verified ·
1 Parent(s): b0ae6dd

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -0
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+
4
+ # --- 1. 設定模型路徑 (請替換成你自己的 Hugging Face Model ID 或 本地路徑) ---
5
+ # 為了示範,這裡使用兩個公開的小模型來模擬 "原版" vs "微調版"
6
+ base_model_id = "gpt2" # 假設這是原模型
7
+ tuned_model_id = "distilgpt2" # 假設這是訓練後模型 (這裡先用 distil 代替)
8
+
9
+ # --- 2. 載入模型 (使用 pipeline 比較方便) ---
10
+ print("正在載入模型,請稍候...")
11
+ # 實務上建議使用 GPU (device=0),如果是 Hugging Face Free Space 則用 CPU (device=-1)
12
+ pipe_base = pipeline("text-generation", model=base_model_id)
13
+ pipe_tuned = pipeline("text-generation", model=tuned_model_id)
14
+
15
+ # --- 3. 定義推論函數 ---
16
+ def compare_models(input_text):
17
+ # 設定生成參數 (可以根據需求調整 max_length, temperature 等)
18
+ gen_kwargs = {"max_length": 100, "num_return_sequences": 1}
19
+
20
+ # 原模型生成
21
+ out_base = pipe_base(input_text, **gen_kwargs)[0]['generated_text']
22
+
23
+ # 訓練後模型生成
24
+ out_tuned = pipe_tuned(input_text, **gen_kwargs)[0]['generated_text']
25
+
26
+ return out_base, out_tuned
27
+
28
+ # --- 4. 建置 Gradio 介面 ---
29
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
30
+ gr.Markdown("# 🚀 模型訓練成果比對系統")
31
+ gr.Markdown("輸入一段 Prompt,同時觀察原模型與訓練後模型的差異。")
32
+
33
+ with gr.Row():
34
+ input_box = gr.Textbox(label="輸入測試文字 (Prompt)", placeholder="例如:請寫一封信給客戶...", lines=3)
35
+
36
+ submit_btn = gr.Button("開始比對", variant="primary")
37
+
38
+ gr.Markdown("### 📊 比對結果")
39
+
40
+ # 使用 Row 來讓兩個輸出框並排顯示
41
+ with gr.Row():
42
+ with gr.Column():
43
+ output_base = gr.Textbox(label="❌ 原模型輸出 (Base Model)", lines=10, interactive=False)
44
+
45
+ with gr.Column():
46
+ output_tuned = gr.Textbox(label="✅ 訓練後模型 (Fine-tuned)", lines=10, interactive=False)
47
+
48
+ # 設定按鈕動作
49
+ submit_btn.click(
50
+ fn=compare_models,
51
+ inputs=input_box,
52
+ outputs=[output_base, output_tuned]
53
+ )
54
+
55
+ # --- 5. 啟動 ---
56
+ if __name__ == "__main__":
57
+ demo.launch()