smartTranscend commited on
Commit
55fafff
·
verified ·
1 Parent(s): 3312ec2

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +2172 -0
app.py ADDED
@@ -0,0 +1,2172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import torch
4
+ from datasets import Dataset, DatasetDict
5
+ from transformers import (
6
+ AutoTokenizer,
7
+ AutoModelForSequenceClassification,
8
+ TrainingArguments,
9
+ Trainer,
10
+ DataCollatorWithPadding
11
+ )
12
+ from peft import (
13
+ LoraConfig,
14
+ AdaLoraConfig,
15
+ AdaptionPromptConfig,
16
+ PromptTuningConfig,
17
+ PrefixTuningConfig,
18
+ get_peft_model,
19
+ TaskType,
20
+ PeftModel
21
+ )
22
+ from sklearn.model_selection import train_test_split
23
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support
24
+ from sklearn.utils import resample
25
+ import numpy as np
26
+ import json
27
+ from datetime import datetime
28
+ import os
29
+ import gc
30
+ from huggingface_hub import login
31
+
32
+ # ==================== 全域變數 ====================
33
+ LAST_MODEL_PATH = None
34
+ LAST_TOKENIZER = None
35
+ MAX_LENGTH = 512
36
+
37
+ # ==================== HF Token 登入 ====================
38
+ print("🔐 檢查 Hugging Face Token...")
39
+ if "HF_TOKEN" in os.environ:
40
+ try:
41
+ login(token=os.environ["HF_TOKEN"])
42
+ print("✅ 已使用 HF Token 登入")
43
+ except Exception as e:
44
+ print(f"⚠️ Token 登入失敗: {e}")
45
+ else:
46
+ print("⚠️ 未找到 HF_TOKEN,可能無法下載 Llama 模型")
47
+
48
+ # 檢測設備
49
+ device = "cuda" if torch.cuda.is_available() else "cpu"
50
+ print(f"🖥️ 使用設備: {device}")
51
+
52
+ # ==================== 核心訓練函數(你的原始邏輯 - 完全不動) ====================
53
+ def run_llama_training(
54
+ file_path,
55
+ model_name,
56
+ target_samples,
57
+ use_class_weights,
58
+ num_epochs,
59
+ batch_size,
60
+ learning_rate,
61
+ tuning_method,
62
+ lora_r,
63
+ lora_alpha,
64
+ lora_dropout,
65
+ lora_target_modules,
66
+ adalora_init_r,
67
+ adalora_target_r,
68
+ adalora_alpha,
69
+ adalora_tinit,
70
+ adalora_tfinal,
71
+ adalora_delta_t,
72
+ adapter_reduction_factor,
73
+ prompt_tuning_num_tokens,
74
+ prefix_tuning_num_tokens,
75
+ best_metric,
76
+ # 【新增】二次微調參數
77
+ is_second_finetuning=False,
78
+ base_model_path=None
79
+ ):
80
+ """
81
+ 你的原始 Llama 訓練邏輯
82
+ """
83
+
84
+ global LAST_MODEL_PATH, LAST_TOKENIZER
85
+
86
+ # ==================== 清空記憶體(訓練前) ====================
87
+ torch.cuda.empty_cache()
88
+ gc.collect()
89
+ print("🧹 記憶體已清空")
90
+
91
+ # ==================== 1. 載入數據 ====================
92
+ training_type = "二次微調" if is_second_finetuning else "第一次微調"
93
+
94
+ print("\n" + "="*80)
95
+ print(f"🦙 Llama NBCD {training_type} - {tuning_method} 方法")
96
+ print("="*80)
97
+ print(f"開始時間: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
98
+ print(f"訓練類型: {training_type}")
99
+ print(f"微調方法: {tuning_method}")
100
+ if is_second_finetuning:
101
+ print(f"基礎模型: {base_model_path}")
102
+ print("="*80)
103
+
104
+ print("📂 載入訓練數據...")
105
+ df = pd.read_csv(file_path)
106
+ print(f"✅ 成功載入 {len(df)} 筆數據")
107
+
108
+ # 自動偵測文本和標籤欄位
109
+ text_col = None
110
+ label_col = None
111
+
112
+ # 支持的文本欄位名稱
113
+ if 'Text' in df.columns:
114
+ text_col = 'Text'
115
+ elif 'text' in df.columns:
116
+ text_col = 'text'
117
+
118
+ # 支持的標籤欄位名稱
119
+ if 'Label' in df.columns:
120
+ label_col = 'Label'
121
+ elif 'label' in df.columns:
122
+ label_col = 'label'
123
+
124
+ if text_col is None or label_col is None:
125
+ raise ValueError(
126
+ f"❌ 無法偵測到正確的欄位名稱!\n"
127
+ f"📋 您的 CSV 欄位: {list(df.columns)}\n\n"
128
+ f"✅ 請使用以下欄位名稱:\n"
129
+ f" 文本欄位: 'Text' 或 'text'\n"
130
+ f" 標籤欄位: 'Label' 或 'label'"
131
+ )
132
+
133
+ print(f" ✅ 偵測到文本欄位: '{text_col}'")
134
+ print(f" ✅ 偵測到標籤欄位: '{label_col}'")
135
+
136
+ # 統一重命名為標準欄位名
137
+ df = df.rename(columns={text_col: 'Text', label_col: 'nbcd'})
138
+
139
+ print(f" 原始 Class 0: {(df['nbcd']==0).sum()} 筆")
140
+ print(f" 原始 Class 1: {(df['nbcd']==1).sum()} 筆")
141
+
142
+ # ==================== 2. 資料平衡處理 ====================
143
+ print("\n⚖️ 執行資料平衡...")
144
+
145
+ df_class_0 = df[df['nbcd'] == 0]
146
+ df_class_1 = df[df['nbcd'] == 1]
147
+
148
+ target_n = int(target_samples)
149
+
150
+ # 欠採樣 Class 0
151
+ if len(df_class_0) > target_n:
152
+ df_class_0_balanced = resample(df_class_0, n_samples=target_n, random_state=42, replace=False)
153
+ print(f"✅ Class 0 欠採樣: {len(df_class_0)} → {len(df_class_0_balanced)} 筆")
154
+ else:
155
+ df_class_0_balanced = df_class_0
156
+ print(f"⚠️ Class 0 樣本數不足,保持 {len(df_class_0)} 筆")
157
+
158
+ # 過採樣 Class 1
159
+ if len(df_class_1) < target_n:
160
+ df_class_1_balanced = resample(df_class_1, n_samples=target_n, random_state=42, replace=True)
161
+ print(f"✅ Class 1 過採樣: {len(df_class_1)} → {len(df_class_1_balanced)} 筆")
162
+ else:
163
+ df_class_1_balanced = df_class_1
164
+ print(f"⚠️ Class 1 樣本數充足,保持 {len(df_class_1)} 筆")
165
+
166
+ df_balanced = pd.concat([df_class_0_balanced, df_class_1_balanced])
167
+ df_balanced = df_balanced.sample(frac=1, random_state=42).reset_index(drop=True)
168
+
169
+ print(f"\n📊 平衡後數據:")
170
+ print(f" 總樣本數: {len(df_balanced)} 筆")
171
+ print(f" Class 0: {(df_balanced['nbcd']==0).sum()} 筆")
172
+ print(f" Class 1: {(df_balanced['nbcd']==1).sum()} 筆")
173
+
174
+ # ==================== 3. 計算類別權重 ====================
175
+ if use_class_weights:
176
+ print("\n⚖️ 計算類別權重...")
177
+ class_counts = df_balanced['nbcd'].value_counts().sort_index()
178
+ total = len(df_balanced)
179
+ num_classes = 2
180
+
181
+ class_weight_0 = total / (num_classes * class_counts[0])
182
+ class_weight_1 = total / (num_classes * class_counts[1])
183
+ class_weights = torch.tensor([class_weight_0, class_weight_1], dtype=torch.float32)
184
+
185
+ print(f"✅ 類別權重計算完成:")
186
+ print(f" Class 0 權重: {class_weight_0:.4f}")
187
+ print(f" Class 1 權重: {class_weight_1:.4f}")
188
+
189
+ if device == "cuda":
190
+ class_weights = class_weights.to(device)
191
+ else:
192
+ class_weights = None
193
+ print("\n⚠️ 未使用類別權重")
194
+
195
+ # ==================== 4. 分割數據 ====================
196
+ print("\n✂️ 分割訓練集和測試集...")
197
+ train_df, test_df = train_test_split(
198
+ df_balanced,
199
+ test_size=0.2,
200
+ stratify=df_balanced['nbcd'],
201
+ random_state=42
202
+ )
203
+ print(f"✅ 訓練集: {len(train_df)} 筆 (Class 0: {(train_df['nbcd']==0).sum()}, Class 1: {(train_df['nbcd']==1).sum()})")
204
+ print(f"✅ 測試集: {len(test_df)} 筆 (Class 0: {(test_df['nbcd']==0).sum()}, Class 1: {(test_df['nbcd']==1).sum()})")
205
+
206
+ dataset = DatasetDict({
207
+ 'train': Dataset.from_pandas(train_df[['Text', 'nbcd']]),
208
+ 'test': Dataset.from_pandas(test_df[['Text', 'nbcd']])
209
+ })
210
+
211
+ # ==================== 5. 載入模型和 Tokenizer ====================
212
+ print("\n🤖 載入 Llama 模型和 Tokenizer...")
213
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
214
+ if tokenizer.pad_token is None:
215
+ tokenizer.pad_token = tokenizer.eos_token
216
+ tokenizer.pad_token_id = tokenizer.eos_token_id
217
+
218
+ # ==================== 6. 載入未微調的基礎模型 (Baseline) ====================
219
+ print("\n📦 載入未微調的基礎模型 (Baseline)...")
220
+ baseline_model = AutoModelForSequenceClassification.from_pretrained(
221
+ model_name,
222
+ num_labels=2,
223
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
224
+ device_map="auto" if device == "cuda" else None
225
+ )
226
+ baseline_model.config.pad_token_id = tokenizer.pad_token_id
227
+ print("✅ Baseline 模型載入完成")
228
+
229
+ # ==================== 7. 載入要微調的模型 ====================
230
+ print("\n🔧 載入用於微調的模型...")
231
+
232
+ # 【新增】二次微調邏輯
233
+ if is_second_finetuning and base_model_path:
234
+ print(f"📦 載入第一次微調模型: {base_model_path}")
235
+
236
+ # 讀取第一次模型資訊
237
+ with open('./saved_llama_models_list.json', 'r') as f:
238
+ models_list = json.load(f)
239
+
240
+ base_model_info = None
241
+ for model_info in models_list:
242
+ if model_info['model_path'] == base_model_path:
243
+ base_model_info = model_info
244
+ break
245
+
246
+ if base_model_info is None:
247
+ raise ValueError(f"找不到基礎模型資訊: {base_model_path}")
248
+
249
+ base_tuning_method = base_model_info['tuning_method']
250
+ print(f" 第一次微調方法: {base_tuning_method}")
251
+
252
+ # 根據第一次的方法載入模型
253
+ if base_tuning_method in ["LoRA", "AdaLoRA", "Adapter", "Prompt Tuning"]:
254
+ # 載入 PEFT 模型
255
+ base_bert = AutoModelForSequenceClassification.from_pretrained(
256
+ model_name,
257
+ num_labels=2,
258
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32
259
+ )
260
+ base_model = PeftModel.from_pretrained(base_bert, base_model_path)
261
+ print(f" ✅ 已載入 {base_tuning_method} 模型")
262
+ else:
263
+ # 載入一般模型 (BitFit)
264
+ base_model = AutoModelForSequenceClassification.from_pretrained(
265
+ base_model_path,
266
+ num_labels=2,
267
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32
268
+ )
269
+ print(f" ✅ 已載入 BitFit 模型")
270
+
271
+ if device == "cuda":
272
+ base_model = base_model.to(device)
273
+
274
+ print(f" ⚠️ 注意:二次微調將使用與第一次相同的方法 ({base_tuning_method})")
275
+
276
+ # 二次微調時強制使用相同方法
277
+ tuning_method = base_tuning_method
278
+
279
+ else:
280
+ # 【原始邏輯】第一次微調:從純 Llama 開始
281
+ base_model = AutoModelForSequenceClassification.from_pretrained(
282
+ model_name,
283
+ num_labels=2,
284
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
285
+ device_map="auto" if device == "cuda" else None
286
+ )
287
+
288
+ base_model.config.pad_token_id = tokenizer.pad_token_id
289
+ print("✅ 基礎模型載入完成")
290
+
291
+ # ==================== 8. 配置微調方法 ====================
292
+ print(f"\n🔧 配置 {tuning_method}...")
293
+
294
+ if tuning_method == "LoRA":
295
+ # LoRA 配置 - 使用完整參數
296
+ target_modules_map = {
297
+ "query,value": ["q_proj", "v_proj"],
298
+ "query,key,value": ["q_proj", "k_proj", "v_proj"],
299
+ "all": ["q_proj", "k_proj", "v_proj", "o_proj"]
300
+ }
301
+
302
+ peft_config = LoraConfig(
303
+ task_type=TaskType.SEQ_CLS,
304
+ r=int(lora_r),
305
+ lora_alpha=int(lora_alpha),
306
+ lora_dropout=float(lora_dropout),
307
+ target_modules=target_modules_map.get(lora_target_modules, ["q_proj", "v_proj"]),
308
+ bias="none"
309
+ )
310
+ print(f"✅ LoRA 配置完成")
311
+ print(f" LoRA rank (r): {lora_r}")
312
+ print(f" LoRA alpha: {lora_alpha}")
313
+ print(f" LoRA dropout: {lora_dropout}")
314
+ print(f" 目標模組: {lora_target_modules}")
315
+
316
+ elif tuning_method == "AdaLoRA":
317
+ # AdaLoRA 配置 - 使用獨立參數
318
+ try:
319
+ peft_config = AdaLoraConfig(
320
+ task_type=TaskType.SEQ_CLS,
321
+ inference_mode=False,
322
+ r=int(adalora_target_r),
323
+ lora_alpha=int(adalora_alpha),
324
+ lora_dropout=0.1,
325
+ target_modules=["q_proj", "v_proj"],
326
+ # AdaLoRA 特定參數
327
+ init_r=int(adalora_init_r),
328
+ target_r=int(adalora_target_r),
329
+ tinit=int(adalora_tinit),
330
+ tfinal=int(adalora_tfinal),
331
+ deltaT=int(adalora_delta_t),
332
+ )
333
+ print(f"✅ AdaLoRA 配置完成")
334
+ print(f" 初始 rank: {adalora_init_r}")
335
+ print(f" 目標 rank: {adalora_target_r}")
336
+ print(f" Alpha: {adalora_alpha}")
337
+ print(f" Tinit: {adalora_tinit}, Tfinal: {adalora_tfinal}")
338
+ print(f" Delta T: {adalora_delta_t}")
339
+ print(f" 自適應秩調整: 啟用")
340
+ except Exception as e:
341
+ print(f"⚠️ AdaLoRA 配置失敗,回退到 LoRA: {e}")
342
+ peft_config = LoraConfig(
343
+ task_type=TaskType.SEQ_CLS,
344
+ r=int(adalora_target_r),
345
+ lora_alpha=int(adalora_alpha),
346
+ lora_dropout=0.1,
347
+ target_modules=["q_proj", "v_proj"],
348
+ bias="none"
349
+ )
350
+
351
+ elif tuning_method == "Adapter":
352
+ # Adapter (Bottleneck Adapters)
353
+ peft_config = AdaptionPromptConfig(
354
+ task_type=TaskType.SEQ_CLS,
355
+ adapter_len=10,
356
+ adapter_layers=30,
357
+ reduction_factor=int(adapter_reduction_factor)
358
+ )
359
+ print(f"✅ Adapter 配置完成")
360
+ print(f" Reduction factor: {adapter_reduction_factor}")
361
+
362
+ elif tuning_method == "Prompt Tuning":
363
+ # Soft Prompt Tuning
364
+ peft_config = PromptTuningConfig(
365
+ task_type=TaskType.SEQ_CLS,
366
+ num_virtual_tokens=int(prompt_tuning_num_tokens),
367
+ prompt_tuning_init="TEXT",
368
+ prompt_tuning_init_text="Classify if the following text indicates NBCD:",
369
+ tokenizer_name_or_path=model_name
370
+ )
371
+ print(f"✅ Prompt Tuning 配置完成")
372
+ print(f" Virtual tokens: {prompt_tuning_num_tokens}")
373
+
374
+ elif tuning_method == "Prefix Tuning":
375
+ # Prefix Tuning - 可能有兼容性問題,但仍然嘗試
376
+ print(f"⚠️ Prefix Tuning 在某些環境可能有兼容性問題")
377
+ print(f" 如果遇到錯誤,建議使用 Prompt Tuning 替代")
378
+
379
+ try:
380
+ # 先禁用模型的緩存功能
381
+ base_model.config.use_cache = False
382
+
383
+ peft_config = PrefixTuningConfig(
384
+ task_type=TaskType.SEQ_CLS,
385
+ num_virtual_tokens=int(prefix_tuning_num_tokens),
386
+ prefix_projection=False,
387
+ inference_mode=False
388
+ )
389
+ print(f"✅ Prefix Tuning 配置完成")
390
+ print(f" Virtual tokens: {prefix_tuning_num_tokens}")
391
+ print(f" 已禁用緩存")
392
+ except Exception as e:
393
+ print(f"❌ Prefix Tuning 配置失敗: {e}")
394
+ raise ValueError(
395
+ f"Prefix Tuning 配置失敗,原因: {e}\n"
396
+ f"建議使用 Prompt Tuning 作為替代方案"
397
+ )
398
+
399
+ elif tuning_method == "BitFit":
400
+ # BitFit: 只訓練 bias 參數 - 完全修復版
401
+ model = base_model
402
+
403
+ # 凍結所有參數
404
+ for param in model.parameters():
405
+ param.requires_grad = False
406
+
407
+ # 只解凍 bias 和 分類頭
408
+ trainable_params_list = []
409
+ for name, param in model.named_parameters():
410
+ if 'bias' in name or 'score' in name or 'classifier' in name:
411
+ param.requires_grad = True
412
+ trainable_params_list.append(name)
413
+
414
+ print(f"✅ BitFit 配置完成")
415
+ print(f" 僅訓練 bias 和分類頭參數")
416
+ print(f" 可訓練參數: {', '.join(trainable_params_list[:5])}...")
417
+
418
+ # 應用 PEFT 配置(BitFit 除外)
419
+ if tuning_method != "BitFit":
420
+ model = get_peft_model(base_model, peft_config)
421
+
422
+ # Prefix Tuning 額外設置
423
+ if tuning_method == "Prefix Tuning":
424
+ model.config.use_cache = False
425
+
426
+ # 計算可訓練參數
427
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
428
+ total_params = sum(p.numel() for p in model.parameters())
429
+ print(f" 可訓練參數: {trainable_params:,} / {total_params:,} ({trainable_params/total_params*100:.2f}%)")
430
+
431
+ # ==================== 9. 預處理數據 ====================
432
+ print("\n📄 預處理數據...")
433
+
434
+ def preprocess_function(examples):
435
+ return tokenizer(
436
+ examples['Text'],
437
+ truncation=True,
438
+ padding='max_length',
439
+ max_length=MAX_LENGTH
440
+ )
441
+
442
+ tokenized_dataset = dataset.map(preprocess_function, batched=True, remove_columns=['Text'])
443
+ tokenized_dataset = tokenized_dataset.rename_column("nbcd", "labels")
444
+ print("✅ 數據預處理完成")
445
+
446
+ # ==================== 10. 評估指標函數 ====================
447
+ def compute_metrics(eval_pred):
448
+ predictions, labels = eval_pred
449
+ predictions = np.argmax(predictions, axis=1)
450
+
451
+ accuracy = accuracy_score(labels, predictions)
452
+ precision, recall, f1, _ = precision_recall_fscore_support(
453
+ labels, predictions, average='binary', zero_division=0
454
+ )
455
+
456
+ # 計算混淆矩陣以得到 sensitivity 和 specificity
457
+ from sklearn.metrics import confusion_matrix
458
+ cm = confusion_matrix(labels, predictions)
459
+
460
+ if cm.shape == (2, 2):
461
+ tn, fp, fn, tp = cm.ravel()
462
+ sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0 # 敏感度 = Recall
463
+ specificity = tn / (tn + fp) if (tn + fp) > 0 else 0 # 特異性
464
+ else:
465
+ sensitivity = 0
466
+ specificity = 0
467
+
468
+ return {
469
+ 'accuracy': accuracy,
470
+ 'precision': precision,
471
+ 'recall': recall,
472
+ 'f1': f1,
473
+ 'sensitivity': sensitivity,
474
+ 'specificity': specificity
475
+ }
476
+
477
+ # ==================== 11. 評估 Baseline 模型 ====================
478
+ # 【僅第一次微調時執行】
479
+ if not is_second_finetuning:
480
+ print("\n" + "="*70)
481
+ print("📊 評估未微調的 Baseline 模型...")
482
+ print("="*70)
483
+
484
+ baseline_trainer = Trainer(
485
+ model=baseline_model,
486
+ args=TrainingArguments(
487
+ output_dir="./temp_baseline_llama",
488
+ per_device_eval_batch_size=int(batch_size),
489
+ bf16=(device == "cuda"),
490
+ report_to="none"
491
+ ),
492
+ tokenizer=tokenizer,
493
+ data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
494
+ compute_metrics=compute_metrics
495
+ )
496
+
497
+ baseline_test_results = baseline_trainer.evaluate(eval_dataset=tokenized_dataset['test'])
498
+
499
+ print("\n📋 Baseline 模型 - 測試集結果:")
500
+ print(f" Accuracy: {baseline_test_results['eval_accuracy']:.4f}")
501
+ print(f" Precision: {baseline_test_results['eval_precision']:.4f}")
502
+ print(f" Recall: {baseline_test_results['eval_recall']:.4f}")
503
+ print(f" F1 Score: {baseline_test_results['eval_f1']:.4f}")
504
+ print(f" Sensitivity: {baseline_test_results['eval_sensitivity']:.4f}")
505
+ print(f" Specificity: {baseline_test_results['eval_specificity']:.4f}")
506
+
507
+ # 清空 baseline 模型記憶體
508
+ del baseline_model
509
+ del baseline_trainer
510
+ torch.cuda.empty_cache()
511
+ gc.collect()
512
+ else:
513
+ # 二次微調不評估 baseline
514
+ baseline_test_results = None
515
+ del baseline_model
516
+ torch.cuda.empty_cache()
517
+ gc.collect()
518
+
519
+ # ==================== 12. 自定義 Trainer ====================
520
+ if use_class_weights:
521
+ class WeightedTrainer(Trainer):
522
+ def __init__(self, *args, class_weights=None, **kwargs):
523
+ super().__init__(*args, **kwargs)
524
+ self.class_weights = class_weights
525
+
526
+ def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
527
+ labels = inputs.pop("labels")
528
+ outputs = model(**inputs)
529
+ logits = outputs.logits
530
+
531
+ loss_fct = torch.nn.CrossEntropyLoss(weight=self.class_weights)
532
+ loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
533
+
534
+ return (loss, outputs) if return_outputs else loss
535
+
536
+ TrainerClass = WeightedTrainer
537
+ else:
538
+ TrainerClass = Trainer
539
+
540
+ # ==================== 13. 訓練配置 ====================
541
+ print("\n" + "="*70)
542
+ print("⚙️ 配置微調訓練器...")
543
+ print("="*70)
544
+
545
+ # 指標映射
546
+ metric_map = {
547
+ "f1": "f1",
548
+ "accuracy": "accuracy",
549
+ "precision": "precision",
550
+ "recall": "recall",
551
+ "sensitivity": "sensitivity",
552
+ "specificity": "specificity"
553
+ }
554
+
555
+ training_label = "second" if is_second_finetuning else "first"
556
+ output_dir = f'./llama_nbcd_{tuning_method.lower().replace(" ", "_")}_{training_label}_{datetime.now().strftime("%Y%m%d_%H%M%S")}'
557
+
558
+ training_args = TrainingArguments(
559
+ output_dir=output_dir,
560
+ num_train_epochs=int(num_epochs),
561
+ per_device_train_batch_size=int(batch_size),
562
+ per_device_eval_batch_size=int(batch_size),
563
+ learning_rate=float(learning_rate),
564
+ weight_decay=0.01,
565
+ eval_strategy="epoch",
566
+ save_strategy="epoch",
567
+ load_best_model_at_end=True,
568
+ metric_for_best_model=metric_map.get(best_metric, "recall"),
569
+ logging_dir=f"{output_dir}/logs",
570
+ logging_steps=10,
571
+ bf16=(device == "cuda"),
572
+ gradient_accumulation_steps=2,
573
+ warmup_steps=50,
574
+ report_to="none",
575
+ seed=42
576
+ )
577
+
578
+ if use_class_weights:
579
+ trainer = TrainerClass(
580
+ model=model,
581
+ args=training_args,
582
+ train_dataset=tokenized_dataset['train'],
583
+ eval_dataset=tokenized_dataset['test'],
584
+ tokenizer=tokenizer,
585
+ data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
586
+ compute_metrics=compute_metrics,
587
+ class_weights=class_weights
588
+ )
589
+ else:
590
+ trainer = TrainerClass(
591
+ model=model,
592
+ args=training_args,
593
+ train_dataset=tokenized_dataset['train'],
594
+ eval_dataset=tokenized_dataset['test'],
595
+ tokenizer=tokenizer,
596
+ data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
597
+ compute_metrics=compute_metrics
598
+ )
599
+
600
+ # ==================== 14. 開始訓練 ====================
601
+ print("\n" + "="*70)
602
+ print(f"🚀 開始{training_type}訓練...")
603
+ print("="*70 + "\n")
604
+
605
+ start_time = datetime.now()
606
+ train_result = trainer.train()
607
+ end_time = datetime.now()
608
+ duration = (end_time - start_time).total_seconds() / 60
609
+
610
+ print("\n" + "="*70)
611
+ print(f"✅ 訓練完成!")
612
+ print(f" 耗時: {duration:.1f} 分鐘")
613
+ print("="*70)
614
+
615
+ # ==================== 15. 評估微調後的模型 ====================
616
+ print("\n" + "="*70)
617
+ print(f"📊 評估{training_type}後的模型...")
618
+ print("="*70)
619
+
620
+ finetuned_test_results = trainer.evaluate(eval_dataset=tokenized_dataset['test'])
621
+
622
+ print(f"\n📋 {training_type}模型 - 測試集結果:")
623
+ print(f" Accuracy: {finetuned_test_results['eval_accuracy']:.4f}")
624
+ print(f" Precision: {finetuned_test_results['eval_precision']:.4f}")
625
+ print(f" Recall: {finetuned_test_results['eval_recall']:.4f}")
626
+ print(f" F1 Score: {finetuned_test_results['eval_f1']:.4f}")
627
+ print(f" Sensitivity: {finetuned_test_results['eval_sensitivity']:.4f}")
628
+ print(f" Specificity: {finetuned_test_results['eval_specificity']:.4f}")
629
+
630
+ # ==================== 16. 保存模型和結果 ====================
631
+ print("\n💾 保存模型和結果...")
632
+ trainer.save_model()
633
+ tokenizer.save_pretrained(output_dir)
634
+
635
+ # 儲存模型資訊到 JSON 檔案
636
+ metric_key = 'eval_' + metric_map.get(best_metric, "recall")
637
+ model_info = {
638
+ 'model_path': output_dir,
639
+ 'model_name': model_name,
640
+ 'tuning_method': tuning_method,
641
+ 'training_type': training_type,
642
+ 'best_metric': best_metric,
643
+ 'best_metric_value': float(finetuned_test_results[metric_key]),
644
+ 'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
645
+ 'target_samples': target_samples,
646
+ 'epochs': num_epochs,
647
+ 'batch_size': batch_size,
648
+ 'learning_rate': learning_rate,
649
+ 'lora_r': lora_r if tuning_method in ["LoRA", "AdaLoRA"] else None,
650
+ 'lora_alpha': lora_alpha if tuning_method in ["LoRA", "AdaLoRA"] else None,
651
+ 'is_second_finetuning': is_second_finetuning,
652
+ 'base_model_path': base_model_path if is_second_finetuning else None
653
+ }
654
+
655
+ # 讀取現有的模型列表
656
+ models_list_file = './saved_llama_models_list.json'
657
+ if os.path.exists(models_list_file):
658
+ with open(models_list_file, 'r') as f:
659
+ models_list = json.load(f)
660
+ else:
661
+ models_list = []
662
+
663
+ # 加入新模型資訊
664
+ models_list.append(model_info)
665
+
666
+ # 儲存更新後的列表
667
+ with open(models_list_file, 'w') as f:
668
+ json.dump(models_list, f, indent=2)
669
+
670
+ # 更新全域變數
671
+ LAST_MODEL_PATH = output_dir
672
+ LAST_TOKENIZER = tokenizer
673
+
674
+ print(f"✅ 模型已儲存至: {output_dir}")
675
+
676
+ # ==================== 清空記憶體(訓練後) ====================
677
+ del model
678
+ del trainer
679
+ torch.cuda.empty_cache()
680
+ gc.collect()
681
+ print("🧹 訓練後記憶體已清空")
682
+
683
+ # 準備返回結果
684
+ results = {
685
+ 'baseline_results': baseline_test_results,
686
+ 'finetuned_results': finetuned_test_results,
687
+ 'model_path': output_dir,
688
+ 'duration': duration,
689
+ 'best_metric': best_metric,
690
+ 'model_name': model_name,
691
+ 'tuning_method': tuning_method,
692
+ 'training_type': training_type,
693
+ 'is_second_finetuning': is_second_finetuning
694
+ }
695
+
696
+ return results
697
+
698
+ # ==================== Gradio Wrapper 函數 ====================
699
+ def train_first_wrapper(
700
+ file,
701
+ model_name,
702
+ target_samples,
703
+ use_class_weights,
704
+ num_epochs,
705
+ batch_size,
706
+ learning_rate,
707
+ tuning_method,
708
+ lora_r,
709
+ lora_alpha,
710
+ lora_dropout,
711
+ lora_target_modules,
712
+ adalora_init_r,
713
+ adalora_target_r,
714
+ adalora_alpha,
715
+ adalora_tinit,
716
+ adalora_tfinal,
717
+ adalora_delta_t,
718
+ adapter_reduction_factor,
719
+ prompt_tuning_num_tokens,
720
+ prefix_tuning_num_tokens,
721
+ best_metric
722
+ ):
723
+ """第一次微調的包裝函數"""
724
+
725
+ if file is None:
726
+ return "請上傳 CSV 檔案", "", ""
727
+
728
+ try:
729
+ # 呼叫訓練函數
730
+ results = run_llama_training(
731
+ file_path=file.name,
732
+ model_name=model_name,
733
+ target_samples=target_samples,
734
+ use_class_weights=use_class_weights,
735
+ num_epochs=num_epochs,
736
+ batch_size=batch_size,
737
+ learning_rate=learning_rate,
738
+ tuning_method=tuning_method,
739
+ lora_r=lora_r,
740
+ lora_alpha=lora_alpha,
741
+ lora_dropout=lora_dropout,
742
+ lora_target_modules=lora_target_modules,
743
+ adalora_init_r=adalora_init_r,
744
+ adalora_target_r=adalora_target_r,
745
+ adalora_alpha=adalora_alpha,
746
+ adalora_tinit=adalora_tinit,
747
+ adalora_tfinal=adalora_tfinal,
748
+ adalora_delta_t=adalora_delta_t,
749
+ adapter_reduction_factor=adapter_reduction_factor,
750
+ prompt_tuning_num_tokens=prompt_tuning_num_tokens,
751
+ prefix_tuning_num_tokens=prefix_tuning_num_tokens,
752
+ best_metric=best_metric,
753
+ is_second_finetuning=False
754
+ )
755
+
756
+ baseline_results = results['baseline_results']
757
+ finetuned_results = results['finetuned_results']
758
+
759
+ # 第一格:資料資訊
760
+ data_info = f"""
761
+ # 📊 資料資訊 (第一次微調)
762
+
763
+ ## 🔧 訓練配置
764
+ - **模型**: {results['model_name']}
765
+ - **微調方法**: {results['tuning_method']}
766
+ - **最佳化指標**: {results['best_metric']}
767
+ - **訓練時長**: {results['duration']:.1f} 分鐘
768
+
769
+ ## ⚙️ 訓練參數
770
+ - **目標樣本數**: {target_samples} 筆/類別
771
+ - **使用類別權重**: {'是' if use_class_weights else '否'}
772
+ - **訓練輪數**: {num_epochs}
773
+ - **批次大小**: {batch_size}
774
+ - **學習率**: {learning_rate}
775
+
776
+ ✅ 第一次微調完成!可進行二次微調或預測!
777
+ """
778
+
779
+ # 第二格:未微調 Llama
780
+ baseline_output = f"""
781
+ # 🔵 未微調 Llama (Baseline)
782
+ ## 未經訓練
783
+
784
+ ### 📈 評估指標
785
+
786
+ | 指標 | 數值 |
787
+ |------|------|
788
+ | **Accuracy** | {baseline_results['eval_accuracy']:.4f} |
789
+ | **Precision** | {baseline_results['eval_precision']:.4f} |
790
+ | **Recall** | {baseline_results['eval_recall']:.4f} |
791
+ | **F1 Score** | {baseline_results['eval_f1']:.4f} |
792
+ | **Sensitivity** | {baseline_results['eval_sensitivity']:.4f} |
793
+ | **Specificity** | {baseline_results['eval_specificity']:.4f} |
794
+ """
795
+
796
+ # 第三格:微調後 Llama
797
+ finetuned_output = f"""
798
+ # 🟢 第一次微調 Llama
799
+ ## {results['tuning_method']}
800
+
801
+ ### 📈 評估指標
802
+
803
+ | 指標 | 數值 |
804
+ |------|------|
805
+ | **Accuracy** | {finetuned_results['eval_accuracy']:.4f} |
806
+ | **Precision** | {finetuned_results['eval_precision']:.4f} |
807
+ | **Recall** | {finetuned_results['eval_recall']:.4f} |
808
+ | **F1 Score** | {finetuned_results['eval_f1']:.4f} |
809
+ | **Sensitivity** | {finetuned_results['eval_sensitivity']:.4f} |
810
+ | **Specificity** | {finetuned_results['eval_specificity']:.4f} |
811
+ """
812
+
813
+ return data_info, baseline_output, finetuned_output
814
+
815
+ except Exception as e:
816
+ import traceback
817
+ error_msg = f"❌ 錯誤:{str(e)}\n\n詳細錯誤訊息:\n{traceback.format_exc()}"
818
+ return error_msg, "", ""
819
+
820
+ def train_second_wrapper(
821
+ base_model_choice,
822
+ file,
823
+ target_samples,
824
+ use_class_weights,
825
+ num_epochs,
826
+ batch_size,
827
+ learning_rate,
828
+ best_metric
829
+ ):
830
+ """二次微調的包裝函數"""
831
+
832
+ if base_model_choice == "請先進行第一次微調":
833
+ return "請先在「第一次微調」頁面訓練模型", ""
834
+
835
+ if file is None:
836
+ return "請上傳新的訓練數據 CSV 檔案", ""
837
+
838
+ try:
839
+ # 解析基礎模型路徑
840
+ base_model_path = base_model_choice
841
+
842
+ # 讀取第一次模型資訊
843
+ with open('./saved_llama_models_list.json', 'r') as f:
844
+ models_list = json.load(f)
845
+
846
+ base_model_info = None
847
+ for model_info in models_list:
848
+ if model_info['model_path'] == base_model_path:
849
+ base_model_info = model_info
850
+ break
851
+
852
+ if base_model_info is None:
853
+ return "找不到基礎模型資訊", ""
854
+
855
+ # 使用第一次的參數(二次微調不更改方法)
856
+ tuning_method = base_model_info['tuning_method']
857
+ model_name = base_model_info['model_name']
858
+
859
+ # 獲取第一次的 PEFT 參數
860
+ lora_r = base_model_info.get('lora_r', 16)
861
+ lora_alpha = base_model_info.get('lora_alpha', 32)
862
+ lora_dropout = 0.1
863
+ lora_target_modules = "query,value"
864
+ adalora_init_r = 12
865
+ adalora_target_r = 8
866
+ adalora_alpha = 32
867
+ adalora_tinit = 0
868
+ adalora_tfinal = 0
869
+ adalora_delta_t = 1
870
+ adapter_reduction_factor = 16
871
+ prompt_tuning_num_tokens = 20
872
+ prefix_tuning_num_tokens = 30
873
+
874
+ results = run_llama_training(
875
+ file_path=file.name,
876
+ model_name=model_name,
877
+ target_samples=target_samples,
878
+ use_class_weights=use_class_weights,
879
+ num_epochs=num_epochs,
880
+ batch_size=batch_size,
881
+ learning_rate=learning_rate,
882
+ tuning_method=tuning_method,
883
+ lora_r=lora_r,
884
+ lora_alpha=lora_alpha,
885
+ lora_dropout=lora_dropout,
886
+ lora_target_modules=lora_target_modules,
887
+ adalora_init_r=adalora_init_r,
888
+ adalora_target_r=adalora_target_r,
889
+ adalora_alpha=adalora_alpha,
890
+ adalora_tinit=adalora_tinit,
891
+ adalora_tfinal=adalora_tfinal,
892
+ adalora_delta_t=adalora_delta_t,
893
+ adapter_reduction_factor=adapter_reduction_factor,
894
+ prompt_tuning_num_tokens=prompt_tuning_num_tokens,
895
+ prefix_tuning_num_tokens=prefix_tuning_num_tokens,
896
+ best_metric=best_metric,
897
+ is_second_finetuning=True,
898
+ base_model_path=base_model_path
899
+ )
900
+
901
+ finetuned_results = results['finetuned_results']
902
+
903
+ data_info = f"""
904
+ # 📊 二次微調結果
905
+
906
+ ## 🔧 訓練配置
907
+ - **基礎模型**: {base_model_path}
908
+ - **微調方法**: {results['tuning_method']} (繼承自第一次)
909
+ - **最佳化指標**: {results['best_metric']}
910
+ - **最佳指標值**: {finetuned_results['eval_' + results['best_metric']]:.4f}
911
+ - **訓練時長**: {results['duration']:.1f} 分鐘
912
+
913
+ ## ⚙️ 訓練參數
914
+ - **目標樣本數**: {target_samples} 筆/類別
915
+ - **使用類別權重**: {'是' if use_class_weights else '否'}
916
+ - **訓練輪數**: {num_epochs}
917
+ - **批次大小**: {batch_size}
918
+ - **學習率**: {learning_rate}
919
+
920
+ ✅ 二次微調完成!可進行預測!
921
+ """
922
+
923
+ finetuned_output = f"""
924
+ # 🟢 二次微調 Llama
925
+ ## {results['tuning_method']}
926
+
927
+ ### 📈 評估指標
928
+
929
+ | 指標 | 數值 |
930
+ |------|------|
931
+ | **Accuracy** | {finetuned_results['eval_accuracy']:.4f} |
932
+ | **Precision** | {finetuned_results['eval_precision']:.4f} |
933
+ | **Recall** | {finetuned_results['eval_recall']:.4f} |
934
+ | **F1 Score** | {finetuned_results['eval_f1']:.4f} |
935
+ | **Sensitivity** | {finetuned_results['eval_sensitivity']:.4f} |
936
+ | **Specificity** | {finetuned_results['eval_specificity']:.4f} |
937
+ """
938
+
939
+ return data_info, finetuned_output
940
+
941
+ except Exception as e:
942
+ import traceback
943
+ error_msg = f"❌ 錯誤:{str(e)}\n\n詳細錯誤訊息:\n{traceback.format_exc()}"
944
+ return error_msg, ""
945
+
946
+ # ==================== 新增:新數據測試函數 ====================
947
+
948
+ def test_on_new_data(test_file_path, baseline_choice, first_choice, second_choice):
949
+ """
950
+ 在新測試數據上比較三個模型的表現:
951
+ 1. 純 Llama (baseline)
952
+ 2. 第一次微調模型
953
+ 3. 第二次微調模型
954
+ """
955
+
956
+ print("\n" + "=" * 80)
957
+ print("📊 新數據測試 - 三模型比較")
958
+ print("=" * 80)
959
+
960
+ # 載入測試數據
961
+ df_test = pd.read_csv(test_file_path)
962
+
963
+ # 自動偵測欄位
964
+ text_col = 'Text' if 'Text' in df_test.columns else 'text'
965
+ label_col = 'Label' if 'Label' in df_test.columns else 'label'
966
+
967
+ df_clean = pd.DataFrame({
968
+ 'text': df_test[text_col],
969
+ 'label': df_test[label_col]
970
+ })
971
+ df_clean = df_clean.dropna()
972
+
973
+ print(f"\n測試數據:")
974
+ print(f" 總筆數: {len(df_clean)}")
975
+ print(f" Class 0: {sum(df_clean['label']==0)} 筆")
976
+ print(f" Class 1: {sum(df_clean['label']==1)} 筆")
977
+
978
+ # 準備測試數據
979
+ test_dataset = Dataset.from_pandas(df_clean[['text', 'label']])
980
+
981
+ # 評估函數
982
+ def evaluate_model(model, tokenizer, model_name_str, dataset_name):
983
+ model.eval()
984
+
985
+ # 確保 tokenizer 有 pad_token
986
+ if tokenizer.pad_token is None:
987
+ tokenizer.pad_token = tokenizer.eos_token
988
+ tokenizer.pad_token_id = tokenizer.eos_token_id
989
+
990
+ # 確保模型配置也有 pad_token_id
991
+ if hasattr(model, 'config'):
992
+ model.config.pad_token_id = tokenizer.pad_token_id
993
+
994
+ def preprocess_function(examples):
995
+ return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=MAX_LENGTH)
996
+
997
+ test_tokenized = test_dataset.map(preprocess_function, batched=True)
998
+
999
+ trainer_args = TrainingArguments(
1000
+ output_dir='./temp_test',
1001
+ per_device_eval_batch_size=32,
1002
+ report_to="none"
1003
+ )
1004
+
1005
+ def compute_metrics_test(eval_pred):
1006
+ predictions, labels = eval_pred
1007
+ predictions = np.argmax(predictions, axis=1)
1008
+
1009
+ accuracy = accuracy_score(labels, predictions)
1010
+ precision, recall, f1, _ = precision_recall_fscore_support(
1011
+ labels, predictions, average='binary', zero_division=0
1012
+ )
1013
+
1014
+ from sklearn.metrics import confusion_matrix
1015
+ cm = confusion_matrix(labels, predictions)
1016
+
1017
+ if cm.shape == (2, 2):
1018
+ tn, fp, fn, tp = cm.ravel()
1019
+ sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
1020
+ specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
1021
+ else:
1022
+ sensitivity = 0
1023
+ specificity = 0
1024
+ tn = fp = fn = tp = 0
1025
+
1026
+ return {
1027
+ 'accuracy': accuracy,
1028
+ 'precision': precision,
1029
+ 'recall': recall,
1030
+ 'f1': f1,
1031
+ 'sensitivity': sensitivity,
1032
+ 'specificity': specificity,
1033
+ 'tp': int(tp),
1034
+ 'tn': int(tn),
1035
+ 'fp': int(fp),
1036
+ 'fn': int(fn)
1037
+ }
1038
+
1039
+ trainer = Trainer(
1040
+ model=model,
1041
+ args=trainer_args,
1042
+ compute_metrics=compute_metrics_test,
1043
+ data_collator=DataCollatorWithPadding(tokenizer=tokenizer)
1044
+ )
1045
+
1046
+ predictions_output = trainer.predict(test_tokenized)
1047
+
1048
+ results = {
1049
+ 'accuracy': predictions_output.metrics['test_accuracy'],
1050
+ 'precision': predictions_output.metrics['test_precision'],
1051
+ 'recall': predictions_output.metrics['test_recall'],
1052
+ 'f1': predictions_output.metrics['test_f1'],
1053
+ 'sensitivity': predictions_output.metrics['test_sensitivity'],
1054
+ 'specificity': predictions_output.metrics['test_specificity'],
1055
+ 'tp': predictions_output.metrics['test_tp'],
1056
+ 'tn': predictions_output.metrics['test_tn'],
1057
+ 'fp': predictions_output.metrics['test_fp'],
1058
+ 'fn': predictions_output.metrics['test_fn']
1059
+ }
1060
+
1061
+ print(f"\n✅ {dataset_name} 評估完成")
1062
+
1063
+ del trainer
1064
+ torch.cuda.empty_cache()
1065
+ gc.collect()
1066
+
1067
+ return results
1068
+
1069
+ all_results = {}
1070
+
1071
+ # 1. 評估純 Llama
1072
+ if baseline_choice == "評估純 Llama":
1073
+ print("\n" + "-" * 80)
1074
+ print("1️⃣ 評估純 Llama (Baseline)")
1075
+ print("-" * 80)
1076
+
1077
+ # 獲取模型名稱
1078
+ if first_choice != "請選擇":
1079
+ with open('./saved_llama_models_list.json', 'r') as f:
1080
+ models_list = json.load(f)
1081
+ for model_info in models_list:
1082
+ if model_info['model_path'] == first_choice:
1083
+ model_name = model_info['model_name']
1084
+ break
1085
+ else:
1086
+ model_name = "meta-llama/Llama-3.2-1B"
1087
+
1088
+ baseline_tokenizer = AutoTokenizer.from_pretrained(model_name)
1089
+ if baseline_tokenizer.pad_token is None:
1090
+ baseline_tokenizer.pad_token = baseline_tokenizer.eos_token
1091
+ baseline_tokenizer.pad_token_id = baseline_tokenizer.eos_token_id
1092
+
1093
+ baseline_model = AutoModelForSequenceClassification.from_pretrained(
1094
+ model_name,
1095
+ num_labels=2,
1096
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
1097
+ device_map="auto" if device == "cuda" else None
1098
+ )
1099
+ baseline_model.config.pad_token_id = baseline_tokenizer.pad_token_id
1100
+
1101
+ all_results['baseline'] = evaluate_model(baseline_model, baseline_tokenizer, model_name, "純 Llama")
1102
+ del baseline_model, baseline_tokenizer
1103
+ torch.cuda.empty_cache()
1104
+ else:
1105
+ all_results['baseline'] = None
1106
+
1107
+ # 2. 評估第一次微調模型
1108
+ if first_choice != "請選擇":
1109
+ print("\n" + "-" * 80)
1110
+ print("2️⃣ 評估第一次微調模型")
1111
+ print("-" * 80)
1112
+
1113
+ # 讀取模型資訊
1114
+ with open('./saved_llama_models_list.json', 'r') as f:
1115
+ models_list = json.load(f)
1116
+
1117
+ first_model_info = None
1118
+ for model_info in models_list:
1119
+ if model_info['model_path'] == first_choice:
1120
+ first_model_info = model_info
1121
+ break
1122
+
1123
+ if first_model_info:
1124
+ tuning_method = first_model_info['tuning_method']
1125
+ model_name = first_model_info['model_name']
1126
+
1127
+ first_tokenizer = AutoTokenizer.from_pretrained(first_choice)
1128
+ if first_tokenizer.pad_token is None:
1129
+ first_tokenizer.pad_token = first_tokenizer.eos_token
1130
+ first_tokenizer.pad_token_id = first_tokenizer.eos_token_id
1131
+
1132
+ if tuning_method in ["LoRA", "AdaLoRA", "Adapter", "Prompt Tuning"]:
1133
+ base_model = AutoModelForSequenceClassification.from_pretrained(
1134
+ model_name,
1135
+ num_labels=2,
1136
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32
1137
+ )
1138
+ first_model = PeftModel.from_pretrained(base_model, first_choice)
1139
+ if device == "cuda":
1140
+ first_model = first_model.to(device)
1141
+ else:
1142
+ first_model = AutoModelForSequenceClassification.from_pretrained(
1143
+ first_choice,
1144
+ num_labels=2,
1145
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
1146
+ device_map="auto" if device == "cuda" else None
1147
+ )
1148
+
1149
+ all_results['first'] = evaluate_model(first_model, first_tokenizer, model_name, "第一次微調模型")
1150
+ del first_model, first_tokenizer
1151
+ torch.cuda.empty_cache()
1152
+ else:
1153
+ all_results['first'] = None
1154
+ else:
1155
+ all_results['first'] = None
1156
+
1157
+ # 3. 評估第二次微調模型
1158
+ if second_choice != "請選擇":
1159
+ print("\n" + "-" * 80)
1160
+ print("3️⃣ 評估第二次微調模型")
1161
+ print("-" * 80)
1162
+
1163
+ # 讀取模型資訊
1164
+ with open('./saved_llama_models_list.json', 'r') as f:
1165
+ models_list = json.load(f)
1166
+
1167
+ second_model_info = None
1168
+ for model_info in models_list:
1169
+ if model_info['model_path'] == second_choice:
1170
+ second_model_info = model_info
1171
+ break
1172
+
1173
+ if second_model_info:
1174
+ tuning_method = second_model_info['tuning_method']
1175
+ model_name = second_model_info['model_name']
1176
+
1177
+ second_tokenizer = AutoTokenizer.from_pretrained(second_choice)
1178
+ if second_tokenizer.pad_token is None:
1179
+ second_tokenizer.pad_token = second_tokenizer.eos_token
1180
+ second_tokenizer.pad_token_id = second_tokenizer.eos_token_id
1181
+
1182
+ if tuning_method in ["LoRA", "AdaLoRA", "Adapter", "Prompt Tuning"]:
1183
+ base_model = AutoModelForSequenceClassification.from_pretrained(
1184
+ model_name,
1185
+ num_labels=2,
1186
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32
1187
+ )
1188
+ second_model = PeftModel.from_pretrained(base_model, second_choice)
1189
+ if device == "cuda":
1190
+ second_model = second_model.to(device)
1191
+ else:
1192
+ second_model = AutoModelForSequenceClassification.from_pretrained(
1193
+ second_choice,
1194
+ num_labels=2,
1195
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
1196
+ device_map="auto" if device == "cuda" else None
1197
+ )
1198
+
1199
+ all_results['second'] = evaluate_model(second_model, second_tokenizer, model_name, "第二次微調模型")
1200
+ del second_model, second_tokenizer
1201
+ torch.cuda.empty_cache()
1202
+ else:
1203
+ all_results['second'] = None
1204
+ else:
1205
+ all_results['second'] = None
1206
+
1207
+ print("\n" + "=" * 80)
1208
+ print("✅ 新數據測試完成")
1209
+ print("=" * 80)
1210
+
1211
+ return all_results
1212
+
1213
+ def test_new_data_wrapper(test_file, baseline_choice, first_choice, second_choice):
1214
+ """新數據測試的包裝函數"""
1215
+
1216
+ if test_file is None:
1217
+ return "請上傳測試數據 CSV 檔案", "", ""
1218
+
1219
+ try:
1220
+ all_results = test_on_new_data(
1221
+ test_file.name,
1222
+ baseline_choice,
1223
+ first_choice,
1224
+ second_choice
1225
+ )
1226
+
1227
+ # 格式化輸出
1228
+ outputs = []
1229
+
1230
+ # 1. 純 Llama
1231
+ if all_results['baseline']:
1232
+ r = all_results['baseline']
1233
+ baseline_output = f"""
1234
+ # 🔵 純 Llama (Baseline)
1235
+
1236
+ | 指標 | 數值 |
1237
+ |------|------|
1238
+ | **F1 Score** | {r['f1']:.4f} |
1239
+ | **Accuracy** | {r['accuracy']:.4f} |
1240
+ | **Precision** | {r['precision']:.4f} |
1241
+ | **Recall** | {r['recall']:.4f} |
1242
+ | **Sensitivity** | {r['sensitivity']:.4f} |
1243
+ | **Specificity** | {r['specificity']:.4f} |
1244
+
1245
+ ### 混淆矩陣
1246
+ | | 預測:Class 0 | 預測:Class 1 |
1247
+ |---|-----------|-----------|
1248
+ | **實際:Class 0** | TN={r['tn']} | FP={r['fp']} |
1249
+ | **實際:Class 1** | FN={r['fn']} | TP={r['tp']} |
1250
+ """
1251
+ else:
1252
+ baseline_output = "未選擇評估純 Llama"
1253
+ outputs.append(baseline_output)
1254
+
1255
+ # 2. 第一次微調
1256
+ if all_results['first']:
1257
+ r = all_results['first']
1258
+ first_output = f"""
1259
+ # 🟢 第一次微調模型
1260
+
1261
+ | 指標 | 數值 |
1262
+ |------|------|
1263
+ | **F1 Score** | {r['f1']:.4f} |
1264
+ | **Accuracy** | {r['accuracy']:.4f} |
1265
+ | **Precision** | {r['precision']:.4f} |
1266
+ | **Recall** | {r['recall']:.4f} |
1267
+ | **Sensitivity** | {r['sensitivity']:.4f} |
1268
+ | **Specificity** | {r['specificity']:.4f} |
1269
+
1270
+ ### 混淆矩陣
1271
+ | | 預測:Class 0 | 預測:Class 1 |
1272
+ |---|-----------|-----------|
1273
+ | **實際:Class 0** | TN={r['tn']} | FP={r['fp']} |
1274
+ | **實際:Class 1** | FN={r['fn']} | TP={r['tp']} |
1275
+ """
1276
+ else:
1277
+ first_output = "未選擇第一次微調模型"
1278
+ outputs.append(first_output)
1279
+
1280
+ # 3. 第二次微調
1281
+ if all_results['second']:
1282
+ r = all_results['second']
1283
+ second_output = f"""
1284
+ # 🟡 第二次微調模型
1285
+
1286
+ | 指標 | 數值 |
1287
+ |------|------|
1288
+ | **F1 Score** | {r['f1']:.4f} |
1289
+ | **Accuracy** | {r['accuracy']:.4f} |
1290
+ | **Precision** | {r['precision']:.4f} |
1291
+ | **Recall** | {r['recall']:.4f} |
1292
+ | **Sensitivity** | {r['sensitivity']:.4f} |
1293
+ | **Specificity** | {r['specificity']:.4f} |
1294
+
1295
+ ### 混淆矩陣
1296
+ | | 預測:Class 0 | 預測:Class 1 |
1297
+ |---|-----------|-----------|
1298
+ | **實際:Class 0** | TN={r['tn']} | FP={r['fp']} |
1299
+ | **實際:Class 1** | FN={r['fn']} | TP={r['tp']} |
1300
+ """
1301
+ else:
1302
+ second_output = "未選擇第二次微調模型"
1303
+ outputs.append(second_output)
1304
+
1305
+ return outputs[0], outputs[1], outputs[2]
1306
+
1307
+ except Exception as e:
1308
+ import traceback
1309
+ error_msg = f"❌ 錯誤:{str(e)}\n\n詳細錯誤訊息:\n{traceback.format_exc()}"
1310
+ return error_msg, "", ""
1311
+
1312
+ # ==================== 預測函數 ====================
1313
+ def predict_text(model_choice, text_input):
1314
+ """
1315
+ 預測功能 - 支持選擇已訓練的模型,並同時顯示未微調和微調的預測結果
1316
+ """
1317
+
1318
+ if not text_input or text_input.strip() == "":
1319
+ return "請輸入文本", "請輸入文本"
1320
+
1321
+ try:
1322
+ # ==================== 未微調的 Llama 預測 ====================
1323
+ print("\n使用未微調 Llama 預測...")
1324
+
1325
+ # 載入 tokenizer
1326
+ if model_choice != "請先訓練模型":
1327
+ # 從選擇中解析模型名稱
1328
+ model_path = model_choice.split(" | ")[0].replace("路徑: ", "")
1329
+
1330
+ # 從 JSON 讀取模型資訊
1331
+ with open('./saved_llama_models_list.json', 'r') as f:
1332
+ models_list = json.load(f)
1333
+
1334
+ selected_model_info = None
1335
+ for model_info in models_list:
1336
+ if model_info['model_path'] == model_path:
1337
+ selected_model_info = model_info
1338
+ break
1339
+
1340
+ if selected_model_info is None:
1341
+ return "找不到模型資訊", "找不到模型資訊"
1342
+
1343
+ model_name = selected_model_info['model_name']
1344
+ baseline_tokenizer = AutoTokenizer.from_pretrained(model_name)
1345
+ else:
1346
+ baseline_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
1347
+ model_name = "meta-llama/Llama-3.2-1B"
1348
+
1349
+ if baseline_tokenizer.pad_token is None:
1350
+ baseline_tokenizer.pad_token = baseline_tokenizer.eos_token
1351
+ baseline_tokenizer.pad_token_id = baseline_tokenizer.eos_token_id
1352
+
1353
+ baseline_model = AutoModelForSequenceClassification.from_pretrained(
1354
+ model_name,
1355
+ num_labels=2,
1356
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
1357
+ device_map="auto" if device == "cuda" else None
1358
+ )
1359
+ baseline_model.config.pad_token_id = baseline_tokenizer.pad_token_id
1360
+ baseline_model.eval()
1361
+
1362
+ # Tokenize 輸入(未微調)
1363
+ baseline_inputs = baseline_tokenizer(
1364
+ text_input,
1365
+ return_tensors="pt",
1366
+ truncation=True,
1367
+ max_length=MAX_LENGTH
1368
+ )
1369
+ if device == "cuda":
1370
+ baseline_inputs = {k: v.to(baseline_model.device) for k, v in baseline_inputs.items()}
1371
+
1372
+ # 預測(未微調)
1373
+ with torch.no_grad():
1374
+ baseline_outputs = baseline_model(**baseline_inputs)
1375
+ baseline_probs = torch.nn.functional.softmax(baseline_outputs.logits, dim=-1)
1376
+ baseline_pred_class = torch.argmax(baseline_probs, dim=-1).item()
1377
+ baseline_confidence = baseline_probs[0][baseline_pred_class].item()
1378
+
1379
+ baseline_result = "NBCD = 0" if baseline_pred_class == 0 else "NBCD = 1"
1380
+ baseline_prob_class0 = baseline_probs[0][0].item()
1381
+ baseline_prob_class1 = baseline_probs[0][1].item()
1382
+
1383
+ baseline_output = f"""
1384
+ # 🔵 未微調 Llama 預測結果
1385
+
1386
+ ## 預測類別: **{baseline_result}**
1387
+
1388
+ ## 信心度: **{baseline_confidence:.1%}**
1389
+
1390
+ ## 機率分布:
1391
+ - **Class 0 機率**: {baseline_prob_class0:.2%}
1392
+ - **Class 1 機率**: {baseline_prob_class1:.2%}
1393
+
1394
+ ---
1395
+ **說明**: 此為原始 Llama 模型,未經任何領域資料訓練
1396
+ """
1397
+
1398
+ # 清空記憶體
1399
+ del baseline_model
1400
+ del baseline_tokenizer
1401
+ torch.cuda.empty_cache()
1402
+
1403
+ # ==================== 微調後的 Llama 預測 ====================
1404
+
1405
+ if model_choice == "請先訓練模型":
1406
+ finetuned_output = """
1407
+ # 🟢 微調 Llama 預測結果
1408
+
1409
+ ❌ 尚未訓練任何模型,請先在「模型訓練」頁面訓練模型
1410
+ """
1411
+ return baseline_output, finetuned_output
1412
+
1413
+ print(f"\n使用微調模型: {model_path}")
1414
+
1415
+ # 載入 tokenizer
1416
+ finetuned_tokenizer = AutoTokenizer.from_pretrained(model_path)
1417
+ if finetuned_tokenizer.pad_token is None:
1418
+ finetuned_tokenizer.pad_token = finetuned_tokenizer.eos_token
1419
+ finetuned_tokenizer.pad_token_id = finetuned_tokenizer.eos_token_id
1420
+
1421
+ # 載入 PEFT 模型(根據微調方法)
1422
+ base_model = AutoModelForSequenceClassification.from_pretrained(
1423
+ model_name,
1424
+ num_labels=2,
1425
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
1426
+ device_map="auto" if device == "cuda" else None
1427
+ )
1428
+
1429
+ # 根據微調方法載入模型
1430
+ tuning_method = selected_model_info.get('tuning_method', 'LoRA')
1431
+
1432
+ if tuning_method == "BitFit":
1433
+ # BitFit 直接載入完整模型
1434
+ finetuned_model = AutoModelForSequenceClassification.from_pretrained(
1435
+ model_path,
1436
+ num_labels=2,
1437
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
1438
+ device_map="auto" if device == "cuda" else None
1439
+ )
1440
+ else:
1441
+ # 其他方法使用 PEFT
1442
+ finetuned_model = PeftModel.from_pretrained(base_model, model_path)
1443
+
1444
+ # Prefix Tuning 需要禁用緩存
1445
+ if tuning_method == "Prefix Tuning":
1446
+ finetuned_model.config.use_cache = False
1447
+
1448
+ finetuned_model.config.pad_token_id = finetuned_tokenizer.pad_token_id
1449
+ finetuned_model.eval()
1450
+
1451
+ # Tokenize 輸入(微調)
1452
+ finetuned_inputs = finetuned_tokenizer(
1453
+ text_input,
1454
+ return_tensors="pt",
1455
+ truncation=True,
1456
+ max_length=MAX_LENGTH
1457
+ )
1458
+ if device == "cuda":
1459
+ finetuned_inputs = {k: v.to(finetuned_model.device) for k, v in finetuned_inputs.items()}
1460
+
1461
+ # 預測(微調)
1462
+ with torch.no_grad():
1463
+ finetuned_outputs = finetuned_model(**finetuned_inputs)
1464
+ finetuned_probs = torch.nn.functional.softmax(finetuned_outputs.logits, dim=-1)
1465
+ finetuned_pred_class = torch.argmax(finetuned_probs, dim=-1).item()
1466
+ finetuned_confidence = finetuned_probs[0][finetuned_pred_class].item()
1467
+
1468
+ finetuned_result = "NBCD = 0" if finetuned_pred_class == 0 else "NBCD = 1"
1469
+ finetuned_prob_class0 = finetuned_probs[0][0].item()
1470
+ finetuned_prob_class1 = finetuned_probs[0][1].item()
1471
+
1472
+ training_type_label = "二次微調" if selected_model_info.get('is_second_finetuning', False) else "第一次微調"
1473
+
1474
+ finetuned_output = f"""
1475
+ # 🟢 微調 Llama 預測結果
1476
+
1477
+ ## 預測類別: **{finetuned_result}**
1478
+
1479
+ ## 信心度: **{finetuned_confidence:.1%}**
1480
+
1481
+ ## 機率分布:
1482
+ - **Class 0 機率**: {finetuned_prob_class0:.2%}
1483
+ - **Class 1 機率**: {finetuned_prob_class1:.2%}
1484
+
1485
+ ---
1486
+ ### 模型資訊:
1487
+ - **訓練類型**: {training_type_label}
1488
+ - **模型名稱**: {selected_model_info['model_name']}
1489
+ - **微調方法**: {selected_model_info['tuning_method']}
1490
+ - **最佳化指標**: {selected_model_info['best_metric']}
1491
+ - **訓練時間**: {selected_model_info['timestamp']}
1492
+ - **模型路徑**: {model_path}
1493
+
1494
+ ---
1495
+ **注意**: 此預測僅供參考。
1496
+ """
1497
+
1498
+ # 清空記憶體
1499
+ del finetuned_model
1500
+ del finetuned_tokenizer
1501
+ torch.cuda.empty_cache()
1502
+
1503
+ return baseline_output, finetuned_output
1504
+
1505
+ except Exception as e:
1506
+ import traceback
1507
+ error_msg = f"❌ 預測錯誤:{str(e)}\n\n詳細錯誤訊息:\n{traceback.format_exc()}"
1508
+ return error_msg, error_msg
1509
+
1510
+ def get_available_models():
1511
+ """
1512
+ 取得所有已訓練的模型列表
1513
+ """
1514
+ models_list_file = './saved_llama_models_list.json'
1515
+ if not os.path.exists(models_list_file):
1516
+ return ["請先訓練模型"]
1517
+
1518
+ with open(models_list_file, 'r') as f:
1519
+ models_list = json.load(f)
1520
+
1521
+ if len(models_list) == 0:
1522
+ return ["請先訓練模型"]
1523
+
1524
+ # 格式化模型選項
1525
+ model_choices = []
1526
+ for i, model_info in enumerate(models_list, 1):
1527
+ training_type = model_info.get('training_type', '第一次微調')
1528
+ choice = f"路徑: {model_info['model_path']} | 類型: {training_type} | 方法: {model_info['tuning_method']} | 時間: {model_info['timestamp']}"
1529
+ model_choices.append(choice)
1530
+
1531
+ return model_choices
1532
+
1533
+ def get_first_finetuning_models():
1534
+ """
1535
+ 取得所有第一次微調的模型(用於二次微調選擇)
1536
+ """
1537
+ models_list_file = './saved_llama_models_list.json'
1538
+ if not os.path.exists(models_list_file):
1539
+ return ["請先進行第一次微調"]
1540
+
1541
+ with open(models_list_file, 'r') as f:
1542
+ models_list = json.load(f)
1543
+
1544
+ # 只返回第一次微調的模型
1545
+ first_models = [m for m in models_list if not m.get('is_second_finetuning', False)]
1546
+
1547
+ if len(first_models) == 0:
1548
+ return ["請先進行第一次微調"]
1549
+
1550
+ model_choices = []
1551
+ for model_info in first_models:
1552
+ choice = f"{model_info['model_path']}"
1553
+ model_choices.append(choice)
1554
+
1555
+ return model_choices
1556
+
1557
+ # ==================== Gradio 介面 (參考第四個文件的視覺化) ====================
1558
+ with gr.Blocks(title="🦙 Llama NBCD 二次微調平台", theme=gr.themes.Soft()) as demo:
1559
+
1560
+ gr.Markdown("""
1561
+ # 🦙 Llama NBCD 二次微調完整平台
1562
+
1563
+ ### 🌟 功能特色:
1564
+ - 🎯 第一次微調:從純 Llama 開始訓練
1565
+ - 🔄 第二次微調:基於第一次模型用新數據繼續訓練
1566
+ - 📊 自動比較有/無微調的表現差異
1567
+ - 🎨 可選擇最佳化指標(F1、Accuracy、Precision、Recall)
1568
+ - 🔮 訓練後可直接預測新樣本
1569
+ - 💾 自動儲存最佳模型
1570
+ - 🧹 自動記憶體管理
1571
+
1572
+ ✅ **支持的微調方法**: LoRA, AdaLoRA, Adapter, BitFit, Prompt Tuning
1573
+ ⚠️ **暫不支持**: Prefix Tuning (版本兼容性問題,請使用 Prompt Tuning 替代)
1574
+ """)
1575
+
1576
+ # Tab 1: 第一次微調
1577
+ with gr.Tab("1️⃣ 第一次微調"):
1578
+ with gr.Row():
1579
+ with gr.Column(scale=1):
1580
+ gr.Markdown("### 📤 資料上傳")
1581
+
1582
+ file_input = gr.File(
1583
+ label="上傳 CSV 檔案",
1584
+ file_types=[".csv"]
1585
+ )
1586
+
1587
+ gr.Markdown("### 🤖 模型選擇")
1588
+
1589
+ model_name_input = gr.Textbox(
1590
+ value="meta-llama/Llama-3.2-1B",
1591
+ label="Hugging Face 模型名稱",
1592
+ info="例如: meta-llama/Llama-3.2-1B"
1593
+ )
1594
+
1595
+ gr.Markdown("### 🔧 微調方法選擇")
1596
+
1597
+ tuning_method = gr.Radio(
1598
+ choices=["LoRA", "AdaLoRA", "Adapter", "BitFit", "Prompt Tuning"],
1599
+ value="LoRA",
1600
+ label="選擇微調方法",
1601
+ info="不同的參數效率微調方法 (Prefix Tuning 暫不支持)"
1602
+ )
1603
+
1604
+ gr.Markdown("### 🎯 最佳模型選擇")
1605
+
1606
+ best_metric = gr.Dropdown(
1607
+ choices=["f1", "accuracy", "precision", "recall", "sensitivity", "specificity"],
1608
+ value="recall",
1609
+ label="選擇最佳化指標",
1610
+ info="模型會根據此指標選擇最佳檢查點"
1611
+ )
1612
+
1613
+ gr.Markdown("### ⚙️ 資料平衡參數")
1614
+
1615
+ target_samples_input = gr.Number(
1616
+ value=700,
1617
+ label="目標樣本數(每類別)"
1618
+ )
1619
+
1620
+ use_weights_checkbox = gr.Checkbox(
1621
+ value=True,
1622
+ label="使用類別權重",
1623
+ info="在損失函數中使用類別權重"
1624
+ )
1625
+
1626
+ gr.Markdown("### ⚙️ 訓練參數")
1627
+
1628
+ epochs_input = gr.Number(
1629
+ value=3,
1630
+ label="訓練輪數 (Epochs)"
1631
+ )
1632
+
1633
+ batch_size_input = gr.Number(
1634
+ value=4,
1635
+ label="批次大小 (Batch Size)"
1636
+ )
1637
+
1638
+ lr_input = gr.Number(
1639
+ value=1e-4,
1640
+ label="學習率 (Learning Rate)"
1641
+ )
1642
+
1643
+ gr.Markdown("---")
1644
+
1645
+ # LoRA 參數
1646
+ with gr.Column(visible=True) as lora_params:
1647
+ gr.Markdown("### 🔷 LoRA 參數")
1648
+
1649
+ lora_r_input = gr.Slider(
1650
+ minimum=4,
1651
+ maximum=64,
1652
+ value=16,
1653
+ step=4,
1654
+ label="LoRA Rank (r)",
1655
+ info="低秩分解的秩"
1656
+ )
1657
+
1658
+ lora_alpha_input = gr.Slider(
1659
+ minimum=8,
1660
+ maximum=128,
1661
+ value=32,
1662
+ step=8,
1663
+ label="LoRA Alpha",
1664
+ info="LoRA 縮放參數"
1665
+ )
1666
+
1667
+ lora_dropout_input = gr.Slider(
1668
+ minimum=0.0,
1669
+ maximum=0.5,
1670
+ value=0.1,
1671
+ step=0.05,
1672
+ label="LoRA Dropout",
1673
+ info="Dropout 率"
1674
+ )
1675
+
1676
+ lora_target_input = gr.Dropdown(
1677
+ choices=["query,value", "query,key,value", "all"],
1678
+ value="query,value",
1679
+ label="目標模組",
1680
+ info="用逗號分隔"
1681
+ )
1682
+
1683
+ # AdaLoRA 參數
1684
+ with gr.Column(visible=False) as adalora_params:
1685
+ gr.Markdown("### 🔶 AdaLoRA 參數")
1686
+
1687
+ adalora_init_r_input = gr.Slider(
1688
+ minimum=4,
1689
+ maximum=64,
1690
+ value=12,
1691
+ step=4,
1692
+ label="初始 Rank",
1693
+ info="訓練開始時的秩"
1694
+ )
1695
+
1696
+ adalora_target_r_input = gr.Slider(
1697
+ minimum=4,
1698
+ maximum=64,
1699
+ value=8,
1700
+ step=4,
1701
+ label="目標 Rank",
1702
+ info="訓練結束時的目標秩"
1703
+ )
1704
+
1705
+ adalora_alpha_input = gr.Slider(
1706
+ minimum=8,
1707
+ maximum=128,
1708
+ value=32,
1709
+ step=8,
1710
+ label="LoRA Alpha",
1711
+ info="縮放參數"
1712
+ )
1713
+
1714
+ adalora_tinit_input = gr.Number(
1715
+ value=0,
1716
+ label="Tinit",
1717
+ info="開始剪枝的步數"
1718
+ )
1719
+
1720
+ adalora_tfinal_input = gr.Number(
1721
+ value=0,
1722
+ label="Tfinal",
1723
+ info="結束剪枝的步數"
1724
+ )
1725
+
1726
+ adalora_delta_t_input = gr.Number(
1727
+ value=1,
1728
+ label="Delta T",
1729
+ info="剪枝頻率"
1730
+ )
1731
+
1732
+ # Adapter 參數
1733
+ with gr.Column(visible=False) as adapter_params:
1734
+ gr.Markdown("### 🔶 Adapter 參數")
1735
+
1736
+ adapter_reduction_input = gr.Slider(
1737
+ minimum=2,
1738
+ maximum=64,
1739
+ value=16,
1740
+ step=2,
1741
+ label="Reduction Factor",
1742
+ info="降維因子,越大參數越少"
1743
+ )
1744
+
1745
+ # Prompt Tuning 參數
1746
+ with gr.Column(visible=False) as prompt_tuning_params:
1747
+ gr.Markdown("### 🔷 Prompt Tuning 參數")
1748
+
1749
+ prompt_tokens_input = gr.Slider(
1750
+ minimum=1,
1751
+ maximum=100,
1752
+ value=20,
1753
+ step=1,
1754
+ label="Virtual Tokens 數量"
1755
+ )
1756
+
1757
+ # Prefix Tuning 參數
1758
+ with gr.Column(visible=False) as prefix_tuning_params:
1759
+ gr.Markdown("### 🔶 Prefix Tuning 參數")
1760
+ gr.Markdown("⚠️ **注意**: 目前版本可能有兼容性問題,建議使用 Prompt Tuning")
1761
+
1762
+ prefix_tokens_input = gr.Slider(
1763
+ minimum=1,
1764
+ maximum=100,
1765
+ value=30,
1766
+ step=1,
1767
+ label="Virtual Tokens 數量"
1768
+ )
1769
+
1770
+ train_button = gr.Button(
1771
+ "🚀 開始第一次微調",
1772
+ variant="primary",
1773
+ size="lg"
1774
+ )
1775
+
1776
+ with gr.Column(scale=2):
1777
+ gr.Markdown("### 📊 第一次微調結果與比較")
1778
+
1779
+ # 第一格:資料資訊
1780
+ data_info_output = gr.Markdown(
1781
+ value="### 等待訓練...\n\n訓練完成後會顯示資料資訊和訓練配置",
1782
+ label="資料資訊"
1783
+ )
1784
+
1785
+ # 第二和第三格:並排顯示
1786
+ with gr.Row():
1787
+ # 第二格:未微調 Llama
1788
+ baseline_output = gr.Markdown(
1789
+ value="### 未微調 Llama\n等待訓練完成...",
1790
+ label="未微調 Llama"
1791
+ )
1792
+
1793
+ # 第三格:微調後 Llama
1794
+ finetuned_output = gr.Markdown(
1795
+ value="### 第一次微調 Llama\n等待訓練完成...",
1796
+ label="第一次微調 Llama"
1797
+ )
1798
+
1799
+ # Tab 2: 二次微調
1800
+ with gr.Tab("2️⃣ 二次微調"):
1801
+ with gr.Row():
1802
+ with gr.Column(scale=1):
1803
+ gr.Markdown("### 🔄 選擇基礎模型")
1804
+ base_model_dropdown = gr.Dropdown(
1805
+ label="選擇第一次微調的模型",
1806
+ choices=["請先進行第一次微調"],
1807
+ value="請先進行第一次微調"
1808
+ )
1809
+ refresh_base_models = gr.Button("🔄 重新整理模型列表", size="sm")
1810
+
1811
+ gr.Markdown("### 📤 上傳新訓練數據")
1812
+ file_input_second = gr.File(label="上傳新的訓練數據 CSV", file_types=[".csv"])
1813
+
1814
+ gr.Markdown("### ⚙️ 訓練參數")
1815
+ gr.Markdown("⚠️ 微調方法將自動繼承第一次微調的方法")
1816
+ best_metric_second = gr.Dropdown(
1817
+ choices=["f1", "accuracy", "precision", "recall", "sensitivity", "specificity"],
1818
+ value="f1",
1819
+ label="選擇最佳化指標"
1820
+ )
1821
+
1822
+ target_samples_second = gr.Number(
1823
+ value=700,
1824
+ label="目標樣本數(每類別)"
1825
+ )
1826
+
1827
+ use_weights_second = gr.Checkbox(
1828
+ value=True,
1829
+ label="使用類別權重"
1830
+ )
1831
+
1832
+ epochs_input_second = gr.Number(value=3, label="訓練輪數", info="建議比第一次少")
1833
+ batch_size_input_second = gr.Number(value=4, label="批次大小")
1834
+ lr_input_second = gr.Number(value=5e-5, label="學習率", info="建議比第一次小")
1835
+
1836
+ train_button_second = gr.Button("🚀 開始二次微調", variant="primary", size="lg")
1837
+
1838
+ with gr.Column(scale=2):
1839
+ gr.Markdown("### 📊 二次微調結果")
1840
+ data_info_output_second = gr.Markdown(value="等待訓練...")
1841
+ finetuned_output_second = gr.Markdown(value="### 二次微調\n等待訓練...")
1842
+
1843
+ # Tab 3: 新數據測試
1844
+ with gr.Tab("3️⃣ 新數據測試"):
1845
+ with gr.Row():
1846
+ with gr.Column(scale=1):
1847
+ gr.Markdown("### 📤 上傳測試數據")
1848
+ test_file_input = gr.File(label="上傳測試數據 CSV", file_types=[".csv"])
1849
+
1850
+ gr.Markdown("### 🎯 選擇要比較的模型")
1851
+ gr.Markdown("可選擇 1-3 個模型進行比較")
1852
+
1853
+ baseline_test_choice = gr.Radio(
1854
+ choices=["評估純 Llama", "跳過"],
1855
+ value="評估純 Llama",
1856
+ label="純 Llama (Baseline)"
1857
+ )
1858
+
1859
+ first_model_test_dropdown = gr.Dropdown(
1860
+ label="第一次微調模型",
1861
+ choices=["請選擇"],
1862
+ value="請選擇"
1863
+ )
1864
+
1865
+ second_model_test_dropdown = gr.Dropdown(
1866
+ label="第二次微調模型",
1867
+ choices=["請選擇"],
1868
+ value="請選擇"
1869
+ )
1870
+
1871
+ refresh_test_models = gr.Button("🔄 重新整理模型列表", size="sm")
1872
+ test_button = gr.Button("📊 開始測試", variant="primary", size="lg")
1873
+
1874
+ with gr.Column(scale=2):
1875
+ gr.Markdown("### 📊 新數據測試結果 - 三模型比較")
1876
+ with gr.Row():
1877
+ baseline_test_output = gr.Markdown(value="### 純 Llama\n等待測試...")
1878
+ first_test_output = gr.Markdown(value="### 第一次微調\n等待測試...")
1879
+ second_test_output = gr.Markdown(value="### 二次微調\n等待測試...")
1880
+
1881
+ # Tab 4: 模型預測
1882
+ with gr.Tab("4️⃣ 模型預測"):
1883
+ gr.Markdown("""
1884
+ ### 使用訓練好的模型進行預測
1885
+
1886
+ 選擇已訓練的模型,輸入文本進行預測。會同時顯示未微調和微調模型的預測結果以供比較。
1887
+ """)
1888
+
1889
+ with gr.Row():
1890
+ with gr.Column():
1891
+ # 模型選擇下拉選單
1892
+ model_dropdown = gr.Dropdown(
1893
+ label="選擇模型",
1894
+ choices=["請先訓練模型"],
1895
+ value="請先訓練模型",
1896
+ info="選擇要使用的已訓練模型"
1897
+ )
1898
+
1899
+ refresh_button = gr.Button(
1900
+ "🔄 重新整理模型列表",
1901
+ size="sm"
1902
+ )
1903
+
1904
+ text_input = gr.Textbox(
1905
+ label="輸入文本",
1906
+ placeholder="請輸入要預測的文本...",
1907
+ lines=10
1908
+ )
1909
+
1910
+ predict_button = gr.Button(
1911
+ "🔮 開始預測",
1912
+ variant="primary",
1913
+ size="lg"
1914
+ )
1915
+
1916
+ with gr.Column():
1917
+ gr.Markdown("### 預測結果比較")
1918
+
1919
+ # 上框:未微調 Llama 預測結果
1920
+ baseline_prediction_output = gr.Markdown(
1921
+ label="未微調 Llama",
1922
+ value="等待預測..."
1923
+ )
1924
+
1925
+ # 下框:微調 Llama 預測結果
1926
+ finetuned_prediction_output = gr.Markdown(
1927
+ label="微調 Llama",
1928
+ value="等待預測..."
1929
+ )
1930
+
1931
+ # Tab 5: 使用說明
1932
+ with gr.Tab("📖 使用說明"):
1933
+ gr.Markdown("""
1934
+ ## 🔄 二次微調流程說明
1935
+
1936
+ ### 步驟 1: 第一次微調
1937
+ 1. 上傳訓練數據 A (CSV 格式: Text, label)
1938
+ 2. 選擇微調方法 (LoRA / AdaLoRA / Adapter / BitFit / Prompt Tuning)
1939
+ 3. 調整訓練參數
1940
+ 4. 開始訓練
1941
+ 5. 系統會自動比較純 Llama vs 第一次微調的表現
1942
+
1943
+ ### 步驟 2: 二次微調
1944
+ 1. 選擇已訓練的第一次微調模型
1945
+ 2. 上傳新的訓練數據 B
1946
+ 3. 調整訓練參數 (建議 epochs 更小, learning rate 更小)
1947
+ 4. 開始訓練 (方法自動繼承第一次)
1948
+ 5. 模型會基於第一次的權重繼續學習
1949
+
1950
+ ### 步驟 3: 預測
1951
+ 1. 選擇任一已訓練模型
1952
+ 2. 輸入文本
1953
+ 3. 查看預測結果
1954
+
1955
+ ## 🎯 微調方法說明
1956
+
1957
+ | 方法 | 參數量 | 記憶體 | 訓練速度 | 適用場景 |
1958
+ |------|--------|--------|----------|----------|
1959
+ | **LoRA** | 很少 (~1%) | 低 | 快 | 通用,效果好 |
1960
+ | **AdaLoRA** | 很少 (~1%) | 低 | 快 | 自適應,效果更優 |
1961
+ | **Adapter** | 少 (~2-5%) | 低 | 中 | 多任務學習 |
1962
+ | **BitFit** | 極少 (~0.1%) | 極低 | 極快 | 快速微調 |
1963
+ | **Prompt Tuning** | 極少 (可調) | 極低 | 快 | 小數據集 |
1964
+
1965
+ ## 💡 二次微調建議
1966
+
1967
+ ### 訓練參數調整:
1968
+ - **Epochs**: 第二次建議 3-5 輪 (第一次通常 8-10 輪)
1969
+ - **Learning Rate**: 第二次建議 5e-5 (第一次通常 1e-4)
1970
+ - **Warmup Steps**: 第二次建議減半
1971
+
1972
+ ### 適用場景:
1973
+ 1. **領域適應**: 第一次用通用醫療數據,第二次用特定醫院數據
1974
+ 2. **增量學習**: 隨時間增加新病例數據
1975
+ 3. **數據稀缺**: 先用大量相關數據預訓練,再用少量目標數據微調
1976
+
1977
+ ## ⚠️ 注意事項
1978
+
1979
+ - CSV 格式必須包含 `Text` 和 `label` 欄位
1980
+ - 第二次微調會自動使用第一次的微調方法
1981
+ - 建議第二次的學習率比第一次小,避免破壞已學習的知識
1982
+ - 訓練時間依資料量和硬體而定(10-30 分鐘)
1983
+ - 需要 Hugging Face Token 才能下載 Llama 模型
1984
+ - GPU 訓練效果最佳,CPU 會非常慢
1985
+
1986
+ ## 📊 指標說明
1987
+
1988
+ - **F1 Score**: 精確率和召回率的調和平均,平衡指標
1989
+ - **Accuracy**: 整體準確率
1990
+ - **Precision**: 預測為正類中的準確率
1991
+ - **Recall/Sensitivity**: 實際正類中被正確識別的比例
1992
+ - **Specificity**: 實際負類中被正確識別的比例
1993
+
1994
+ ## 🔧 已修復的問題
1995
+
1996
+ - ✅ **AdaLoRA**: 簡化配置參數,避免版本兼容性問題
1997
+ - ✅ **BitFit**: 正確處理 gradient 設置,包含分類頭訓練
1998
+ - ✅ **參數顯示**: AdaLoRA 現在會正確顯示專屬參數界面
1999
+ - ❌ **Prefix Tuning**: 因 PEFT 版本問題暫時移除,請用 Prompt Tuning 替代
2000
+
2001
+ ## 🔐 設定 HF Token
2002
+
2003
+ 在環境變數中設定:
2004
+ ```
2005
+ export HF_TOKEN=your_token_here
2006
+ ```
2007
+ """)
2008
+
2009
+ # ==================== 事件綁定 ====================
2010
+
2011
+ # 根據選擇的微調方法顯示/隱藏相應參數
2012
+ def update_params_visibility(method):
2013
+ if method == "LoRA":
2014
+ return (
2015
+ gr.update(visible=True), # lora_params
2016
+ gr.update(visible=False), # adalora_params
2017
+ gr.update(visible=False), # adapter_params
2018
+ gr.update(visible=False), # prompt_tuning_params
2019
+ gr.update(visible=False) # prefix_tuning_params
2020
+ )
2021
+ elif method == "AdaLoRA":
2022
+ return (
2023
+ gr.update(visible=False), # lora_params
2024
+ gr.update(visible=True), # adalora_params
2025
+ gr.update(visible=False), # adapter_params
2026
+ gr.update(visible=False), # prompt_tuning_params
2027
+ gr.update(visible=False) # prefix_tuning_params
2028
+ )
2029
+ elif method == "Adapter":
2030
+ return (
2031
+ gr.update(visible=False),
2032
+ gr.update(visible=False),
2033
+ gr.update(visible=True),
2034
+ gr.update(visible=False),
2035
+ gr.update(visible=False)
2036
+ )
2037
+ elif method == "Prompt Tuning":
2038
+ return (
2039
+ gr.update(visible=False),
2040
+ gr.update(visible=False),
2041
+ gr.update(visible=False),
2042
+ gr.update(visible=True),
2043
+ gr.update(visible=False)
2044
+ )
2045
+ elif method == "Prefix Tuning":
2046
+ return (
2047
+ gr.update(visible=False),
2048
+ gr.update(visible=False),
2049
+ gr.update(visible=False),
2050
+ gr.update(visible=False),
2051
+ gr.update(visible=True)
2052
+ )
2053
+ else: # BitFit
2054
+ return (
2055
+ gr.update(visible=False),
2056
+ gr.update(visible=False),
2057
+ gr.update(visible=False),
2058
+ gr.update(visible=False),
2059
+ gr.update(visible=False)
2060
+ )
2061
+
2062
+ tuning_method.change(
2063
+ fn=update_params_visibility,
2064
+ inputs=[tuning_method],
2065
+ outputs=[lora_params, adalora_params, adapter_params, prompt_tuning_params, prefix_tuning_params]
2066
+ )
2067
+
2068
+ # 設定第一次微調按鈕動作
2069
+ train_button.click(
2070
+ fn=train_first_wrapper,
2071
+ inputs=[
2072
+ file_input,
2073
+ model_name_input,
2074
+ target_samples_input,
2075
+ use_weights_checkbox,
2076
+ epochs_input,
2077
+ batch_size_input,
2078
+ lr_input,
2079
+ tuning_method,
2080
+ lora_r_input,
2081
+ lora_alpha_input,
2082
+ lora_dropout_input,
2083
+ lora_target_input,
2084
+ adalora_init_r_input,
2085
+ adalora_target_r_input,
2086
+ adalora_alpha_input,
2087
+ adalora_tinit_input,
2088
+ adalora_tfinal_input,
2089
+ adalora_delta_t_input,
2090
+ adapter_reduction_input,
2091
+ prompt_tokens_input,
2092
+ prefix_tokens_input,
2093
+ best_metric
2094
+ ],
2095
+ outputs=[data_info_output, baseline_output, finetuned_output]
2096
+ )
2097
+
2098
+ # 重新整理基礎模型列表按鈕
2099
+ def refresh_base_models_list():
2100
+ choices = get_first_finetuning_models()
2101
+ return gr.update(choices=choices, value=choices[0])
2102
+
2103
+ refresh_base_models.click(
2104
+ fn=refresh_base_models_list,
2105
+ outputs=[base_model_dropdown]
2106
+ )
2107
+
2108
+ # 二次微調按鈕
2109
+ train_button_second.click(
2110
+ fn=train_second_wrapper,
2111
+ inputs=[
2112
+ base_model_dropdown,
2113
+ file_input_second,
2114
+ target_samples_second,
2115
+ use_weights_second,
2116
+ epochs_input_second,
2117
+ batch_size_input_second,
2118
+ lr_input_second,
2119
+ best_metric_second
2120
+ ],
2121
+ outputs=[data_info_output_second, finetuned_output_second]
2122
+ )
2123
+
2124
+ # 重新整理測試模型列表
2125
+ def refresh_test_models_list():
2126
+ all_models = get_available_models()
2127
+ first_models = get_first_finetuning_models()
2128
+
2129
+ # 篩選第二次微調模型
2130
+ with open('./saved_llama_models_list.json', 'r') as f:
2131
+ models_list = json.load(f)
2132
+ second_models = [m['model_path'] for m in models_list if m.get('is_second_finetuning', False)]
2133
+
2134
+ if len(second_models) == 0:
2135
+ second_models = ["請選擇"]
2136
+
2137
+ return (
2138
+ gr.update(choices=first_models if first_models[0] != "請先進行第一次微調" else ["請選擇"], value="請選擇"),
2139
+ gr.update(choices=second_models, value="請選擇")
2140
+ )
2141
+
2142
+ refresh_test_models.click(
2143
+ fn=refresh_test_models_list,
2144
+ outputs=[first_model_test_dropdown, second_model_test_dropdown]
2145
+ )
2146
+
2147
+ # 測試按鈕
2148
+ test_button.click(
2149
+ fn=test_new_data_wrapper,
2150
+ inputs=[test_file_input, baseline_test_choice, first_model_test_dropdown, second_model_test_dropdown],
2151
+ outputs=[baseline_test_output, first_test_output, second_test_output]
2152
+ )
2153
+
2154
+ # 重新整理模型列表按鈕
2155
+ def refresh_models():
2156
+ return gr.update(choices=get_available_models(), value=get_available_models()[0])
2157
+
2158
+ refresh_button.click(
2159
+ fn=refresh_models,
2160
+ inputs=[],
2161
+ outputs=[model_dropdown]
2162
+ )
2163
+
2164
+ # 預測按鈕動作
2165
+ predict_button.click(
2166
+ fn=predict_text,
2167
+ inputs=[model_dropdown, text_input],
2168
+ outputs=[baseline_prediction_output, finetuned_prediction_output]
2169
+ )
2170
+
2171
+ if __name__ == "__main__":
2172
+ demo.launch()