zhangyikai commited on
Commit
89c9672
·
1 Parent(s): 8518eef

Upload V0 model and UI

Browse files
app.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import json
4
+ import os
5
+ import torch
6
+ from huggingface_hub import hf_hub_download, snapshot_download # 引入下载工具
7
+
8
+ # ==========================================
9
+ # 0. 模型初始化
10
+ # ==========================================
11
+ MODEL_REPO_ID = "Now-Join-Us/Generalist-Value-Model-V0"
12
+ EMBEDDING_REPO_ID = "Qwen/Qwen3-Embedding-0.6B"
13
+
14
+ v0_model = None
15
+
16
+ print(">>> Starting V0 App...")
17
+
18
+ try:
19
+ from v0_core.models.v0 import V0
20
+
21
+ print(f">>> Downloading models...")
22
+
23
+ # 1. 下载你的训练权重
24
+ checkpoint_path = hf_hub_download(
25
+ repo_id=MODEL_REPO_ID,
26
+ filename="v_0_for_grpo_training.pt"
27
+ )
28
+
29
+ # 2. 下载 TabPFN
30
+ tabpfn_path = hf_hub_download(
31
+ repo_id=MODEL_REPO_ID,
32
+ filename="tabpfn-v2.5-classifier-v2.5_default.ckpt"
33
+ )
34
+
35
+ # 3. 下载 Qwen Embedding
36
+ embedding_path = snapshot_download(
37
+ repo_id=EMBEDDING_REPO_ID
38
+ )
39
+
40
+ print(">>> Models downloaded. Initializing V0 class...")
41
+
42
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
43
+ device = "cpu"
44
+
45
+ # 加载模型
46
+ v0_model = V0.from_pretrained(
47
+ checkpoint_path=checkpoint_path,
48
+ embedding_model_path=embedding_path,
49
+ tabpfn_head_path=tabpfn_path,
50
+ device=device
51
+ )
52
+ print(f">>> V0 Model Loaded Successfully on {device}!")
53
+
54
+ except Exception as e:
55
+ print(f"Error loading model: {e}")
56
+ print("UI will run in Mock Mode.")
57
+ v0_model = None
58
+
59
+ # ==========================================
60
+ # 1. 核心逻辑
61
+ # ==========================================
62
+
63
+ # 默认数据 (作为 Context)
64
+ history_default = [
65
+ {"prompt": "Let $d(m)$ denote the number of positive integer divisors of a positive integer $m$. If $r$ is the number of integers $n \\leq 2023$ for which $\\sum_{i=1}^{n} d(i)$ is odd, find the sum of the digits of $r$.", "is_correct": True},
66
+ {"prompt": "设在 $5 \\times 5$ 的方格表的第 $i$ 行第 $j$ 列所填的数为 $a_{i j}\\left(a_{i j} \\in\\{0,1\\}\\right), a_{i j}=a_{j i}(1 \\leqslant i、j \\leqslant 5)$ .则表中共有五个 1 的填表方法总数为 $\\qquad$ (用具体数字作答).", "is_correct": True},
67
+ {"prompt": "Suppose $x, y \\in \\mathbb{Z}$ satisfy the equation:\n\\[\ny^4 + 4y^3 + 28y + 8x^3 + 6y^2 + 32x + 1 = (x^2 - y^2)(x^2 + y^2 + 24).\n\\]\nFind the sum of all possible values of $|xy|$.", "is_correct": False},
68
+ {"prompt": "Three builders are scheduled to build a house in 60 days. However, they procrastinate and do nothing for the first 50 days. To complete the house on time, they decide to hire more workers and work at twice their original speed. If the new workers also work at this doubled rate, how many new workers are needed? Assume each builder works at the same rate and does not interfere with the others.", "is_correct": True},
69
+ {"prompt": "Let $P_0 = (3,1)$ and define $P_{n+1} = (x_n, y_n)$ for $n \\ge 0$ by \\[ x_{n+1} = - \\frac{3x_n - y_n}{2}, \\quad y_{n+1} = - \\frac{x_n + y_n}{2} \\] Find the area of the quadrilateral formed by the points $P_{96}, P_{97}, P_{98}, P_{99}$.", "is_correct": False}
70
+ ]
71
+
72
+ def format_model_card(data_list, model_name, is_custom=False):
73
+ if not data_list:
74
+ if is_custom:
75
+ return f"<div class='model-card empty'><div class='card-title'>No Custom Model Uploaded</div></div>"
76
+ return ""
77
+
78
+ total = len(data_list)
79
+ rows_html = ""
80
+ preview_limit = 3
81
+ preview_data = data_list[:preview_limit]
82
+
83
+ for item in preview_data:
84
+ p_text = item.get('prompt', '')
85
+ if len(p_text) > 64:
86
+ p_text = p_text[:64] + "..."
87
+
88
+ is_acc = item.get('is_correct', False)
89
+ status_class = "status-green" if is_acc else "status-red"
90
+ icon = "✔" if is_acc else "✘"
91
+
92
+ rows_html += f"""
93
+ <div class='history-row'>
94
+ <div class='status-box {status_class}'>{icon}</div>
95
+ <div class='prompt-text'>{p_text}</div>
96
+ </div>
97
+ """
98
+
99
+ remaining = total - preview_limit
100
+ if remaining > 0:
101
+ rows_html += f"<div class='history-more'>+ {remaining} more items</div>"
102
+
103
+ return f"""
104
+ <div class='model-card populated'>
105
+ <div class='card-header'>
106
+ <span class='model-name'>{model_name}</span>
107
+ <span class='acc-badge'>Total Samples: {total}</span>
108
+ </div>
109
+ <div class='card-body'>
110
+ <div class='history-container'>{rows_html}</div>
111
+ </div>
112
+ </div>
113
+ """
114
+
115
+ def process_upload(file_obj):
116
+ if file_obj is None:
117
+ return None, format_model_card(None, "Custom", True)
118
+
119
+ content = []
120
+ try:
121
+ with open(file_obj.name, 'r', encoding='utf-8') as f:
122
+ for line in f:
123
+ line = line.strip()
124
+ if line:
125
+ json_obj = json.loads(line)
126
+ content.append(json_obj)
127
+
128
+ if not content: return None, "<div class='model-card empty'>File is empty</div>"
129
+ if 'is_correct' not in content[0]: return None, "<div class='model-card empty'>Missing 'is_correct' field</div>"
130
+
131
+ # 简单的验证逻辑
132
+ has_positive = any(item.get('is_correct') for item in content)
133
+ has_negative = any(not item.get('is_correct') for item in content)
134
+
135
+ if not (has_positive and has_negative):
136
+ return None, """
137
+ <div class='model-card empty' style='border-color: var(--fail); color: var(--fail);'>
138
+ <div class='card-title'>Invalid Dataset Distribution</div>
139
+ <div class='card-subtitle'>Please upload at least one positive AND one negative sample.</div>
140
+ </div>
141
+ """
142
+
143
+ return content, format_model_card(content, "Custom Model")
144
+
145
+ except json.JSONDecodeError:
146
+ return None, f"<div class='model-card empty'>Invalid JSONL Format</div>"
147
+ except Exception as e:
148
+ return None, f"<div class='model-card empty'>Error: {str(e)}</div>"
149
+
150
+ def predict_performance(default_data, custom_data, t1, t2, t3):
151
+ """
152
+ 使用加载的 V0 模型进行预测。
153
+ """
154
+ targets = [t for t in [t1, t2, t3] if t.strip()]
155
+ if not targets:
156
+ return pd.DataFrame([{"Error": "Please enter at least one target prompt."}])
157
+
158
+ models_to_run = []
159
+ if default_data:
160
+ models_to_run.append(("Qwen3-4B-Instruct-2507", default_data))
161
+ if custom_data:
162
+ models_to_run.append(("Custom Uploaded Model", custom_data))
163
+
164
+ results = []
165
+
166
+ for m_name, m_history in models_to_run:
167
+ context_prompts = [item['prompt'] for item in m_history]
168
+ context_labels = [1 if item.get('is_correct') else 0 for item in m_history]
169
+
170
+ scores = []
171
+
172
+ if v0_model:
173
+ try:
174
+ # print(f"Running inference for {m_name} on {len(targets)} targets with {len(context_prompts)} context examples...")
175
+ scores = v0_model.predict(
176
+ context_prompts=context_prompts,
177
+ context_labels=context_labels,
178
+ target_prompts=targets
179
+ )
180
+ except Exception as e:
181
+ print(f"Inference Error: {e}")
182
+ scores = [0.0] * len(targets)
183
+ else:
184
+ import random
185
+ scores = [random.uniform(0.1, 0.9) for _ in targets]
186
+
187
+ for t_text, score in zip(targets, scores):
188
+ # 处理 Tensor 或 float
189
+ if isinstance(score, torch.Tensor):
190
+ final_score = score.item()
191
+ else:
192
+ final_score = float(score)
193
+
194
+ if final_score > 0.5:
195
+ pred_str = "✔ Success"
196
+ else:
197
+ pred_str = "✘ Failure"
198
+
199
+ results.append({
200
+ "Model": m_name,
201
+ "Instruction": t_text,
202
+ "Predicted Value Score": round(final_score, 4),
203
+ "Prediction": pred_str
204
+ })
205
+
206
+ df = pd.DataFrame(results)
207
+ return df
208
+
209
+ # ==========================================
210
+ # 2. CSS 样式
211
+ # ==========================================
212
+ css = """
213
+ /* 全局变量 */
214
+ :root {
215
+ --primary: #10b981;
216
+ --primary-light: #ecfdf5;
217
+ --primary-dark: #047857;
218
+ --bg-card: #ffffff;
219
+ --border-sub: #e5e7eb;
220
+ --text-main: #1f2937;
221
+ --text-sub: #6b7280;
222
+ --success: #10b981;
223
+ --fail: #ef4444;
224
+ --popup-bg: #ffffff;
225
+ --popup-text: #1f2937;
226
+ --popup-border: #e5e7eb;
227
+ --popup-shadow: rgba(0,0,0,0.15);
228
+ }
229
+ .dark {
230
+ --bg-card: #1f2937;
231
+ --border-sub: #374151;
232
+ --text-main: #f3f4f6;
233
+ --text-sub: #9ca3af;
234
+ --popup-bg: #2d2d2d;
235
+ --popup-text: #e5e5e5;
236
+ --popup-border: #4b5563;
237
+ --popup-shadow: rgba(0,0,0,0.4);
238
+ }
239
+ .label-row { display: flex; align-items: center; margin-bottom: 6px; font-family: 'Source Sans Pro', sans-serif; }
240
+ .upload-label-text { font-size: 1rem; color: var(--text-main); margin-right: 8px; }
241
+ .format-hint-wrapper { display: inline-block; position: relative; cursor: help; font-size: 0.9rem; color: var(--primary); font-weight: 600; border-bottom: 1px dashed var(--primary); line-height: 1.2; }
242
+ .format-popup {
243
+ visibility: hidden; opacity: 0; position: absolute; bottom: 145%; left: -20px; width: 450px;
244
+ background: var(--popup-bg); color: var(--popup-text); border: 1px solid var(--popup-border);
245
+ padding: 16px; border-radius: 8px; box-shadow: 0 10px 30px var(--popup-shadow); z-index: 1000;
246
+ transition: all 0.2s cubic-bezier(0.165, 0.84, 0.44, 1); transform: translateY(10px); pointer-events: none;
247
+ font-size: 0.95rem; line-height: 1.5;
248
+ }
249
+ .format-hint-wrapper:hover .format-popup { visibility: visible; opacity: 1; transform: translateY(0); }
250
+ .format-popup::after {
251
+ content: ""; position: absolute; top: 100%; left: 60px; border-width: 8px; border-style: solid;
252
+ border-color: var(--popup-bg) transparent transparent transparent;
253
+ }
254
+ .code-snippet {
255
+ display: block; background: #1a1a1a; color: #a7f3d0; font-family: 'Courier New', monospace;
256
+ font-size: 0.85em; padding: 8px; border-radius: 6px; margin-top: 6px; white-space: pre; border: 1px solid #444;
257
+ }
258
+ .concept-banner {
259
+ background: linear-gradient(135deg, rgba(16, 185, 129, 0.08) 0%, rgba(59, 130, 246, 0.05) 100%);
260
+ border: 1px solid var(--primary-light); border-radius: 12px; padding: 24px; text-align: center; margin-bottom: 30px;
261
+ }
262
+ .concept-title { font-size: 1.8em; font-weight: 700; color: var(--text-main); margin-bottom: 8px;}
263
+ .concept-subtitle { font-size: 1em; color: var(--text-sub); max-width: 600px; margin: 0 auto; line-height: 1.5; }
264
+ .equation-box {
265
+ margin-top: 15px; font-family: 'Courier New', monospace; font-weight: bold;
266
+ color: var(--primary); background: var(--bg-card); display: inline-block;
267
+ padding: 8px 16px; border-radius: 8px; border: 1px dashed var(--primary);
268
+ box-shadow: 0 2px 6px rgba(0,0,0,0.05);
269
+ }
270
+ .step-header { display: flex; align-items: center; margin-bottom: 15px; border-bottom: 2px solid var(--border-sub); padding-bottom: 10px; }
271
+ .step-num {
272
+ background: var(--primary); color: white; width: 28px; height: 28px;
273
+ border-radius: 50%; display: flex; align-items: center; justify-content: center;
274
+ font-weight: bold; margin-right: 10px; font-size: 0.9em;
275
+ }
276
+ .step-title { font-size: 1.2em; font-weight: 600; color: var(--text-main); }
277
+ .step-desc { font-size: 0.93em; color: var(--text-sub); margin-left: auto; font-style: italic;}
278
+ .model-card {
279
+ background: var(--bg-card); border: 1px solid var(--border-sub);
280
+ border-radius: 10px; padding: 16px; margin-bottom: 15px;
281
+ transition: all 0.2s; position: relative; overflow: hidden;
282
+ }
283
+ .model-card.populated { border-left: 5px solid var(--primary); box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.05); }
284
+ .model-card.empty { border: 2px dashed var(--border-sub); text-align: center; opacity: 0.7; padding: 30px 16px; }
285
+ .card-title { font-weight: bold; color: var(--text-sub); }
286
+ .card-subtitle { font-size: 0.8em; color: var(--text-sub); }
287
+ .card-header { display: flex; justify-content: space-between; align-items: center; margin-bottom: 15px; }
288
+ .model-name { font-weight: bold; font-size: 1.1em; color: var(--text-main); }
289
+ .acc-badge { background: var(--primary-light); color: var(--primary-dark); font-size: 0.75em; padding: 3px 10px; border-radius: 12px; font-weight: 700; }
290
+ .history-container { display: flex; flex-direction: column; gap: 8px; margin-bottom: 15px; }
291
+ .history-row { display: flex; align-items: center; background: rgba(0,0,0,0.02); padding: 6px 8px; border-radius: 6px; }
292
+ .status-box {
293
+ width: 24px; height: 24px; border-radius: 6px; display: flex; align-items: center; justify-content: center;
294
+ color: white; font-size: 0.8em; font-weight: bold; margin-right: 10px; flex-shrink: 0;
295
+ }
296
+ .status-green { background-color: var(--success); }
297
+ .status-red { background-color: var(--fail); }
298
+ .prompt-text {
299
+ font-size: 0.9em; color: var(--text-main); white-space: nowrap; overflow: hidden; text-overflow: ellipsis;
300
+ }
301
+ .history-more { font-size: 0.95em; color: var(--text-sub); text-align: center; font-style: italic; margin-top: -4px; }
302
+ .custom-btn { font-weight: bold !important; font-size: 1.1em !important; }
303
+ .paper-link {
304
+ font-size: 0.5em; vertical-align: middle; color: var(--primary); text-decoration: none;
305
+ border: 1px solid var(--primary); padding: 4px 10px; border-radius: 15px; font-weight: normal;
306
+ transition: all 0.2s; background: transparent;
307
+ }
308
+ .paper-link:hover { background: var(--primary); color: white; }
309
+ """
310
+
311
+ # ==========================================
312
+ # 3. UI 构建
313
+ # ==========================================
314
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="emerald"), css=css, title="V0 Predictor") as demo:
315
+
316
+ state_default = gr.State(value=history_default)
317
+ state_custom = gr.State(value=None)
318
+
319
+ gr.HTML("""
320
+ <div class="concept-banner">
321
+ <div class="concept-title">
322
+ V<sub>0</sub> Value Model
323
+ <a href="TBD" target="_blank" class="paper-link">Paper ↗</a>
324
+ <a href="TBD" target="_blank" class="paper-link">Code ↗</a>
325
+ </div>
326
+ <div class="concept-subtitle">
327
+ <span style="color: var(--primary); font-weight: bold;">Function:</span> V<sub>0</sub> uses a model's historical performance to predict<br>
328
+ how it will perform on unseen instructions<br>
329
+ without running the model itself.
330
+ </div>
331
+ <div class="equation-box">
332
+ Historical Perf. + Instruction &rarr; Predicted Perf.
333
+ </div>
334
+ </div>
335
+ """)
336
+
337
+ with gr.Row(equal_height=False):
338
+
339
+ with gr.Column(scale=1, variant="panel"):
340
+ gr.HTML("""
341
+ <div class="step-header">
342
+ <div class="step-num">1</div>
343
+ <div class="step-title">Represent Any Model with <span style="color: var(--primary);">Performance-Instruction Pairs</span></div>
344
+ </div>
345
+ """)
346
+
347
+ preview_default = gr.HTML(format_model_card(history_default, "Qwen3-4B-Instruct-2507"))
348
+
349
+ gr.HTML("""
350
+ <div class="label-row">
351
+ <span class="upload-label-text"><span style="font-weight: 800;">[Optional]</span> Upload Your Model</span>
352
+ <div class="format-hint-wrapper">
353
+ Required JSONL Format ⓘ
354
+ <div class="format-popup">
355
+ <div style="font-weight: bold; margin-bottom:4px;">File Content Example:</div>
356
+ <code class="code-snippet">
357
+ {"prompt": "Calculate 1+1", "is_correct": true}
358
+ {"prompt": "Write a poem", "is_correct": false}
359
+ </code>
360
+ <div style="margin-top:6px; font-size:0.9em; opacity: 0.8;">
361
+ Each line must be a valid JSON object containing <b>'prompt'</b> (string) and <b>'is_correct'</b> (boolean).
362
+ </div>
363
+ </div>
364
+ </div>
365
+ </div>
366
+ """)
367
+
368
+ upload_btn = gr.File(
369
+ label=None,
370
+ show_label=False,
371
+ file_types=[".jsonl"],
372
+ height=130
373
+ )
374
+ preview_custom = gr.HTML(format_model_card(None, "Custom", True))
375
+
376
+ with gr.Column(scale=1, variant="panel"):
377
+ gr.HTML("""
378
+ <div class="step-header" style="margin-top: 80px;">
379
+ <div class="step-num">2</div>
380
+ <div class="step-title">Enter Instructions</div>
381
+ <div class="step-desc">trigger V<sub>0</sub> to predict the expected perf. for each model</div>
382
+ </div>
383
+ """)
384
+ t1 = gr.Textbox(label="Instruction 1", value="What is the largest $n$ such that there exists a non-degenerate convex $n$-gon where each of its angles is an integer number of degrees, and all angles are distinct?", lines=2)
385
+ t2 = gr.Textbox(label="Instruction 2", value="已知四面体 \\(A B C D\\) 内接于球 \\(O\\),且 \\(A D\\) 是球 \\(O\\) 的直径。若 \\(\\triangle A B C\\) 和 \\(\\triangle B C D\\) 都是边长为 1 的等边三角形,则四面体 \\(A B C D\\) 的体积是多少?原始答案的形式为 \\(\\frac{\\sqrt{c}}{b}\\),请给出a+b+c的值。", lines=2)
386
+ t3 = gr.Textbox(label="Instruction 3", placeholder="Your instruction here ...", lines=2)
387
+
388
+ gr.HTML("""
389
+ <div style="margin-top: 15px; font-size: 1.05em; color: var(--text-main);">
390
+ <span style="color: var(--primary); font-weight: bold;">Next:</span> Clicking <b>Run V<sub>0</sub> Prediction!</b>
391
+ </div>
392
+ """)
393
+
394
+ with gr.Row():
395
+ with gr.Column():
396
+ predict_btn = gr.Button("Run V₀ Prediction", variant="primary", size="lg", elem_classes=["custom-btn"])
397
+
398
+ gr.HTML("""
399
+ <div class="step-header" style="margin-top: 20px; border-bottom: none;">
400
+ <div class="step-num">3</div>
401
+ <div class="step-title">Results</div>
402
+ </div>
403
+ """)
404
+
405
+ output_df = gr.Dataframe(
406
+ headers=["Model Entity", "Unseen Instruction", "Predicted Value Score", "Prediction"],
407
+ datatype=["str", "str", "number", "str"],
408
+ interactive=False,
409
+ column_widths=["20%", "40%", "20%", "20%"]
410
+ )
411
+
412
+ upload_btn.change(
413
+ fn=process_upload,
414
+ inputs=[upload_btn],
415
+ outputs=[state_custom, preview_custom]
416
+ )
417
+
418
+ predict_btn.click(
419
+ fn=predict_performance,
420
+ inputs=[state_default, state_custom, t1, t2, t3],
421
+ outputs=[output_df]
422
+ )
423
+
424
+ if __name__ == "__main__":
425
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ pandas
3
+ einops
4
+ numpy==2.2.6
5
+ scikit-learn==1.7.2
6
+ -e git+https://github.com/PriorLabs/TabPFN.git@2cd2326038e789a26f7a07e70e1ea986ffd040c9#egg=tabpfn
7
+ torch==2.7.1
8
+ tqdm==4.67.1
9
+ transformers==4.55.4
10
+ wandb==0.21.3
v0_core/config/__init__.py ADDED
File without changes
v0_core/config/arguments.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+
4
+ # =============================================================================
5
+ # 参数解析配置
6
+ # =============================================================================
7
+ def parse_args():
8
+ parser = argparse.ArgumentParser(description="Generalist Value Model")
9
+
10
+ # --- 路径相关 ---
11
+ parser.add_argument("--time_str", type=str, required=True)
12
+ parser.add_argument("--qwen_path", type=str, required=True, help="Qwen 模型路径")
13
+ parser.add_argument("--tabpfn_checkpoint", type=str, required=True, help="TabPFN Checkpoint 路径")
14
+
15
+ # 数据路径配置
16
+ parser.add_argument("--context_data_paths", type=str, required=True, help="Context Pool Jsonl路径 (支持多个,逗号分隔)")
17
+ parser.add_argument("--train_data_paths", type=str, default=None, help="Train Query Pool Jsonl路径 (支持多个,逗号分隔)")
18
+ parser.add_argument("--eval_data_paths", type=str, default=None, help="Test Query Pool Jsonl路径 (支持多个,逗号分隔)")
19
+ parser.add_argument("--validity_data_paths", type=str, default=None, help="Validity Test Pool Jsonl路径 (支持多个,逗号分隔)")
20
+
21
+ parser.add_argument("--prompt_dict_path", type=str, required=True, help="Prompt 字典 JSON 路径")
22
+
23
+ # --- Checkpoint 保存相关 ---
24
+ parser.add_argument("--checkpoint_dir", type=str, default=None, help="模型保存目录")
25
+ parser.add_argument("--save_interval", type=int, default=1, help="每隔多少个 Epoch 保存一次模型")
26
+ parser.add_argument("--max_keep_checkpoints", type=int, default=2, help="最多保留多少个最新的 Checkpoint")
27
+ parser.add_argument("--resume", action="store_true", help="是否尝试从 checkpoint_dir 恢复训练")
28
+ parser.add_argument("--resume_from_specific_epoch", type=int, default=None, help="指定要 resume 的 epoch")
29
+
30
+ # --- 日志相关 ---
31
+ parser.add_argument("--log_path", type=str, default=None)
32
+ parser.add_argument("--log_interval", type=int, default=10, help="保存间隔")
33
+ parser.add_argument("--metric_path", type=str, default=None)
34
+ parser.add_argument("--wandb_project", type=str, default="context-v", help="Wandb 项目名称")
35
+ parser.add_argument("--wandb_interval", type=int, default=1, help="Wandb 记录间隔")
36
+ parser.add_argument("--wandb_id", type=str, default=None)
37
+
38
+ # --- 运行模式与策略 ---
39
+ parser.add_argument("--run_mode", type=str, default="eval", choices=["train", "eval"])
40
+ parser.add_argument("--pooling_strategy", type=str, default="dynamic_query",
41
+ choices=["last_token", "fixed_query", "dynamic_query"],
42
+ help="Embedding 提取策略")
43
+
44
+
45
+ parser.add_argument("--label_strategy", type=str, default="binary",
46
+ choices=["binary", "minmax_norm"],
47
+ help="Label 处理策略")
48
+ parser.add_argument("--loss_type", type=str, default="ce_hard",
49
+ choices=["ce_hard", "ce_soft", "kl_div", "pairwise", "combined"],
50
+ help="Loss 函数类型: combined = pairwise + ce_soft")
51
+ parser.add_argument("--loss_alpha", type=float, default=0.5,
52
+ help="Combined Loss 中 Pairwise 的权重 (0.0-1.0)。Total = alpha * Pair + (1-alpha) * CE")
53
+ parser.add_argument("--loss_balance", action="store_true", help="是否对正负样本加权")
54
+
55
+ parser.add_argument("--kl_temperature", type=float, default=1.0,
56
+ help="KL 散度或 Softmax 的温度系数 T")
57
+ # --- 降维参数 ---
58
+ parser.add_argument("--reduce_method", type=str, default="none",
59
+ choices=["none", "avg_pool", "max_pool"])
60
+ parser.add_argument("--target_dim", type=int, default=1024)
61
+ parser.add_argument("--num_heads", type=int, default=4)
62
+
63
+ parser.add_argument("--context_clustering", action="store_true", help="是否启用 Support Set 聚类筛选")
64
+ parser.add_argument("--context_num_clusters", type=int, default=128, help="聚类保留的原型数量 (k值)")
65
+
66
+ # --- 模型超参数 ---
67
+ parser.add_argument("--num_queries", type=int, default=10)
68
+ parser.add_argument("--embed_dim", type=int, default=32)
69
+ parser.add_argument("--tabpfn_estimators", type=int, default=4)
70
+ parser.add_argument("--dynamic_query_generator_bottleneck_dim", type=int, default=128)
71
+ parser.add_argument("--dynamic_query_generator_dropout_rate", type=float, default=0.2)
72
+
73
+ # --- 训练超参数 ---
74
+ parser.add_argument("--epochs", type=int, default=5)
75
+ parser.add_argument("--meta_batch_size", type=int, default=1, help="每次forward处理多少个Task(一个Task包含Support+Query)")
76
+ parser.add_argument("--grad_accum_steps", type=int, default=4)
77
+
78
+ parser.add_argument("--train_query_batch_size", type=int, default=8, help="每个Task包含多少个Query样本 (必须来自同一个Step)")
79
+ parser.add_argument("--eval_query_batch_size", type=int, default=8, help="每个Task包含��少个Query样本 (必须来自同一个Step)")
80
+ parser.add_argument("--support_size", type=int, default=256, help="每个Task采样的Context样本数量")
81
+
82
+ parser.add_argument("--lr_backbone", type=float, default=1e-5)
83
+ parser.add_argument("--lr_adapter", type=float, default=1e-4)
84
+ parser.add_argument("--lr_tabpfn", type=float, default=1e-5)
85
+
86
+ parser.add_argument("--weight_decay", type=float, default=0.01)
87
+ parser.add_argument("--warmup_ratio", type=float, default=0.05)
88
+ parser.add_argument("--lr_scheduler_type", type=str, default="cosine")
89
+ parser.add_argument("--max_grad_norm", type=float, default=1.0)
90
+
91
+ parser.add_argument("--train_embed_bs", type=int, default=4)
92
+ parser.add_argument("--eval_embed_bs", type=int, default=4)
93
+
94
+ parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
95
+
96
+ args = parser.parse_args()
97
+
98
+ def split_paths(path_str):
99
+ if not path_str: return []
100
+ return [p.strip() for p in path_str.split(',') if p.strip()]
101
+
102
+ args.context_data_paths = split_paths(args.context_data_paths)
103
+ args.train_data_paths = split_paths(args.train_data_paths)
104
+ args.eval_data_paths = split_paths(args.eval_data_paths)
105
+ args.validity_data_paths = split_paths(args.validity_data_paths)
106
+ args.prompt_dict_path = split_paths(args.prompt_dict_path)
107
+
108
+ return args
109
+
110
+ def print_elegant_args(args):
111
+ """
112
+ 打印参数列表
113
+ """
114
+ args_dict = vars(args)
115
+ keys = sorted(args_dict.keys())
116
+ # 计算最长键名以便对齐
117
+ max_k = max([len(k) for k in keys]) if keys else 10
118
+
119
+ # 定义颜色
120
+ C_KEY = "\033[36m" # 青色用于键
121
+ C_VALUE = "\033[33m" # 黄色用于值(如果不想要颜色,设为 "" 即可)
122
+ C_RESET = "\033[0m" # 重置
123
+
124
+ print(f"\n{C_VALUE}Arguments:{C_RESET}")
125
+
126
+ for k in keys:
127
+ val = str(args_dict[k])
128
+ # 格式说明:
129
+ # {k:<{max_k}} : 让键名左对齐并填充空格
130
+ # val : 完整打印值,不截断
131
+ print(f" {C_KEY}{k:<{max_k}}{C_RESET} : {val}")
132
+
133
+ print() # 打印末尾空行
v0_core/data/__init__.py ADDED
File without changes
v0_core/data/collator.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def meta_collate_fn(batch):
4
+ all_prompts = []
5
+ all_labels = []
6
+ metadata = []
7
+ current_start = 0
8
+ for item in batch:
9
+ t_len = len(item['prompts'])
10
+ all_prompts.extend(item['prompts'])
11
+ all_labels.append(item['labels'])
12
+ metadata.append({
13
+ 'start': current_start,
14
+ 'len': t_len,
15
+ 'split': item['split_idx'],
16
+ 'q_ids': item['q_ids'],
17
+ 'pair_ids': item['pair_ids'],
18
+ 'pair_types': item['pair_types'],
19
+ 'key': item['key'],
20
+ 'stats': item['stats']
21
+ })
22
+ current_start += t_len
23
+ return {
24
+ 'flat_prompts': all_prompts,
25
+ 'flat_labels': torch.cat(all_labels),
26
+ 'metadata': metadata
27
+ }
v0_core/data/dataset.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ import numpy as np
4
+ import random
5
+ from collections import defaultdict
6
+ from torch.utils.data import Dataset
7
+ from v0_core.data.utils import load_jsonl_lines
8
+
9
+ # =============================================================================
10
+ # 数据与日志工具
11
+ # =============================================================================
12
+ class ValueModelDataset(Dataset):
13
+ def __init__(self,
14
+ context_paths,
15
+ query_paths,
16
+ prompt_dict_path,
17
+ label_strategy='binary',
18
+ query_batch_size=8,
19
+ support_size=256,
20
+ mode='train'):
21
+ """
22
+ args:
23
+ context_paths: List of paths to context_pool jsonl files
24
+ query_paths: List of paths to query_pool jsonl files (train/test/validity)
25
+ prompt_dict_path: List of paths to prompt dictionaries
26
+ query_batch_size: Number of queries in one forward pass (all from same step)
27
+ support_size: Number of context samples to sample
28
+ mode: 'train' (shuffle queries before chunking) or 'eval' (sequential)
29
+ """
30
+ self.label_strategy = label_strategy
31
+ self.query_batch_size = query_batch_size
32
+ self.support_size = support_size
33
+ self.mode = mode
34
+
35
+ # 1. Load Prompt Dictionary
36
+ print(f"Loading prompts from {prompt_dict_path}...")
37
+ self.prompt_map = {}
38
+ for path in prompt_dict_path:
39
+ with open(path, 'r', encoding='utf-8') as f:
40
+ self.prompt_map.update(json.load(f))
41
+
42
+ # 2. Load Context Pool and Index it
43
+ # Structure: {(dataset, model, step): [list of sample dicts]}
44
+ print("Loading Context Pool...")
45
+ self.context_pool = defaultdict(list)
46
+ self.context_pool_fallback = defaultdict(list)
47
+ raw_context = load_jsonl_lines(context_paths)
48
+ for item in raw_context:
49
+ key = (item['dataset'], item['model'], item['step'])
50
+ self.context_pool[key].append(item)
51
+ fallback_key = (item['model'], item['step'])
52
+ self.context_pool_fallback[fallback_key].append(item)
53
+ print(f"Loaded Context Pool with {len(self.context_pool)} unique (dataset, model, step) keys.")
54
+
55
+ # 3. Load Query Pool
56
+ print(f"Loading Query Pool from {query_paths}...")
57
+ raw_queries = load_jsonl_lines(query_paths)
58
+
59
+ # 4. Group Queries by Key
60
+ self.queries_by_key = defaultdict(list)
61
+ print("Grouping Queries...")
62
+ for item in raw_queries:
63
+ key = (item['dataset'], item['model'], item['step'])
64
+ # Pre-fetch prompt text to save time later, if ID exists
65
+ s_id_str = f"{item['dataset']}_{item['id']}"
66
+ item['text'] = self.prompt_map[s_id_str]
67
+ self.queries_by_key[key].append(item)
68
+
69
+ # 5. Pre-calculate Class Statistics for Context-Aware Re-weighting
70
+ # 统计每个Context Key下,Query Pool中的正负样本总数,用于计算加权Loss
71
+ print("Calculating Global Context Statistics for Re-weighting...")
72
+ self.context_stats = {}
73
+ if mode == 'train':
74
+ for key, items in self.queries_by_key.items():
75
+ # 定义正样本: score >= 0
76
+ n_pos = sum(1 for x in items if float(x.get('score', -1)) >= 0)
77
+ n_neg = len(items) - n_pos
78
+ self.context_stats[key] = {'n_pos': n_pos, 'n_neg': n_neg}
79
+
80
+ print("\n" + "="*60)
81
+ print(f"Top 10 Steps Statistics ({mode} mode)")
82
+ print(f"{'Dataset':<15} | {'Model':<15} | {'Step':<6} | {'n_pos':<6} | {'n_neg':<6} | {'Total':<6}")
83
+ print("-" * 60)
84
+
85
+ sorted_keys = sorted(list(self.context_stats.keys()))
86
+
87
+ for i, key in enumerate(sorted_keys[:10]):
88
+
89
+ dataset_name, model_name, step_val = key
90
+ stats = self.context_stats[key]
91
+ total = stats['n_pos'] + stats['n_neg']
92
+ print(f"{dataset_name:<15} | {model_name:<15} | {str(step_val):<6} | "
93
+ f"{stats['n_pos']:<6} | {stats['n_neg']:<6} | {total:<6}")
94
+
95
+ print(f"... (Total {len(sorted_keys)} steps loaded)")
96
+ print("="*60 + "\n")
97
+
98
+ # 6. Create Tasks (Chunks of Queries)
99
+ self.tasks = []
100
+ self.generate_tasks(shuffle=(self.mode == 'train'))
101
+
102
+ print(f"Dataset Initialized. Total Tasks: {len(self.tasks)}")
103
+
104
+ def generate_tasks(self, shuffle=True):
105
+ """
106
+ Pairwise Task Generation with Cyclic Oversampling.
107
+ 目标:保留所有样本,不进行丢弃。对于数量较少的一方,循环重复使用以匹配数量较多的一方。
108
+ """
109
+ new_tasks = []
110
+ keys = sorted(list(self.queries_by_key.keys()))
111
+
112
+ if shuffle:
113
+ random.shuffle(keys)
114
+
115
+ dropped_steps = 0
116
+ total_pairs = 0
117
+
118
+ for key in keys:
119
+ samples = list(self.queries_by_key[key])
120
+
121
+ if self.mode == 'train':
122
+ # 1. 分离正负样本
123
+ pos_list = [x for x in samples if self._process_label(x['score']) >= 0.5]
124
+ neg_list = [x for x in samples if self._process_label(x['score']) < 0.5]
125
+
126
+ n_pos = len(pos_list)
127
+ n_neg = len(neg_list)
128
+
129
+ # 2. 如果某一方完全缺失,不得不跳过 (无法构建 Pair)
130
+ if n_pos == 0 or n_neg == 0:
131
+ dropped_steps += 1
132
+ continue
133
+
134
+ # 3. Shuffle (保证每次 Epoch 重复使用的样本是随机顺序的)
135
+ if shuffle:
136
+ random.shuffle(pos_list)
137
+ random.shuffle(neg_list)
138
+
139
+ # 4. Maximize Pairs via Cyclic Oversampling
140
+ # 取最大长度,保证所有样本至少被用到一次
141
+ n_pairs = max(n_pos, n_neg)
142
+
143
+ paired_samples = []
144
+ for i in range(n_pairs):
145
+ p = pos_list[i % n_pos]
146
+ n = neg_list[i % n_neg]
147
+
148
+ paired_samples.append(p)
149
+ paired_samples.append(n)
150
+
151
+ total_pairs += n_pairs
152
+
153
+ # 5. Chunking
154
+ # query_batch_size 必须是偶数
155
+ bs = self.query_batch_size
156
+ if bs % 2 != 0:
157
+ bs -= 1
158
+ if bs < 2: bs = 2
159
+
160
+ for i in range(0, len(paired_samples), bs):
161
+ chunk = paired_samples[i : i + bs]
162
+
163
+ # 丢弃末尾不完整的 Pair (极少发生,仅当 chunk 长度为奇数时)
164
+ if len(chunk) % 2 != 0:
165
+ chunk = chunk[:-1]
166
+
167
+ context_key_to_use = None
168
+ if key in self.context_pool and len(self.context_pool[key]) > 0:
169
+ context_key_to_use = key
170
+ else:
171
+ fallback_key = (key[1], key[2]) # (model, step)
172
+ if fallback_key in self.context_pool_fallback and len(self.context_pool_fallback[fallback_key]) > 0:
173
+ context_key_to_use = fallback_key
174
+
175
+ if len(chunk) > 0 and context_key_to_use is not None:
176
+ new_tasks.append({
177
+ 'key': key,
178
+ 'context_key': context_key_to_use,
179
+ 'queries': chunk,
180
+ 'is_pairwise': True
181
+ })
182
+
183
+ else:
184
+ if shuffle: random.shuffle(samples)
185
+ for i in range(0, len(samples), self.query_batch_size):
186
+ chunk = samples[i : i + self.query_batch_size]
187
+ context_key_to_use = None
188
+ if key in self.context_pool and len(self.context_pool[key]) > 0:
189
+ context_key_to_use = key
190
+ else:
191
+ fallback_key = (key[1], key[2]) # (model, step)
192
+ if fallback_key in self.context_pool_fallback and len(self.context_pool_fallback[fallback_key]) > 0:
193
+ context_key_to_use = fallback_key
194
+
195
+ if context_key_to_use is not None:
196
+ new_tasks.append({
197
+ 'key': key,
198
+ 'context_key': context_key_to_use,
199
+ 'queries': chunk,
200
+ 'is_pairwise': False
201
+ })
202
+
203
+ self.tasks = new_tasks
204
+ if self.mode == 'train':
205
+ print(f" >>> [Dataset] Generated {len(self.tasks)} tasks from {len(keys)} contexts.")
206
+ print(f" >>> [Pairwise Stats] Total Pairs: {total_pairs} (Using Oversampling). Dropped Steps (0 pos or 0 neg): {dropped_steps}")
207
+
208
+ def _process_label(self, reward):
209
+ val = float(reward)
210
+ if self.label_strategy == "binary":
211
+ return 1.0 if val >= 0 else 0.0
212
+ elif self.label_strategy == "minmax_norm":
213
+ return (np.clip(val, -1.0, 1.0) + 1.0) / 2.0
214
+ return val
215
+
216
+ def __len__(self):
217
+ return len(self.tasks)
218
+
219
+ def __getitem__(self, idx):
220
+ task = self.tasks[idx]
221
+ key = task['key'] # (dataset, model, step)
222
+ query_samples = task['queries']
223
+
224
+ # 1. Sample Context
225
+ context_key = task.get('context_key', key)
226
+ available_context = self.context_pool[key] if context_key == key else self.context_pool_fallback[context_key]
227
+
228
+ if len(available_context) >= self.support_size:
229
+ support_samples = random.sample(available_context, self.support_size)
230
+ else:
231
+ support_samples = available_context
232
+
233
+ # 2. Format Data
234
+ prompts = []
235
+ labels = []
236
+
237
+ # Process Support
238
+ for item in support_samples:
239
+ s_id_str = f"{item['dataset']}_{item['id']}"
240
+ text = self.prompt_map[s_id_str]
241
+ if text:
242
+ prompts.append(text)
243
+ labels.append(self._process_label(item['score']))
244
+
245
+ split_idx = len(prompts) # Boundary
246
+
247
+ # Process Query
248
+ q_ids = []
249
+ pair_ids = []
250
+ pair_types = []
251
+
252
+ for item in query_samples:
253
+ prompts.append(item['text'])
254
+ labels.append(self._process_label(item['score']))
255
+ q_ids.append(item['id'])
256
+ if 'pair_id' in item:
257
+ pair_ids.append(item['pair_id'])
258
+ if 'pair_type' in item:
259
+ pair_types.append(item['pair_type'])
260
+
261
+ # 获取该Context的全局正负样本统计量
262
+ stats = self.context_stats.get(key, {'n_pos': 0, 'n_neg': 0})
263
+
264
+ return {
265
+ "prompts": prompts,
266
+ "labels": torch.tensor(labels, dtype=torch.float),
267
+ "split_idx": split_idx,
268
+ "q_ids": q_ids,
269
+ "pair_ids": pair_ids,
270
+ "pair_types": pair_types,
271
+ "key": key,
272
+ "stats": stats # Pass stats to collate
273
+ }
v0_core/data/utils.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ def load_jsonl_lines(paths):
5
+ """读取多个文件路径并将所有行合并为一个列表"""
6
+ all_lines = []
7
+ if not isinstance(paths, list): paths = [paths]
8
+ for p in paths:
9
+ if not p or not os.path.exists(p):
10
+ print(f"Warning: Path not found {p}")
11
+ continue
12
+ print(f"Loading {p}...")
13
+ try:
14
+ with open(p, 'r', encoding='utf-8') as f:
15
+ for line in f:
16
+ if line.strip():
17
+ all_lines.append(json.loads(line.strip()))
18
+ except Exception as e:
19
+ print(f"Error reading {p}: {e}")
20
+ return all_lines
v0_core/models/__init__.py ADDED
File without changes
v0_core/models/v0.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch import Tensor
5
+ from transformers import AutoTokenizer, AutoModel
6
+
7
+ # =============================================================================
8
+ # TabPFN 修复补丁
9
+ # =============================================================================
10
+ try:
11
+ from tabpfn import TabPFNClassifier
12
+ except ImportError as e:
13
+ print(f"导入 TabPFN 模块失败: {e}")
14
+ print("请确保已安装 tabpfn,并且处于包含 tabpfn 源代码的环境中。")
15
+ exit(1)
16
+
17
+ from v0_core.utils.tabpfn_patches import fixed_fit, fixed_forward
18
+
19
+ # Apply Patches
20
+ TabPFNClassifier.fit = fixed_fit
21
+ TabPFNClassifier.forward = fixed_forward
22
+ # print("已应用 TabPFNClassifier 的 fit 和 forward 最终修复补丁。")
23
+
24
+ # =============================================================================
25
+ # Qwen Official Pooling
26
+ # =============================================================================
27
+ def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
28
+ left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
29
+ if left_padding:
30
+ return last_hidden_states[:, -1]
31
+ else:
32
+ sequence_lengths = attention_mask.sum(dim=1) - 1
33
+ batch_size = last_hidden_states.shape[0]
34
+ return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
35
+
36
+ # =============================================================================
37
+ # Adapter 策略模块
38
+ # =============================================================================
39
+ class FixedQueryAdapter(nn.Module):
40
+ def __init__(self, input_dim, num_queries=10, embed_dim=32, num_heads=4):
41
+ super().__init__()
42
+ self.proj_kv = nn.Linear(input_dim, embed_dim)
43
+ self.queries = nn.Parameter(torch.randn(1, num_queries, embed_dim))
44
+ self.mha = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)
45
+ self.ln_q = nn.LayerNorm(embed_dim)
46
+ self.ln_kv = nn.LayerNorm(embed_dim)
47
+
48
+ def forward(self, hidden_states, attention_mask=None):
49
+ batch_size = hidden_states.size(0)
50
+ kv = self.proj_kv(hidden_states)
51
+ q = self.queries.repeat(batch_size, 1, 1)
52
+ key_padding_mask = ~attention_mask.bool() if attention_mask is not None else None
53
+ attn_out, _ = self.mha(query=self.ln_q(q), key=self.ln_kv(kv), value=kv, key_padding_mask=key_padding_mask)
54
+ return attn_out.reshape(batch_size, -1)
55
+
56
+ class DynamicQueryAdapter(nn.Module):
57
+ def __init__(self, input_dim, num_queries=10, embed_dim=32, num_heads=4, generator_bottleneck_dim=128, generator_dropout_rate=0.2):
58
+ super().__init__()
59
+ self.num_queries = num_queries
60
+ self.embed_dim = embed_dim
61
+ self.static_queries = nn.Parameter(torch.randn(1, num_queries, embed_dim))
62
+ self.generator = nn.Sequential(
63
+ nn.Linear(input_dim, generator_bottleneck_dim),
64
+ nn.LayerNorm(generator_bottleneck_dim),
65
+ nn.GELU(),
66
+ nn.Dropout(generator_dropout_rate),
67
+ nn.Linear(generator_bottleneck_dim, num_queries * embed_dim)
68
+ )
69
+ nn.init.zeros_(self.generator[-1].weight)
70
+ nn.init.zeros_(self.generator[-1].bias)
71
+ self.proj_kv = nn.Linear(input_dim, embed_dim)
72
+ self.mha = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)
73
+ self.ln_q = nn.LayerNorm(embed_dim)
74
+ self.ln_kv = nn.LayerNorm(embed_dim)
75
+
76
+ def forward(self, hidden_states, attention_mask):
77
+ batch_size = hidden_states.size(0)
78
+ v_global = last_token_pool(hidden_states, attention_mask)
79
+ delta_q = self.generator(v_global).view(batch_size, self.num_queries, self.embed_dim)
80
+ q_final = self.static_queries.repeat(batch_size, 1, 1) + delta_q
81
+ kv = self.proj_kv(hidden_states)
82
+ key_padding_mask = ~attention_mask.bool() if attention_mask is not None else None
83
+ attn_out, _ = self.mha(query=self.ln_q(q_final), key=self.ln_kv(kv), value=kv, key_padding_mask=key_padding_mask)
84
+ return attn_out.reshape(batch_size, -1)
85
+
86
+ # =============================================================================
87
+ # Qwen Embedding 模型封装
88
+ # =============================================================================
89
+ class QwenEmbeddingModel(nn.Module):
90
+ def __init__(self, model_path, pooling_type='last_token', num_queries=10, embed_dim=32,
91
+ reduce_method='avg_pool', target_dim=1024, num_heads=4, generator_bottleneck_dim=128, generator_dropout_rate=0.2, device='cuda'):
92
+ super().__init__()
93
+ self.device = device
94
+ self.pooling_type = pooling_type
95
+ self.reduce_method = reduce_method
96
+ self.target_dim = target_dim
97
+
98
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side='left')
99
+ self.backbone = AutoModel.from_pretrained(model_path, trust_remote_code=True).to(device)
100
+ self.backbone.train()
101
+
102
+ with torch.no_grad(): hidden_size = self.backbone.config.hidden_size
103
+
104
+ if self.pooling_type == 'fixed_query':
105
+ self.adapter_layer = FixedQueryAdapter(input_dim=hidden_size, num_queries=num_queries, embed_dim=embed_dim, num_heads=num_heads).to(device)
106
+ elif self.pooling_type == 'dynamic_query':
107
+ self.adapter_layer = DynamicQueryAdapter(input_dim=hidden_size, num_queries=num_queries, embed_dim=embed_dim, num_heads=num_heads, generator_bottleneck_dim=generator_bottleneck_dim, generator_dropout_rate=generator_dropout_rate).to(device)
108
+ elif self.pooling_type == 'last_token':
109
+ self.adapter_layer = last_token_pool
110
+
111
+ def forward(self, prompts, batch_size=32):
112
+ embeddings = []
113
+ for i in range(0, len(prompts), batch_size):
114
+ batch_prompts = prompts[i : i + batch_size]
115
+ batch_dict = self.tokenizer(batch_prompts, max_length=2048, padding=True, truncation=True, return_tensors="pt").to(self.device)
116
+ with torch.no_grad():
117
+ outputs = self.backbone(**batch_dict)
118
+ last_hidden_state = outputs.last_hidden_state
119
+ emb = self.adapter_layer(last_hidden_state, batch_dict['attention_mask'])
120
+
121
+ if self.reduce_method == 'avg_pool' and emb.shape[1] > self.target_dim:
122
+ emb = F.adaptive_avg_pool1d(emb.unsqueeze(1), self.target_dim).squeeze(1)
123
+ elif self.reduce_method == 'max_pool' and emb.shape[1] > self.target_dim:
124
+ emb = F.adaptive_max_pool1d(emb.unsqueeze(1), self.target_dim).squeeze(1)
125
+ embeddings.append(emb)
126
+ return torch.cat(embeddings, dim=0)
127
+
128
+
129
+ class V0:
130
+ def __init__(self, embedding_model, tabpfn_model, device):
131
+ self.embedding_model = embedding_model
132
+ self.tabpfn = tabpfn_model
133
+ self.device = device
134
+
135
+ @classmethod
136
+ def from_pretrained(cls,
137
+ checkpoint_path,
138
+ embedding_model_path,
139
+ tabpfn_head_path,
140
+ device="cuda",
141
+ num_queries=168,
142
+ embed_dim=6,
143
+ num_heads=3,
144
+ bottleneck_dim=128,
145
+ tabpfn_estimators=4):
146
+
147
+ # 1. Initialize Embedding Model (Qwen + Adapter)
148
+ embedding_model = QwenEmbeddingModel(
149
+ model_path=embedding_model_path,
150
+ num_queries=num_queries,
151
+ embed_dim=embed_dim,
152
+ num_heads=num_heads,
153
+ generator_bottleneck_dim=bottleneck_dim,
154
+ generator_dropout_rate=0.0, # Dropout not needed for inference
155
+ device=device
156
+ )
157
+
158
+ # 2. Load Trained Weights (Adapter + potentially Backbone)
159
+ ckpt = torch.load(checkpoint_path, map_location=device)
160
+ state_dict = ckpt['model_state_dict']
161
+
162
+ # Clean DDP 'module.' prefix if present
163
+ if list(state_dict.keys())[0].startswith('module.'):
164
+ state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
165
+
166
+ # Load weights
167
+ msg = embedding_model.load_state_dict(state_dict, strict=False)
168
+
169
+ # 3. Initialize TabPFN
170
+ tabpfn = TabPFNClassifier(
171
+ model_path=tabpfn_head_path,
172
+ device=device,
173
+ n_estimators=tabpfn_estimators,
174
+ inference_precision=torch.float32,
175
+ differentiable_input=True # As per training script
176
+ )
177
+ # Manual init to ensure weights are loaded
178
+ tabpfn._initialize_model_variables()
179
+
180
+ return cls(embedding_model, tabpfn, device)
181
+
182
+ def predict(self, context_prompts, context_labels, target_prompts, batch_size=32):
183
+ """
184
+ Args:
185
+ context_prompts: List[str] - Support Set Texts
186
+ context_labels: List[float] - Support Set Scores (0.0 to 1.0)
187
+ target_prompts: List[str] - Query Set Texts to be scored
188
+ Returns:
189
+ scores: List[float] - Predicted scores (probability of class 1)
190
+ """
191
+ # 1. Encode Context (Support Set)
192
+ X_sup = self.embedding_model(context_prompts, batch_size=batch_size)
193
+
194
+ # 2. Process Labels (Training script logic: >= 0.5 is Positive)
195
+ y_sup = torch.tensor(context_labels, device=self.device)
196
+ y_sup_hard = (y_sup >= 0.5).long() # Convert to class indices 0 or 1
197
+
198
+ # 3. Fit TabPFN (In-Context Learning)
199
+ # TabPFN learns from this specific batch of context
200
+ self.tabpfn.fit(X_sup, y_sup_hard)
201
+
202
+ # 4. Encode Targets (Query Set)
203
+ X_que = self.embedding_model(target_prompts)
204
+
205
+ # 5. Predict
206
+ # use_inference_mode=True as per eval logic in run_epoch
207
+ with torch.no_grad():
208
+ logits = self.tabpfn.forward(X_que, use_inference_mode=True, return_logits=True)
209
+ probs = torch.softmax(logits, dim=1)
210
+
211
+ # Return probability of the positive class (class 1)
212
+ # If batch size is 1, output might be squeezed, handling that:
213
+ if probs.dim() == 1:
214
+ return [probs[1].item()]
215
+ else:
216
+ return probs[:, 1].tolist()
v0_core/utils/__init__.py ADDED
File without changes
v0_core/utils/checkpoint.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import glob
4
+ import torch
5
+
6
+ # =============================================================================
7
+ # Checkpoint 管理器
8
+ # =============================================================================
9
+ class CheckpointManager:
10
+ def __init__(self, checkpoint_dir, max_keep=2, is_master=False):
11
+ self.checkpoint_dir = checkpoint_dir
12
+ self.max_keep = max_keep
13
+ self.is_master = is_master
14
+ if self.is_master and self.checkpoint_dir:
15
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
16
+
17
+ def save(self, model, optimizer, scheduler, epoch, args, wandb_run_id=None):
18
+ if not self.is_master or not self.checkpoint_dir: return
19
+ raw_model = model.module if hasattr(model, 'module') else model
20
+ state = {
21
+ 'epoch': epoch,
22
+ 'model_state_dict': raw_model.state_dict(),
23
+ 'optimizer_state_dict': optimizer.state_dict() if optimizer else None,
24
+ 'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
25
+ 'args': vars(args),
26
+ 'wandb_run_id': wandb_run_id
27
+ }
28
+ filename = f"checkpoint_epoch_{epoch:04d}.pt"
29
+ filepath = os.path.join(self.checkpoint_dir, filename)
30
+ tmp_filepath = filepath + ".tmp"
31
+ print(f">> Saving Checkpoint to {filepath} (Atomic)...")
32
+ try:
33
+ # 1. 先写入临时文件
34
+ torch.save(state, tmp_filepath)
35
+ # 2. 强制刷盘,确保数据落盘
36
+ if os.path.exists(tmp_filepath):
37
+ with open(tmp_filepath, 'rb') as f:
38
+ os.fsync(f.fileno())
39
+ # 3. 原子重命名 (如果掉电发生在这里之前,旧文件还在;之后,新文件生效)
40
+ os.replace(tmp_filepath, filepath)
41
+ except Exception as e:
42
+ print(f"Error saving checkpoint: {e}")
43
+ if os.path.exists(tmp_filepath):
44
+ os.remove(tmp_filepath)
45
+ return
46
+ self._rotate_checkpoints()
47
+
48
+ def _rotate_checkpoints(self):
49
+ # 保持原逻辑不变,但增加健壮性检查
50
+ files = glob.glob(os.path.join(self.checkpoint_dir, "checkpoint_epoch_*.pt"))
51
+ # 过滤掉 .tmp 文件
52
+ files = [f for f in files if not f.endswith('.tmp')]
53
+
54
+ def extract_epoch(f):
55
+ try:
56
+ match = re.search(r"epoch_(\d+)", f)
57
+ return int(match.group(1)) if match else -1
58
+ except: return -1
59
+
60
+ files.sort(key=extract_epoch)
61
+ if len(files) > self.max_keep:
62
+ to_delete = files[: -self.max_keep]
63
+ for f in to_delete:
64
+ try:
65
+ print(f"Removing old checkpoint: {f}")
66
+ os.remove(f)
67
+ except OSError as e:
68
+ print(f"Error removing {f}: {e}")
69
+
70
+ def find_latest_epoch_num(self):
71
+ if not self.checkpoint_dir or not os.path.exists(self.checkpoint_dir): return 0
72
+ files = glob.glob(os.path.join(self.checkpoint_dir, "checkpoint_epoch_*.pt"))
73
+ files = [f for f in files if not f.endswith('.tmp')]
74
+ if not files: return 0
75
+ def extract_epoch(f):
76
+ match = re.search(r"epoch_(\d+)", f)
77
+ return int(match.group(1)) if match else -1
78
+ files.sort(key=extract_epoch)
79
+ return extract_epoch(files[-1])
80
+
81
+ def load_specific_epoch(self, target_epoch, model, optimizer, scheduler, device):
82
+ if target_epoch <= 0: return 1
83
+ filename = f"checkpoint_epoch_{target_epoch:04d}.pt"
84
+ filepath = os.path.join(self.checkpoint_dir, filename)
85
+ if not os.path.exists(filepath):
86
+ import time
87
+ print(f">> [Warning] Checkpoint {filepath} not found immediately. Waiting for FS sync...")
88
+ time.sleep(5)
89
+ if not os.path.exists(filepath): raise FileNotFoundError(f"Checkpoint {filepath} does not exist.")
90
+ print(f">> Resuming from checkpoint: {filepath}")
91
+ checkpoint = torch.load(filepath, map_location=device)
92
+
93
+ state_dict = checkpoint['model_state_dict']
94
+ raw_model = model.module if hasattr(model, 'module') else model
95
+
96
+ # 检查是否 key 不匹配 (例如保存时有 module. 读取时没有,或者反之)
97
+ model_keys = set(raw_model.state_dict().keys())
98
+ ckpt_keys = set(state_dict.keys())
99
+
100
+ # 简单的 key 修正逻辑
101
+ if list(model_keys)[0].startswith('module.') and not list(ckpt_keys)[0].startswith('module.'):
102
+ state_dict = {f"module.{k}": v for k, v in state_dict.items()}
103
+ elif not list(model_keys)[0].startswith('module.') and list(ckpt_keys)[0].startswith('module.'):
104
+ state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
105
+
106
+ raw_model.load_state_dict(state_dict)
107
+
108
+ if optimizer is not None and 'optimizer_state_dict' in checkpoint:
109
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
110
+ if scheduler is not None and 'scheduler_state_dict' in checkpoint:
111
+ scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
112
+
113
+ start_epoch = checkpoint['epoch'] + 1
114
+ wandb_id = checkpoint.get('wandb_run_id', None)
115
+
116
+ print(f"✅ Successfully resumed. Next epoch: {start_epoch}. WandB ID: {wandb_id}")
117
+ return start_epoch, wandb_id
v0_core/utils/metrics.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from collections import defaultdict
4
+ from sklearn.metrics import roc_auc_score, accuracy_score
5
+ import torch.distributed as dist
6
+
7
+ def append_jsonl(path, data):
8
+ try:
9
+ with open(path, 'a', encoding='utf-8') as f:
10
+ f.write(json.dumps(data, ensure_ascii=False) + '\n')
11
+ except Exception as e:
12
+ print(f"Error appending to jsonl: {e}")
13
+
14
+ # =============================================================================
15
+ # Global Metrics Calculation & Aggregation
16
+ # =============================================================================
17
+ def calculate_metrics_by_group(all_results, phase, epoch, is_master=True, output_dir=None, dataset_name_tag="", avg_loss=None):
18
+
19
+ # 1. Gather from all ranks
20
+ world_size = dist.get_world_size()
21
+ gathered_results = [None for _ in range(world_size)]
22
+ dist.all_gather_object(gathered_results, all_results)
23
+
24
+ if not is_master:
25
+ return {}
26
+
27
+ # Flatten list of lists
28
+ flat_results = []
29
+ for rank_res in gathered_results:
30
+ flat_results.extend(rank_res)
31
+
32
+ print(f"[{phase}] Collected {len(flat_results)} samples for evaluation.")
33
+
34
+ if len(flat_results) == 0:
35
+ return {}
36
+
37
+ metrics_summary = {"epoch": epoch}
38
+ # =========================================================================
39
+ # Part A: Pair-wise Metrics Calculation (Global)
40
+ # =========================================================================
41
+
42
+ pair_grouping = defaultdict(lambda: {'pos': [], 'neg': []})
43
+
44
+ for r in flat_results:
45
+ pid = r.get('pair_id')
46
+ if pid is not None:
47
+ if r['label'] >= 0.5:
48
+ pair_grouping[pid]['pos'].append(r)
49
+ else:
50
+ pair_grouping[pid]['neg'].append(r)
51
+
52
+ valid_pairs = []
53
+
54
+ for pid, group in pair_grouping.items():
55
+ if len(group['pos']) == 1 and len(group['neg']) == 1:
56
+ valid_pairs.append((group['pos'][0], group['neg'][0]))
57
+
58
+ total_valid_pairs = len(valid_pairs)
59
+ strict_pair_correct_count = 0
60
+ rlhf_pair_correct_count = 0
61
+
62
+ # 3. Calculate Global Pair Metrics
63
+ for pos_item, neg_item in valid_pairs:
64
+ # Strict
65
+ if (pos_item['pred'] == 1) and (neg_item['pred'] == 0):
66
+ strict_pair_correct_count += 1
67
+ # RLHF
68
+ if pos_item['prob'] > neg_item['prob']:
69
+ rlhf_pair_correct_count += 1
70
+
71
+ metrics_summary[f"{phase}/global_strict_pair_acc"] = strict_pair_correct_count / total_valid_pairs if total_valid_pairs > 0 else -1
72
+ metrics_summary[f"{phase}/global_rlhf_pair_acc"] = rlhf_pair_correct_count / total_valid_pairs if total_valid_pairs > 0 else -1
73
+ metrics_summary[f"{phase}/num_valid_pairs"] = total_valid_pairs
74
+
75
+ # =========================================================================
76
+ # Part B: Standard Global Metrics (Acc / AUC)
77
+ # =========================================================================
78
+ y_true_binary = [1 if r['label'] >= 0.5 else 0 for r in flat_results]
79
+ y_scores = [r['prob'] for r in flat_results]
80
+ y_preds = [r['pred'] for r in flat_results]
81
+
82
+ def get_auc_strict(y_t, y_s):
83
+ try:
84
+ return roc_auc_score(y_t, y_s) if len(set(y_t)) > 1 else -1
85
+ except:
86
+ return -1
87
+
88
+ g_auc = get_auc_strict(y_true_binary, y_scores)
89
+ metrics_summary[f"{phase}/global_acc"] = accuracy_score(y_true_binary, y_preds)
90
+ metrics_summary[f"{phase}/global_auc"] = g_auc
91
+
92
+ if avg_loss is not None:
93
+ metrics_summary[f"{phase}/loss"] = avg_loss
94
+
95
+ # =========================================================================
96
+ # Part C: Step-wise Metrics
97
+ # =========================================================================
98
+ step_groups = defaultdict(list)
99
+ for r in flat_results:
100
+ step_groups[r['step']].append(r)
101
+
102
+ step_valid_pairs = defaultdict(list)
103
+ for pos_item, neg_item in valid_pairs:
104
+ if pos_item['step'] == neg_item['step']:
105
+ step_valid_pairs[pos_item['step']].append((pos_item, neg_item))
106
+
107
+ print(f"[{phase}] Calculating metrics for {len(step_groups)} distinct steps...")
108
+
109
+ gauc_weighted_sum = 0.0
110
+ gauc_total_weight = 0.0
111
+ valid_gauc_steps = 0
112
+
113
+ step_details_list = []
114
+
115
+ for step_val, items in step_groups.items():
116
+ s_true = [1 if x['label'] >= 0.5 else 0 for x in items]
117
+ s_scores = [x['prob'] for x in items]
118
+ s_preds = [x['pred'] for x in items]
119
+
120
+ # 1. Basic Step Metrics
121
+ step_acc = accuracy_score(s_true, s_preds)
122
+ step_auc = get_auc_strict(s_true, s_scores) # Returns None if only 1 class
123
+
124
+ step_record = {
125
+ "step": step_val,
126
+ "count": len(items),
127
+ "acc": step_acc,
128
+ "auc": step_auc
129
+ }
130
+
131
+ if step_auc != -1:
132
+ weight = len(items)
133
+ gauc_weighted_sum += step_auc * weight
134
+ gauc_total_weight += weight
135
+ valid_gauc_steps += 1
136
+
137
+ # 3. Step Pair Metrics
138
+ pairs_in_step = step_valid_pairs.get(step_val, [])
139
+ n_pairs = len(pairs_in_step)
140
+
141
+ if n_pairs > 0:
142
+ s_strict_corr = sum(1 for p, n in pairs_in_step if (p['pred'] == 1 and n['pred'] == 0))
143
+ s_rlhf_corr = sum(1 for p, n in pairs_in_step if p['prob'] > n['prob'])
144
+
145
+ step_record["pair_count"] = n_pairs
146
+ step_record["strict_pair_acc"] = s_strict_corr / n_pairs
147
+ step_record["rlhf_pair_acc"] = s_rlhf_corr / n_pairs
148
+ else:
149
+ step_record["pair_count"] = 0
150
+ step_record["strict_pair_acc"] = -1
151
+ step_record["rlhf_pair_acc"] = -1
152
+
153
+ step_details_list.append(step_record)
154
+
155
+ # Calculate Weighted gAUC
156
+ final_gauc = gauc_weighted_sum / gauc_total_weight if gauc_total_weight > 0 else -1
157
+
158
+ metrics_summary[f"{phase}/gAUC"] = final_gauc
159
+ metrics_summary[f"{phase}/gAUC_valid_steps"] = valid_gauc_steps
160
+
161
+ print(f"[{phase}] gAUC: {final_gauc:.4f} (Computed over {valid_gauc_steps} valid steps out of {len(step_groups)})")
162
+
163
+ # =========================================================================
164
+ # Part D: Save Logs
165
+ # =========================================================================
166
+ if output_dir:
167
+ # 1. Save Raw Predictions (Keep as is)
168
+ log_filename = f"{phase}_predictions_epoch_{epoch}{dataset_name_tag}.jsonl"
169
+ log_path = os.path.join(output_dir, log_filename)
170
+ valid_pair_ids = set(p[0]['pair_id'] for p in valid_pairs)
171
+
172
+ print(f"Saving raw predictions to {log_path}...")
173
+ with open(log_path, 'w', encoding='utf-8') as f:
174
+ for item in flat_results:
175
+ item['is_valid_pair_part'] = item.get('pair_id') in valid_pair_ids
176
+ f.write(json.dumps(item, ensure_ascii=False) + '\n')
177
+
178
+ # 2. Save Global Metrics (Only summary)
179
+ metric_filename = "all_metrics.jsonl"
180
+ metric_path = os.path.join(output_dir, metric_filename)
181
+ append_jsonl(metric_path, metrics_summary)
182
+
183
+ # 3. [NEW] Save Step-wise Details to a separate file
184
+ step_log_filename = f"{phase}_step_metrics_epoch_{epoch}{dataset_name_tag}.jsonl"
185
+ step_log_path = os.path.join(output_dir, step_log_filename)
186
+ print(f"Saving step-wise metrics to {step_log_path}...")
187
+
188
+ # Sort by step for readability
189
+ step_details_list.sort(key=lambda x: x['step'] if isinstance(x['step'], int) else -1)
190
+
191
+ with open(step_log_path, 'w', encoding='utf-8') as f:
192
+ for item in step_details_list:
193
+ f.write(json.dumps(item, ensure_ascii=False) + '\n')
194
+
195
+ return metrics_summary
v0_core/utils/tabpfn_patches.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import logging
3
+ import numpy as np
4
+ try:
5
+ from tabpfn import TabPFNClassifier
6
+ from tabpfn.base import create_inference_engine, determine_precision
7
+ from tabpfn.utils import infer_random_state
8
+ from tabpfn.classifier import _validate_eval_metric
9
+ from tabpfn.inference import InferenceEngineBatchedNoPreprocessing
10
+ except ImportError as e:
11
+ print(f"导入 TabPFN 模块失败: {e}")
12
+ print("请确保已安装 tabpfn,并且处于包含 tabpfn 源代码的环境中。")
13
+ exit(1)
14
+
15
+ def fixed_fit(self, X, y) -> "TabPFNClassifier":
16
+ """修复 fit 方法:解决 differentiable_input=True 时 ensemble_configs 未定义的问题"""
17
+ self.eval_metric_ = _validate_eval_metric(self.eval_metric)
18
+
19
+ if self.fit_mode == "batched":
20
+ logging.warning("Switching from 'batched' to 'fit_preprocessors' mode...")
21
+ self.fit_mode = "fit_preprocessors"
22
+
23
+ if not hasattr(self, "models_") or not self.differentiable_input:
24
+ byte_size, rng = self._initialize_model_variables()
25
+ ensemble_configs, X, y = self._initialize_dataset_preprocessing(X, y, rng)
26
+ else:
27
+ _, rng = infer_random_state(self.random_state)
28
+ _, _, byte_size = determine_precision(self.inference_precision, self.devices_)
29
+ ensemble_configs, X, y = self._initialize_dataset_preprocessing(X, y, rng)
30
+
31
+ self._maybe_calibrate_temperature_and_tune_decision_thresholds(X=X, y=y)
32
+
33
+ self.executor_ = create_inference_engine(
34
+ X_train=X,
35
+ y_train=y,
36
+ models=self.models_,
37
+ ensemble_configs=ensemble_configs,
38
+ cat_ix=self.inferred_categorical_indices_,
39
+ fit_mode=self.fit_mode,
40
+ devices_=self.devices_,
41
+ rng=rng,
42
+ n_preprocessing_jobs=self.n_preprocessing_jobs,
43
+ byte_size=byte_size,
44
+ forced_inference_dtype_=self.forced_inference_dtype_,
45
+ memory_saving_mode=self.memory_saving_mode,
46
+ use_autocast_=self.use_autocast_,
47
+ inference_mode=not self.differentiable_input,
48
+ )
49
+ return self
50
+
51
+ def fixed_forward(
52
+ self,
53
+ X: list[torch.Tensor] | torch.Tensor,
54
+ *,
55
+ use_inference_mode: bool = False,
56
+ return_logits: bool = False,
57
+ return_raw_logits: bool = False,
58
+ ) -> torch.Tensor:
59
+ """修复 forward 方法:允许 standard inference 下保留梯度"""
60
+ if return_logits and return_raw_logits:
61
+ raise ValueError("Cannot return both logits and raw logits.")
62
+
63
+ is_standard_inference = not isinstance(
64
+ self.executor_, InferenceEngineBatchedNoPreprocessing
65
+ )
66
+ is_batched_for_grads = (
67
+ not use_inference_mode
68
+ and isinstance(self.executor_, InferenceEngineBatchedNoPreprocessing)
69
+ and isinstance(X, list)
70
+ )
71
+
72
+ assert is_standard_inference or is_batched_for_grads, "Invalid forward pass."
73
+
74
+ if self.fit_mode in ["fit_preprocessors", "batched"]:
75
+ self.executor_.use_torch_inference_mode(use_inference=use_inference_mode)
76
+
77
+ outputs = []
78
+ for output, config in self.executor_.iter_outputs(X, autocast=self.use_autocast_):
79
+ processed_output = output.unsqueeze(1) if output.ndim == 2 else output
80
+ config_list = [config] if output.ndim == 2 else config
81
+
82
+ output_batch = []
83
+ for i, batch_config in enumerate(config_list):
84
+ if batch_config.class_permutation is None:
85
+ output_batch.append(processed_output[:, i, : self.n_classes_])
86
+ else:
87
+ use_perm = batch_config.class_permutation
88
+ if len(use_perm) != self.n_classes_:
89
+ full_perm = np.arange(self.n_classes_)
90
+ full_perm[:len(use_perm)] = use_perm
91
+ use_perm = full_perm
92
+ output_batch.append(processed_output[:, i, use_perm])
93
+ outputs.append(torch.stack(output_batch, dim=1))
94
+
95
+ stacked_outputs = torch.stack(outputs) # (Chunks, Samples, Est, Classes)
96
+
97
+ if return_logits:
98
+ temp_scaled = self._apply_temperature(stacked_outputs)
99
+ output = temp_scaled.mean(dim=(0, 2))
100
+ elif return_raw_logits:
101
+ output = stacked_outputs
102
+ else:
103
+ temp_scaled = self._apply_temperature(stacked_outputs)
104
+ avg_logits = temp_scaled.mean(dim=(0, 2))
105
+ output = torch.nn.functional.softmax(avg_logits, dim=-1)
106
+
107
+ if not use_inference_mode:
108
+ if return_logits and output.ndim == 2:
109
+ return output
110
+ if output.ndim == 2:
111
+ output = output.unsqueeze(0)
112
+ output = output.transpose(0, 1).transpose(1, 2)
113
+ elif output.ndim > 2 and use_inference_mode:
114
+ output = output.squeeze(1) if not return_raw_logits else output.squeeze(2)
115
+
116
+ return output