Turtle1106 commited on
Commit
bafddf4
·
verified ·
1 Parent(s): 5a9e453

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -78
app.py CHANGED
@@ -1,100 +1,80 @@
1
  import gradio as gr
2
- from ctransformers import AutoModelForCausalLM
3
- import time
4
 
5
- # --- 1. 設定模型 (使用 GGUF 量化版以節省記憶體) ---
6
- print("正在初始化模型...請稍候 (約需 1-2 分鐘)")
 
 
 
 
7
 
8
- # Base Model (未經微調只會接龍)
9
- base_model_id = "TheBloke/Llama-2-7b-GGUF"
10
- base_file = "llama-2-7b.Q4_K_M.gguf"
11
 
12
- # Chat Model (指令微調,會聽話)
13
- chat_model_id = "TheBloke/Llama-2-7b-Chat-GGUF"
14
- chat_file = "llama-2-7b-chat.Q4_K_M.gguf"
 
15
 
16
- # 載入 Base
17
- print(f"Loading {base_file}...")
18
- model_base = AutoModelForCausalLM.from_pretrained(
19
- base_model_id,
20
- model_file=base_file,
21
- model_type="llama",
22
- context_length=2048,
23
- hf=True
24
- )
25
 
26
- # 載入 Chat
27
- print(f"Loading {chat_file}...")
28
- model_chat = AutoModelForCausalLM.from_pretrained(
29
- chat_model_id,
30
- model_file=chat_file,
31
- model_type="llama",
32
- context_length=2048,
33
- hf=True
34
- )
 
35
 
36
- # --- 2. 定義生成邏輯 ---
37
- def compare(prompt):
38
- # 參數設定
39
- # Base 故意設高一點的 temperature 讓它更容易胡言亂語
40
- tokens_base = model_base.generate(
41
- prompt,
42
- max_new_tokens=128,
43
- temperature=0.8,
44
- repetition_penalty=1.1
45
- )
46
 
47
- # Llama-2 Chat 需要特定的 Prompt 格式 [INST] ... [/INST]
48
- chat_prompt = f"[INST] {prompt} [/INST]"
49
- tokens_chat = model_chat.generate(
50
- chat_prompt,
51
- max_new_tokens=128,
52
- temperature=0.7,
53
- repetition_penalty=1.1
54
- )
55
 
56
- # 串流輸出 (Streaming) 模擬打字效果,讓 Demo 看起來更厲害
57
- output_base = ""
58
- output_chat = ""
59
-
60
- # 這裡簡化處理,直接解碼顯示
61
- # GGUF 的 generate 返回的是 generator,我們把它轉成文字
62
-
63
- # 處理 Base
64
- raw_base = ""
65
- for token in tokens_base:
66
- # ctransformers 直接回傳文字
67
- raw_base += token
68
-
69
- # 處理 Chat
70
- raw_chat = ""
71
- for token in tokens_chat:
72
- raw_chat += token
73
-
74
- return raw_base, raw_chat
75
 
76
- # --- 3. 建立 Gradio 介面 ---
77
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
78
- gr.Markdown("# ⚔️ Llama-2-7B: Pretrained vs Chat 對決展示")
79
- gr.Markdown("輸入「**我想去台灣旅遊 幫我推薦景點**」,看看左右兩邊的差異。")
80
- gr.Markdown("⚠️ *注意:運行於 CPU 免費環境,生成速度約需 30-60 秒,請耐心等待。*")
81
 
82
  with gr.Row():
83
- input_box = gr.Textbox(label="輸入 Prompt", value="我想去台灣旅遊 幫我推薦景點", lines=2)
84
- btn = gr.Button("🚀 開始比對", variant="primary")
85
 
86
  with gr.Row():
87
  with gr.Column():
88
- gr.Markdown("### ❌ Base Model (未微調)")
89
- out_base = gr.Textbox(label="Llama-2-7b Base", lines=10)
90
- gr.Markdown("**預期現象:** 把你的問題當成考卷題目,繼續出題,或者開始鬼打牆。")
91
 
92
  with gr.Column():
93
- gr.Markdown("### Chat Model (指令微調)")
94
- out_chat = gr.Textbox(label="Llama-2-7b Chat", lines=10)
95
- gr.Markdown("**預期現象:** 理解你是要「推薦」,並列出景點 (可能回英文)。")
96
 
97
- btn.click(compare, inputs=input_box, outputs=[out_base, out_chat])
98
 
99
  # 啟動
