Turtle1106 commited on
Commit
5a0326c
·
verified ·
1 Parent(s): 5cff108

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -56
app.py CHANGED
@@ -1,57 +1,70 @@
1
- import gradio as gr
2
- from transformers import pipeline
3
-
4
- # --- 1. 設定模型路徑 (請替換成你自己的 Hugging Face Model ID 或 本地路徑) ---
5
- # 為了示範,這裡使用兩個公開的小模型來模擬 "原版" vs "微調版"
6
- base_model_id = "gpt2" # 這是原模型
7
- tuned_model_id = "distilgpt2" # 假設這是訓練後模型 (這裡先用 distil 代替)
8
-
9
- # --- 2. 載入模型 (使用 pipeline 比較方便) ---
10
- print("正在載入模型,請稍候...")
11
- # 實務上建議使用 GPU (device=0),如果是 Hugging Face Free Space 則用 CPU (device=-1)
12
- pipe_base = pipeline("text-generation", model=base_model_id)
13
- pipe_tuned = pipeline("text-generation", model=tuned_model_id)
14
-
15
- # --- 3.義推論函數 ---
16
- def compare_models(input_text):
17
- # 設定生成參數 (可以根據需求調整 max_length, temperature 等)
18
- gen_kwargs = {"max_length": 100, "num_return_sequences": 1}
19
-
20
- # 原模型生成
21
- out_base = pipe_base(input_text, **gen_kwargs)[0]['generated_text']
22
-
23
- # 訓練後模型生成
24
- out_tuned = pipe_tuned(input_text, **gen_kwargs)[0]['generated_text']
25
-
26
- return out_base, out_tuned
27
-
28
- # --- 4. 建置 Gradio 介面 ---
29
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
30
- gr.Markdown("# 🚀 模型訓練成果比對系統")
31
- gr.Markdown("輸入一段 Prompt,同時觀察原模型與訓練後模型的差異。")
32
-
33
- with gr.Row():
34
- input_box = gr.Textbox(label="輸入測試文字 (Prompt)", placeholder="例如:請寫一封信給客戶...", lines=3)
35
-
36
- submit_btn = gr.Button("開始比對", variant="primary")
37
-
38
- gr.Markdown("### 📊 比對結果")
39
-
40
- # 使用 Row 兩個輸出框並排顯示
41
- with gr.Row():
42
- with gr.Column():
43
- output_base = gr.Textbox(label="❌ 原模型輸出 (Base Model)", lines=10, interactive=False)
44
-
45
- with gr.Column():
46
- output_tuned = gr.Textbox(label="✅ 訓練後模型 (Fine-tuned)", lines=10, interactive=False)
47
-
48
- # 設定按鈕動作
49
- submit_btn.click(
50
- fn=compare_models,
51
- inputs=input_box,
52
- outputs=[output_base, output_tuned]
53
- )
54
-
55
- # --- 5. 啟動 ---
56
- if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  demo.launch()
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from peft import PeftModel
5
+
6
+ # --- 定 ---
7
+ base_model_id = "如果你用Llama3這裡填Llama3的路徑" # 例如 "unsloth/llama-3-8b-bnb-4bit"
8
+ adapter_model_id = "你的帳號/你的模型名稱-lora" # 剛剛在 Kaggle 推上去的那個 ID
9
+
10
+ # --- 1. 載入模型 (記憶體優化版) ---
11
+ print("正在載入基底模型...")
12
+ tokenizer = AutoTokenizer.from_pretrained(base_model_id)
13
+
14
+ # 如果是免費 CPU Space,不要用 4bit (bitsandbytes 對 CPU 支援不好),直接用 float32 或 bfloat16
15
+ # 如果你有買 GPU Space,一要加 load_in_4bit=True
16
+ model = AutoModelForCausalLM.from_pretrained(
17
+ base_model_id,
18
+ device_map="auto",
19
+ torch_dtype=torch.float16, # CPU 建議用 float32,有 GPU 用 float16
20
+ low_cpu_mem_usage=True
21
+ )
22
+
23
+ print("正在掛載 LoRA Adapter...")
24
+ # 這一步把訓練好的微調層掛上去
25
+ model = PeftModel.from_pretrained(model, adapter_model_id)
26
+
27
+ # --- 2. 定義推論邏輯 (核心技巧) ---
28
+ def compare_inference(prompt):
29
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
30
+
31
+ # 設定生成參數
32
+ gen_kwargs = {
33
+ "max_new_tokens": 150,
34
+ "do_sample": True,
35
+ "temperature": 0.7,
36
+ "top_p": 0.9
37
+ }
38
+
39
+ # === A. 生成:原模型 (Base Model) ===
40
+ # 使用 disable_adapter() 暫時關閉 LoRA,模型變回原本的樣子
41
+ with model.disable_adapter():
42
+ output_base_ids = model.generate(**inputs, **gen_kwargs)
43
+ output_base = tokenizer.decode(output_base_ids[0], skip_special_tokens=True)
44
+ # 簡單處理,只取 prompt 之後的文字
45
+ response_base = output_base.replace(prompt, "").strip()
46
+
47
+ # === B. 生成:訓練後模型 (Fine-tuned) ===
48
+ # 這裡正常生成,LoRA 會生效
49
+ output_tuned_ids = model.generate(**inputs, **gen_kwargs)
50
+ output_tuned = tokenizer.decode(output_tuned_ids[0], skip_special_tokens=True)
51
+ response_tuned = output_tuned.replace(prompt, "").strip()
52
+
53
+ return response_base, response_tuned
54
+
55
+ # --- 3. 介面設計 ---
56
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
57
+ gr.Markdown("# ⚔️ SFT 模型效果比對")
58
+ gr.Markdown(f"基底模型: `{base_model_id}` vs 微調權重: `{adapter_model_id}`")
59
+
60
+ inp = gr.Textbox(label="輸入測試指令 (Prompt)", placeholder="例如:這家公司的財報重點是什麼?", lines=2)
61
+ btn = gr.Button("開始比對 (Generate)", variant="primary")
62
+
63
+ with gr.Row():
64
+ out_base = gr.Textbox(label="原始模型 (Base)", lines=8)
65
+ out_tuned = gr.Textbox(label="訓練後模型 (SFT)", lines=8)
66
+
67
+ btn.click(compare_inference, inputs=inp, outputs=[out_base, out_tuned])
68
+
69
+ if __name__ == "__main__":
70
  demo.launch()