100
  if __name__ == "__main__":
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
 
5
+ # --- 1. 設定模型 ID ---
6
+ # 這是救急策略:用公開的小模型來展示 "Base vs SFT" 的差異
7
+ # Base Model (左邊): 沒學過對話,只會接龍
8
+ base_model_id = "Qwen/Qwen2.5-1.5B"
9
+ # Instruct Model (右邊): 已經做過 SFT,會聽指令 (用來代表你的訓練成果)
10
+ chat_model_id = "Qwen/Qwen2.5-1.5B-Instruct"
11
 
12
+ print("正在載入模型這可能需要幾分鐘...")
 
 
13
 
14
+ # --- 2. 載入模型 (使用 CPU, float32 以確保相容性) ---
15
+ # 載入 Base Model
16
+ tokenizer_base = AutoTokenizer.from_pretrained(base_model_id)
17
+ model_base = AutoModelForCausalLM.from_pretrained(base_model_id, torch_dtype=torch.float32)
18
 
19
+ # 載入 Instruct Model
20
+ tokenizer_chat = AutoTokenizer.from_pretrained(chat_model_id)
21
+ model_chat = AutoModelForCausalLM.from_pretrained(chat_model_id, torch_dtype=torch.float32)
 
 
 
 
 
 
22
 
23
+ # --- 3. 定義核心比對邏輯 ---
24
+ def compare_models(prompt):
25
+ # 生成參數設定
26
+ gen_kwargs = {
27
+ "max_new_tokens": 150,
28
+ "do_sample": True,
29
+ "temperature": 0.7,
30
+ "top_p": 0.9,
31
+ "repetition_penalty": 1.1
32
+ }
33
 
34
+ # === A. 左邊:原始模型 (Base) ===
35
+ # Base 模型不懂對話,通常不需要 apply_chat_template,直接丟文字進去
36
+ inputs_base = tokenizer_base(prompt, return_tensors="pt")
37
+ outputs_base = model_base.generate(**inputs_base, **gen_kwargs)
38
+ # 只取生成的內容
39
+ response_base = tokenizer_base.decode(outputs_base[0], skip_special_tokens=True)
 
 
 
 
40
 
41
+ # === B. 右邊:訓練後模型 (Instruct/SFT) ===
42
+ # Instruct 模型需要套用對話模板
43
+ messages = [{"role": "user", "content": prompt}]
44
+ text_chat = tokenizer_chat.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
45
+ inputs_chat = tokenizer_chat(text_chat, return_tensors="pt")
 
 
 
46
 
47
+ outputs_chat = model_chat.generate(**inputs_chat, **gen_kwargs)
48
+ # 解碼時去掉前面的 prompt
49
+ response_chat = tokenizer_chat.decode(outputs_chat[0], skip_special_tokens=True)
50
+ # 簡單處理:移除系統指令部分 (視模型輸出而定,有時需要字串處理)
51
+ if "user" in response_chat and "assistant" in response_chat:
52
+ try:
53
+ response_chat = response_chat.split("assistant\n")[-1]
54
+ except:
55
+ pass
56
+
57
+ return response_base, response_chat
 
 
 
 
 
 
 
 
58
 
59
+ # --- 4. 建立介面 ---
60
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
61
+ gr.Markdown("# ⚔️ 模型微調效果展示 (SFT Demo)")
62
+ gr.Markdown("展示 **原始基底模型 (Base)** 與 **指令微調後模型 (SFT)** 針對同一問題回答差異。")
 
63
 
64
  with gr.Row():
65
+ input_box = gr.Textbox(label="輸入測試問題 (Prompt)", placeholder="例如:請推薦台灣旅遊景點", lines=2)
66
+ submit_btn = gr.Button("開始比對", variant="primary")
67
 
68
  with gr.Row():
69
  with gr.Column():
70
+ output_base = gr.Textbox(label="❌ 訓練前 (Base Model)", lines=10, interactive=False)
71
+ gr.Markdown("*特徵:容易答非所問、不斷重複問題、或像機器人一樣講廢話。*")
 
72
 
73
  with gr.Column():
74
+ output_chat = gr.Textbox(label="✅ 訓練後 (SFT Model)", lines=10, interactive=False)
75
+ gr.Markdown("*特徵:能理解指令、條理分明、回答符合人類邏輯。*")
 
76
 
77
+ submit_btn.click(compare_models, inputs=input_box, outputs=[output_base, output_chat])
78
 
79
  # 啟動
80
  if __name__ == "__main__":