Commit ·
89c9672
1
Parent(s): 8518eef
Upload V0 model and UI
Browse files- app.py +425 -0
- requirements.txt +10 -0
- v0_core/config/__init__.py +0 -0
- v0_core/config/arguments.py +133 -0
- v0_core/data/__init__.py +0 -0
- v0_core/data/collator.py +27 -0
- v0_core/data/dataset.py +273 -0
- v0_core/data/utils.py +20 -0
- v0_core/models/__init__.py +0 -0
- v0_core/models/v0.py +216 -0
- v0_core/utils/__init__.py +0 -0
- v0_core/utils/checkpoint.py +117 -0
- v0_core/utils/metrics.py +195 -0
- v0_core/utils/tabpfn_patches.py +116 -0
app.py
ADDED
|
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import torch
|
| 6 |
+
from huggingface_hub import hf_hub_download, snapshot_download # 引入下载工具
|
| 7 |
+
|
| 8 |
+
# ==========================================
|
| 9 |
+
# 0. 模型初始化
|
| 10 |
+
# ==========================================
|
| 11 |
+
MODEL_REPO_ID = "Now-Join-Us/Generalist-Value-Model-V0"
|
| 12 |
+
EMBEDDING_REPO_ID = "Qwen/Qwen3-Embedding-0.6B"
|
| 13 |
+
|
| 14 |
+
v0_model = None
|
| 15 |
+
|
| 16 |
+
print(">>> Starting V0 App...")
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
from v0_core.models.v0 import V0
|
| 20 |
+
|
| 21 |
+
print(f">>> Downloading models...")
|
| 22 |
+
|
| 23 |
+
# 1. 下载你的训练权重
|
| 24 |
+
checkpoint_path = hf_hub_download(
|
| 25 |
+
repo_id=MODEL_REPO_ID,
|
| 26 |
+
filename="v_0_for_grpo_training.pt"
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
# 2. 下载 TabPFN
|
| 30 |
+
tabpfn_path = hf_hub_download(
|
| 31 |
+
repo_id=MODEL_REPO_ID,
|
| 32 |
+
filename="tabpfn-v2.5-classifier-v2.5_default.ckpt"
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
# 3. 下载 Qwen Embedding
|
| 36 |
+
embedding_path = snapshot_download(
|
| 37 |
+
repo_id=EMBEDDING_REPO_ID
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
print(">>> Models downloaded. Initializing V0 class...")
|
| 41 |
+
|
| 42 |
+
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 43 |
+
device = "cpu"
|
| 44 |
+
|
| 45 |
+
# 加载模型
|
| 46 |
+
v0_model = V0.from_pretrained(
|
| 47 |
+
checkpoint_path=checkpoint_path,
|
| 48 |
+
embedding_model_path=embedding_path,
|
| 49 |
+
tabpfn_head_path=tabpfn_path,
|
| 50 |
+
device=device
|
| 51 |
+
)
|
| 52 |
+
print(f">>> V0 Model Loaded Successfully on {device}!")
|
| 53 |
+
|
| 54 |
+
except Exception as e:
|
| 55 |
+
print(f"Error loading model: {e}")
|
| 56 |
+
print("UI will run in Mock Mode.")
|
| 57 |
+
v0_model = None
|
| 58 |
+
|
| 59 |
+
# ==========================================
|
| 60 |
+
# 1. 核心逻辑
|
| 61 |
+
# ==========================================
|
| 62 |
+
|
| 63 |
+
# 默认数据 (作为 Context)
|
| 64 |
+
history_default = [
|
| 65 |
+
{"prompt": "Let $d(m)$ denote the number of positive integer divisors of a positive integer $m$. If $r$ is the number of integers $n \\leq 2023$ for which $\\sum_{i=1}^{n} d(i)$ is odd, find the sum of the digits of $r$.", "is_correct": True},
|
| 66 |
+
{"prompt": "设在 $5 \\times 5$ 的方格表的第 $i$ 行第 $j$ 列所填的数为 $a_{i j}\\left(a_{i j} \\in\\{0,1\\}\\right), a_{i j}=a_{j i}(1 \\leqslant i、j \\leqslant 5)$ .则表中共有五个 1 的填表方法总数为 $\\qquad$ (用具体数字作答).", "is_correct": True},
|
| 67 |
+
{"prompt": "Suppose $x, y \\in \\mathbb{Z}$ satisfy the equation:\n\\[\ny^4 + 4y^3 + 28y + 8x^3 + 6y^2 + 32x + 1 = (x^2 - y^2)(x^2 + y^2 + 24).\n\\]\nFind the sum of all possible values of $|xy|$.", "is_correct": False},
|
| 68 |
+
{"prompt": "Three builders are scheduled to build a house in 60 days. However, they procrastinate and do nothing for the first 50 days. To complete the house on time, they decide to hire more workers and work at twice their original speed. If the new workers also work at this doubled rate, how many new workers are needed? Assume each builder works at the same rate and does not interfere with the others.", "is_correct": True},
|
| 69 |
+
{"prompt": "Let $P_0 = (3,1)$ and define $P_{n+1} = (x_n, y_n)$ for $n \\ge 0$ by \\[ x_{n+1} = - \\frac{3x_n - y_n}{2}, \\quad y_{n+1} = - \\frac{x_n + y_n}{2} \\] Find the area of the quadrilateral formed by the points $P_{96}, P_{97}, P_{98}, P_{99}$.", "is_correct": False}
|
| 70 |
+
]
|
| 71 |
+
|
| 72 |
+
def format_model_card(data_list, model_name, is_custom=False):
|
| 73 |
+
if not data_list:
|
| 74 |
+
if is_custom:
|
| 75 |
+
return f"<div class='model-card empty'><div class='card-title'>No Custom Model Uploaded</div></div>"
|
| 76 |
+
return ""
|
| 77 |
+
|
| 78 |
+
total = len(data_list)
|
| 79 |
+
rows_html = ""
|
| 80 |
+
preview_limit = 3
|
| 81 |
+
preview_data = data_list[:preview_limit]
|
| 82 |
+
|
| 83 |
+
for item in preview_data:
|
| 84 |
+
p_text = item.get('prompt', '')
|
| 85 |
+
if len(p_text) > 64:
|
| 86 |
+
p_text = p_text[:64] + "..."
|
| 87 |
+
|
| 88 |
+
is_acc = item.get('is_correct', False)
|
| 89 |
+
status_class = "status-green" if is_acc else "status-red"
|
| 90 |
+
icon = "✔" if is_acc else "✘"
|
| 91 |
+
|
| 92 |
+
rows_html += f"""
|
| 93 |
+
<div class='history-row'>
|
| 94 |
+
<div class='status-box {status_class}'>{icon}</div>
|
| 95 |
+
<div class='prompt-text'>{p_text}</div>
|
| 96 |
+
</div>
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
remaining = total - preview_limit
|
| 100 |
+
if remaining > 0:
|
| 101 |
+
rows_html += f"<div class='history-more'>+ {remaining} more items</div>"
|
| 102 |
+
|
| 103 |
+
return f"""
|
| 104 |
+
<div class='model-card populated'>
|
| 105 |
+
<div class='card-header'>
|
| 106 |
+
<span class='model-name'>{model_name}</span>
|
| 107 |
+
<span class='acc-badge'>Total Samples: {total}</span>
|
| 108 |
+
</div>
|
| 109 |
+
<div class='card-body'>
|
| 110 |
+
<div class='history-container'>{rows_html}</div>
|
| 111 |
+
</div>
|
| 112 |
+
</div>
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
def process_upload(file_obj):
|
| 116 |
+
if file_obj is None:
|
| 117 |
+
return None, format_model_card(None, "Custom", True)
|
| 118 |
+
|
| 119 |
+
content = []
|
| 120 |
+
try:
|
| 121 |
+
with open(file_obj.name, 'r', encoding='utf-8') as f:
|
| 122 |
+
for line in f:
|
| 123 |
+
line = line.strip()
|
| 124 |
+
if line:
|
| 125 |
+
json_obj = json.loads(line)
|
| 126 |
+
content.append(json_obj)
|
| 127 |
+
|
| 128 |
+
if not content: return None, "<div class='model-card empty'>File is empty</div>"
|
| 129 |
+
if 'is_correct' not in content[0]: return None, "<div class='model-card empty'>Missing 'is_correct' field</div>"
|
| 130 |
+
|
| 131 |
+
# 简单的验证逻辑
|
| 132 |
+
has_positive = any(item.get('is_correct') for item in content)
|
| 133 |
+
has_negative = any(not item.get('is_correct') for item in content)
|
| 134 |
+
|
| 135 |
+
if not (has_positive and has_negative):
|
| 136 |
+
return None, """
|
| 137 |
+
<div class='model-card empty' style='border-color: var(--fail); color: var(--fail);'>
|
| 138 |
+
<div class='card-title'>Invalid Dataset Distribution</div>
|
| 139 |
+
<div class='card-subtitle'>Please upload at least one positive AND one negative sample.</div>
|
| 140 |
+
</div>
|
| 141 |
+
"""
|
| 142 |
+
|
| 143 |
+
return content, format_model_card(content, "Custom Model")
|
| 144 |
+
|
| 145 |
+
except json.JSONDecodeError:
|
| 146 |
+
return None, f"<div class='model-card empty'>Invalid JSONL Format</div>"
|
| 147 |
+
except Exception as e:
|
| 148 |
+
return None, f"<div class='model-card empty'>Error: {str(e)}</div>"
|
| 149 |
+
|
| 150 |
+
def predict_performance(default_data, custom_data, t1, t2, t3):
|
| 151 |
+
"""
|
| 152 |
+
使用加载的 V0 模型进行预测。
|
| 153 |
+
"""
|
| 154 |
+
targets = [t for t in [t1, t2, t3] if t.strip()]
|
| 155 |
+
if not targets:
|
| 156 |
+
return pd.DataFrame([{"Error": "Please enter at least one target prompt."}])
|
| 157 |
+
|
| 158 |
+
models_to_run = []
|
| 159 |
+
if default_data:
|
| 160 |
+
models_to_run.append(("Qwen3-4B-Instruct-2507", default_data))
|
| 161 |
+
if custom_data:
|
| 162 |
+
models_to_run.append(("Custom Uploaded Model", custom_data))
|
| 163 |
+
|
| 164 |
+
results = []
|
| 165 |
+
|
| 166 |
+
for m_name, m_history in models_to_run:
|
| 167 |
+
context_prompts = [item['prompt'] for item in m_history]
|
| 168 |
+
context_labels = [1 if item.get('is_correct') else 0 for item in m_history]
|
| 169 |
+
|
| 170 |
+
scores = []
|
| 171 |
+
|
| 172 |
+
if v0_model:
|
| 173 |
+
try:
|
| 174 |
+
# print(f"Running inference for {m_name} on {len(targets)} targets with {len(context_prompts)} context examples...")
|
| 175 |
+
scores = v0_model.predict(
|
| 176 |
+
context_prompts=context_prompts,
|
| 177 |
+
context_labels=context_labels,
|
| 178 |
+
target_prompts=targets
|
| 179 |
+
)
|
| 180 |
+
except Exception as e:
|
| 181 |
+
print(f"Inference Error: {e}")
|
| 182 |
+
scores = [0.0] * len(targets)
|
| 183 |
+
else:
|
| 184 |
+
import random
|
| 185 |
+
scores = [random.uniform(0.1, 0.9) for _ in targets]
|
| 186 |
+
|
| 187 |
+
for t_text, score in zip(targets, scores):
|
| 188 |
+
# 处理 Tensor 或 float
|
| 189 |
+
if isinstance(score, torch.Tensor):
|
| 190 |
+
final_score = score.item()
|
| 191 |
+
else:
|
| 192 |
+
final_score = float(score)
|
| 193 |
+
|
| 194 |
+
if final_score > 0.5:
|
| 195 |
+
pred_str = "✔ Success"
|
| 196 |
+
else:
|
| 197 |
+
pred_str = "✘ Failure"
|
| 198 |
+
|
| 199 |
+
results.append({
|
| 200 |
+
"Model": m_name,
|
| 201 |
+
"Instruction": t_text,
|
| 202 |
+
"Predicted Value Score": round(final_score, 4),
|
| 203 |
+
"Prediction": pred_str
|
| 204 |
+
})
|
| 205 |
+
|
| 206 |
+
df = pd.DataFrame(results)
|
| 207 |
+
return df
|
| 208 |
+
|
| 209 |
+
# ==========================================
|
| 210 |
+
# 2. CSS 样式
|
| 211 |
+
# ==========================================
|
| 212 |
+
css = """
|
| 213 |
+
/* 全局变量 */
|
| 214 |
+
:root {
|
| 215 |
+
--primary: #10b981;
|
| 216 |
+
--primary-light: #ecfdf5;
|
| 217 |
+
--primary-dark: #047857;
|
| 218 |
+
--bg-card: #ffffff;
|
| 219 |
+
--border-sub: #e5e7eb;
|
| 220 |
+
--text-main: #1f2937;
|
| 221 |
+
--text-sub: #6b7280;
|
| 222 |
+
--success: #10b981;
|
| 223 |
+
--fail: #ef4444;
|
| 224 |
+
--popup-bg: #ffffff;
|
| 225 |
+
--popup-text: #1f2937;
|
| 226 |
+
--popup-border: #e5e7eb;
|
| 227 |
+
--popup-shadow: rgba(0,0,0,0.15);
|
| 228 |
+
}
|
| 229 |
+
.dark {
|
| 230 |
+
--bg-card: #1f2937;
|
| 231 |
+
--border-sub: #374151;
|
| 232 |
+
--text-main: #f3f4f6;
|
| 233 |
+
--text-sub: #9ca3af;
|
| 234 |
+
--popup-bg: #2d2d2d;
|
| 235 |
+
--popup-text: #e5e5e5;
|
| 236 |
+
--popup-border: #4b5563;
|
| 237 |
+
--popup-shadow: rgba(0,0,0,0.4);
|
| 238 |
+
}
|
| 239 |
+
.label-row { display: flex; align-items: center; margin-bottom: 6px; font-family: 'Source Sans Pro', sans-serif; }
|
| 240 |
+
.upload-label-text { font-size: 1rem; color: var(--text-main); margin-right: 8px; }
|
| 241 |
+
.format-hint-wrapper { display: inline-block; position: relative; cursor: help; font-size: 0.9rem; color: var(--primary); font-weight: 600; border-bottom: 1px dashed var(--primary); line-height: 1.2; }
|
| 242 |
+
.format-popup {
|
| 243 |
+
visibility: hidden; opacity: 0; position: absolute; bottom: 145%; left: -20px; width: 450px;
|
| 244 |
+
background: var(--popup-bg); color: var(--popup-text); border: 1px solid var(--popup-border);
|
| 245 |
+
padding: 16px; border-radius: 8px; box-shadow: 0 10px 30px var(--popup-shadow); z-index: 1000;
|
| 246 |
+
transition: all 0.2s cubic-bezier(0.165, 0.84, 0.44, 1); transform: translateY(10px); pointer-events: none;
|
| 247 |
+
font-size: 0.95rem; line-height: 1.5;
|
| 248 |
+
}
|
| 249 |
+
.format-hint-wrapper:hover .format-popup { visibility: visible; opacity: 1; transform: translateY(0); }
|
| 250 |
+
.format-popup::after {
|
| 251 |
+
content: ""; position: absolute; top: 100%; left: 60px; border-width: 8px; border-style: solid;
|
| 252 |
+
border-color: var(--popup-bg) transparent transparent transparent;
|
| 253 |
+
}
|
| 254 |
+
.code-snippet {
|
| 255 |
+
display: block; background: #1a1a1a; color: #a7f3d0; font-family: 'Courier New', monospace;
|
| 256 |
+
font-size: 0.85em; padding: 8px; border-radius: 6px; margin-top: 6px; white-space: pre; border: 1px solid #444;
|
| 257 |
+
}
|
| 258 |
+
.concept-banner {
|
| 259 |
+
background: linear-gradient(135deg, rgba(16, 185, 129, 0.08) 0%, rgba(59, 130, 246, 0.05) 100%);
|
| 260 |
+
border: 1px solid var(--primary-light); border-radius: 12px; padding: 24px; text-align: center; margin-bottom: 30px;
|
| 261 |
+
}
|
| 262 |
+
.concept-title { font-size: 1.8em; font-weight: 700; color: var(--text-main); margin-bottom: 8px;}
|
| 263 |
+
.concept-subtitle { font-size: 1em; color: var(--text-sub); max-width: 600px; margin: 0 auto; line-height: 1.5; }
|
| 264 |
+
.equation-box {
|
| 265 |
+
margin-top: 15px; font-family: 'Courier New', monospace; font-weight: bold;
|
| 266 |
+
color: var(--primary); background: var(--bg-card); display: inline-block;
|
| 267 |
+
padding: 8px 16px; border-radius: 8px; border: 1px dashed var(--primary);
|
| 268 |
+
box-shadow: 0 2px 6px rgba(0,0,0,0.05);
|
| 269 |
+
}
|
| 270 |
+
.step-header { display: flex; align-items: center; margin-bottom: 15px; border-bottom: 2px solid var(--border-sub); padding-bottom: 10px; }
|
| 271 |
+
.step-num {
|
| 272 |
+
background: var(--primary); color: white; width: 28px; height: 28px;
|
| 273 |
+
border-radius: 50%; display: flex; align-items: center; justify-content: center;
|
| 274 |
+
font-weight: bold; margin-right: 10px; font-size: 0.9em;
|
| 275 |
+
}
|
| 276 |
+
.step-title { font-size: 1.2em; font-weight: 600; color: var(--text-main); }
|
| 277 |
+
.step-desc { font-size: 0.93em; color: var(--text-sub); margin-left: auto; font-style: italic;}
|
| 278 |
+
.model-card {
|
| 279 |
+
background: var(--bg-card); border: 1px solid var(--border-sub);
|
| 280 |
+
border-radius: 10px; padding: 16px; margin-bottom: 15px;
|
| 281 |
+
transition: all 0.2s; position: relative; overflow: hidden;
|
| 282 |
+
}
|
| 283 |
+
.model-card.populated { border-left: 5px solid var(--primary); box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.05); }
|
| 284 |
+
.model-card.empty { border: 2px dashed var(--border-sub); text-align: center; opacity: 0.7; padding: 30px 16px; }
|
| 285 |
+
.card-title { font-weight: bold; color: var(--text-sub); }
|
| 286 |
+
.card-subtitle { font-size: 0.8em; color: var(--text-sub); }
|
| 287 |
+
.card-header { display: flex; justify-content: space-between; align-items: center; margin-bottom: 15px; }
|
| 288 |
+
.model-name { font-weight: bold; font-size: 1.1em; color: var(--text-main); }
|
| 289 |
+
.acc-badge { background: var(--primary-light); color: var(--primary-dark); font-size: 0.75em; padding: 3px 10px; border-radius: 12px; font-weight: 700; }
|
| 290 |
+
.history-container { display: flex; flex-direction: column; gap: 8px; margin-bottom: 15px; }
|
| 291 |
+
.history-row { display: flex; align-items: center; background: rgba(0,0,0,0.02); padding: 6px 8px; border-radius: 6px; }
|
| 292 |
+
.status-box {
|
| 293 |
+
width: 24px; height: 24px; border-radius: 6px; display: flex; align-items: center; justify-content: center;
|
| 294 |
+
color: white; font-size: 0.8em; font-weight: bold; margin-right: 10px; flex-shrink: 0;
|
| 295 |
+
}
|
| 296 |
+
.status-green { background-color: var(--success); }
|
| 297 |
+
.status-red { background-color: var(--fail); }
|
| 298 |
+
.prompt-text {
|
| 299 |
+
font-size: 0.9em; color: var(--text-main); white-space: nowrap; overflow: hidden; text-overflow: ellipsis;
|
| 300 |
+
}
|
| 301 |
+
.history-more { font-size: 0.95em; color: var(--text-sub); text-align: center; font-style: italic; margin-top: -4px; }
|
| 302 |
+
.custom-btn { font-weight: bold !important; font-size: 1.1em !important; }
|
| 303 |
+
.paper-link {
|
| 304 |
+
font-size: 0.5em; vertical-align: middle; color: var(--primary); text-decoration: none;
|
| 305 |
+
border: 1px solid var(--primary); padding: 4px 10px; border-radius: 15px; font-weight: normal;
|
| 306 |
+
transition: all 0.2s; background: transparent;
|
| 307 |
+
}
|
| 308 |
+
.paper-link:hover { background: var(--primary); color: white; }
|
| 309 |
+
"""
|
| 310 |
+
|
| 311 |
+
# ==========================================
|
| 312 |
+
# 3. UI 构建
|
| 313 |
+
# ==========================================
|
| 314 |
+
with gr.Blocks(theme=gr.themes.Soft(primary_hue="emerald"), css=css, title="V0 Predictor") as demo:
|
| 315 |
+
|
| 316 |
+
state_default = gr.State(value=history_default)
|
| 317 |
+
state_custom = gr.State(value=None)
|
| 318 |
+
|
| 319 |
+
gr.HTML("""
|
| 320 |
+
<div class="concept-banner">
|
| 321 |
+
<div class="concept-title">
|
| 322 |
+
V<sub>0</sub> Value Model
|
| 323 |
+
<a href="TBD" target="_blank" class="paper-link">Paper ↗</a>
|
| 324 |
+
<a href="TBD" target="_blank" class="paper-link">Code ↗</a>
|
| 325 |
+
</div>
|
| 326 |
+
<div class="concept-subtitle">
|
| 327 |
+
<span style="color: var(--primary); font-weight: bold;">Function:</span> V<sub>0</sub> uses a model's historical performance to predict<br>
|
| 328 |
+
how it will perform on unseen instructions<br>
|
| 329 |
+
without running the model itself.
|
| 330 |
+
</div>
|
| 331 |
+
<div class="equation-box">
|
| 332 |
+
Historical Perf. + Instruction → Predicted Perf.
|
| 333 |
+
</div>
|
| 334 |
+
</div>
|
| 335 |
+
""")
|
| 336 |
+
|
| 337 |
+
with gr.Row(equal_height=False):
|
| 338 |
+
|
| 339 |
+
with gr.Column(scale=1, variant="panel"):
|
| 340 |
+
gr.HTML("""
|
| 341 |
+
<div class="step-header">
|
| 342 |
+
<div class="step-num">1</div>
|
| 343 |
+
<div class="step-title">Represent Any Model with <span style="color: var(--primary);">Performance-Instruction Pairs</span></div>
|
| 344 |
+
</div>
|
| 345 |
+
""")
|
| 346 |
+
|
| 347 |
+
preview_default = gr.HTML(format_model_card(history_default, "Qwen3-4B-Instruct-2507"))
|
| 348 |
+
|
| 349 |
+
gr.HTML("""
|
| 350 |
+
<div class="label-row">
|
| 351 |
+
<span class="upload-label-text"><span style="font-weight: 800;">[Optional]</span> Upload Your Model</span>
|
| 352 |
+
<div class="format-hint-wrapper">
|
| 353 |
+
Required JSONL Format ⓘ
|
| 354 |
+
<div class="format-popup">
|
| 355 |
+
<div style="font-weight: bold; margin-bottom:4px;">File Content Example:</div>
|
| 356 |
+
<code class="code-snippet">
|
| 357 |
+
{"prompt": "Calculate 1+1", "is_correct": true}
|
| 358 |
+
{"prompt": "Write a poem", "is_correct": false}
|
| 359 |
+
</code>
|
| 360 |
+
<div style="margin-top:6px; font-size:0.9em; opacity: 0.8;">
|
| 361 |
+
Each line must be a valid JSON object containing <b>'prompt'</b> (string) and <b>'is_correct'</b> (boolean).
|
| 362 |
+
</div>
|
| 363 |
+
</div>
|
| 364 |
+
</div>
|
| 365 |
+
</div>
|
| 366 |
+
""")
|
| 367 |
+
|
| 368 |
+
upload_btn = gr.File(
|
| 369 |
+
label=None,
|
| 370 |
+
show_label=False,
|
| 371 |
+
file_types=[".jsonl"],
|
| 372 |
+
height=130
|
| 373 |
+
)
|
| 374 |
+
preview_custom = gr.HTML(format_model_card(None, "Custom", True))
|
| 375 |
+
|
| 376 |
+
with gr.Column(scale=1, variant="panel"):
|
| 377 |
+
gr.HTML("""
|
| 378 |
+
<div class="step-header" style="margin-top: 80px;">
|
| 379 |
+
<div class="step-num">2</div>
|
| 380 |
+
<div class="step-title">Enter Instructions</div>
|
| 381 |
+
<div class="step-desc">trigger V<sub>0</sub> to predict the expected perf. for each model</div>
|
| 382 |
+
</div>
|
| 383 |
+
""")
|
| 384 |
+
t1 = gr.Textbox(label="Instruction 1", value="What is the largest $n$ such that there exists a non-degenerate convex $n$-gon where each of its angles is an integer number of degrees, and all angles are distinct?", lines=2)
|
| 385 |
+
t2 = gr.Textbox(label="Instruction 2", value="已知四面体 \\(A B C D\\) 内接于球 \\(O\\),且 \\(A D\\) 是球 \\(O\\) 的直径。若 \\(\\triangle A B C\\) 和 \\(\\triangle B C D\\) 都是边长为 1 的等边三角形,则四面体 \\(A B C D\\) 的体积是多少?原始答案的形式为 \\(\\frac{\\sqrt{c}}{b}\\),请给出a+b+c的值。", lines=2)
|
| 386 |
+
t3 = gr.Textbox(label="Instruction 3", placeholder="Your instruction here ...", lines=2)
|
| 387 |
+
|
| 388 |
+
gr.HTML("""
|
| 389 |
+
<div style="margin-top: 15px; font-size: 1.05em; color: var(--text-main);">
|
| 390 |
+
<span style="color: var(--primary); font-weight: bold;">Next:</span> Clicking <b>Run V<sub>0</sub> Prediction!</b>
|
| 391 |
+
</div>
|
| 392 |
+
""")
|
| 393 |
+
|
| 394 |
+
with gr.Row():
|
| 395 |
+
with gr.Column():
|
| 396 |
+
predict_btn = gr.Button("Run V₀ Prediction", variant="primary", size="lg", elem_classes=["custom-btn"])
|
| 397 |
+
|
| 398 |
+
gr.HTML("""
|
| 399 |
+
<div class="step-header" style="margin-top: 20px; border-bottom: none;">
|
| 400 |
+
<div class="step-num">3</div>
|
| 401 |
+
<div class="step-title">Results</div>
|
| 402 |
+
</div>
|
| 403 |
+
""")
|
| 404 |
+
|
| 405 |
+
output_df = gr.Dataframe(
|
| 406 |
+
headers=["Model Entity", "Unseen Instruction", "Predicted Value Score", "Prediction"],
|
| 407 |
+
datatype=["str", "str", "number", "str"],
|
| 408 |
+
interactive=False,
|
| 409 |
+
column_widths=["20%", "40%", "20%", "20%"]
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
upload_btn.change(
|
| 413 |
+
fn=process_upload,
|
| 414 |
+
inputs=[upload_btn],
|
| 415 |
+
outputs=[state_custom, preview_custom]
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
predict_btn.click(
|
| 419 |
+
fn=predict_performance,
|
| 420 |
+
inputs=[state_default, state_custom, t1, t2, t3],
|
| 421 |
+
outputs=[output_df]
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
if __name__ == "__main__":
|
| 425 |
+
demo.launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio>=4.0.0
|
| 2 |
+
pandas
|
| 3 |
+
einops
|
| 4 |
+
numpy==2.2.6
|
| 5 |
+
scikit-learn==1.7.2
|
| 6 |
+
-e git+https://github.com/PriorLabs/TabPFN.git@2cd2326038e789a26f7a07e70e1ea986ffd040c9#egg=tabpfn
|
| 7 |
+
torch==2.7.1
|
| 8 |
+
tqdm==4.67.1
|
| 9 |
+
transformers==4.55.4
|
| 10 |
+
wandb==0.21.3
|
v0_core/config/__init__.py
ADDED
|
File without changes
|
v0_core/config/arguments.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import argparse
|
| 3 |
+
|
| 4 |
+
# =============================================================================
|
| 5 |
+
# 参数解析配置
|
| 6 |
+
# =============================================================================
|
| 7 |
+
def parse_args():
|
| 8 |
+
parser = argparse.ArgumentParser(description="Generalist Value Model")
|
| 9 |
+
|
| 10 |
+
# --- 路径相关 ---
|
| 11 |
+
parser.add_argument("--time_str", type=str, required=True)
|
| 12 |
+
parser.add_argument("--qwen_path", type=str, required=True, help="Qwen 模型路径")
|
| 13 |
+
parser.add_argument("--tabpfn_checkpoint", type=str, required=True, help="TabPFN Checkpoint 路径")
|
| 14 |
+
|
| 15 |
+
# 数据路径配置
|
| 16 |
+
parser.add_argument("--context_data_paths", type=str, required=True, help="Context Pool Jsonl路径 (支持多个,逗号分隔)")
|
| 17 |
+
parser.add_argument("--train_data_paths", type=str, default=None, help="Train Query Pool Jsonl路径 (支持多个,逗号分隔)")
|
| 18 |
+
parser.add_argument("--eval_data_paths", type=str, default=None, help="Test Query Pool Jsonl路径 (支持多个,逗号分隔)")
|
| 19 |
+
parser.add_argument("--validity_data_paths", type=str, default=None, help="Validity Test Pool Jsonl路径 (支持多个,逗号分隔)")
|
| 20 |
+
|
| 21 |
+
parser.add_argument("--prompt_dict_path", type=str, required=True, help="Prompt 字典 JSON 路径")
|
| 22 |
+
|
| 23 |
+
# --- Checkpoint 保存相关 ---
|
| 24 |
+
parser.add_argument("--checkpoint_dir", type=str, default=None, help="模型保存目录")
|
| 25 |
+
parser.add_argument("--save_interval", type=int, default=1, help="每隔多少个 Epoch 保存一次模型")
|
| 26 |
+
parser.add_argument("--max_keep_checkpoints", type=int, default=2, help="最多保留多少个最新的 Checkpoint")
|
| 27 |
+
parser.add_argument("--resume", action="store_true", help="是否尝试从 checkpoint_dir 恢复训练")
|
| 28 |
+
parser.add_argument("--resume_from_specific_epoch", type=int, default=None, help="指定要 resume 的 epoch")
|
| 29 |
+
|
| 30 |
+
# --- 日志相关 ---
|
| 31 |
+
parser.add_argument("--log_path", type=str, default=None)
|
| 32 |
+
parser.add_argument("--log_interval", type=int, default=10, help="保存间隔")
|
| 33 |
+
parser.add_argument("--metric_path", type=str, default=None)
|
| 34 |
+
parser.add_argument("--wandb_project", type=str, default="context-v", help="Wandb 项目名称")
|
| 35 |
+
parser.add_argument("--wandb_interval", type=int, default=1, help="Wandb 记录间隔")
|
| 36 |
+
parser.add_argument("--wandb_id", type=str, default=None)
|
| 37 |
+
|
| 38 |
+
# --- 运行模式与策略 ---
|
| 39 |
+
parser.add_argument("--run_mode", type=str, default="eval", choices=["train", "eval"])
|
| 40 |
+
parser.add_argument("--pooling_strategy", type=str, default="dynamic_query",
|
| 41 |
+
choices=["last_token", "fixed_query", "dynamic_query"],
|
| 42 |
+
help="Embedding 提取策略")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
parser.add_argument("--label_strategy", type=str, default="binary",
|
| 46 |
+
choices=["binary", "minmax_norm"],
|
| 47 |
+
help="Label 处理策略")
|
| 48 |
+
parser.add_argument("--loss_type", type=str, default="ce_hard",
|
| 49 |
+
choices=["ce_hard", "ce_soft", "kl_div", "pairwise", "combined"],
|
| 50 |
+
help="Loss 函数类型: combined = pairwise + ce_soft")
|
| 51 |
+
parser.add_argument("--loss_alpha", type=float, default=0.5,
|
| 52 |
+
help="Combined Loss 中 Pairwise 的权重 (0.0-1.0)。Total = alpha * Pair + (1-alpha) * CE")
|
| 53 |
+
parser.add_argument("--loss_balance", action="store_true", help="是否对正负样本加权")
|
| 54 |
+
|
| 55 |
+
parser.add_argument("--kl_temperature", type=float, default=1.0,
|
| 56 |
+
help="KL 散度或 Softmax 的温度系数 T")
|
| 57 |
+
# --- 降维参数 ---
|
| 58 |
+
parser.add_argument("--reduce_method", type=str, default="none",
|
| 59 |
+
choices=["none", "avg_pool", "max_pool"])
|
| 60 |
+
parser.add_argument("--target_dim", type=int, default=1024)
|
| 61 |
+
parser.add_argument("--num_heads", type=int, default=4)
|
| 62 |
+
|
| 63 |
+
parser.add_argument("--context_clustering", action="store_true", help="是否启用 Support Set 聚类筛选")
|
| 64 |
+
parser.add_argument("--context_num_clusters", type=int, default=128, help="聚类保留的原型数量 (k值)")
|
| 65 |
+
|
| 66 |
+
# --- 模型超参数 ---
|
| 67 |
+
parser.add_argument("--num_queries", type=int, default=10)
|
| 68 |
+
parser.add_argument("--embed_dim", type=int, default=32)
|
| 69 |
+
parser.add_argument("--tabpfn_estimators", type=int, default=4)
|
| 70 |
+
parser.add_argument("--dynamic_query_generator_bottleneck_dim", type=int, default=128)
|
| 71 |
+
parser.add_argument("--dynamic_query_generator_dropout_rate", type=float, default=0.2)
|
| 72 |
+
|
| 73 |
+
# --- 训练超参数 ---
|
| 74 |
+
parser.add_argument("--epochs", type=int, default=5)
|
| 75 |
+
parser.add_argument("--meta_batch_size", type=int, default=1, help="每次forward处理多少个Task(一个Task包含Support+Query)")
|
| 76 |
+
parser.add_argument("--grad_accum_steps", type=int, default=4)
|
| 77 |
+
|
| 78 |
+
parser.add_argument("--train_query_batch_size", type=int, default=8, help="每个Task包含多少个Query样本 (必须来自同一个Step)")
|
| 79 |
+
parser.add_argument("--eval_query_batch_size", type=int, default=8, help="每个Task包含��少个Query样本 (必须来自同一个Step)")
|
| 80 |
+
parser.add_argument("--support_size", type=int, default=256, help="每个Task采样的Context样本数量")
|
| 81 |
+
|
| 82 |
+
parser.add_argument("--lr_backbone", type=float, default=1e-5)
|
| 83 |
+
parser.add_argument("--lr_adapter", type=float, default=1e-4)
|
| 84 |
+
parser.add_argument("--lr_tabpfn", type=float, default=1e-5)
|
| 85 |
+
|
| 86 |
+
parser.add_argument("--weight_decay", type=float, default=0.01)
|
| 87 |
+
parser.add_argument("--warmup_ratio", type=float, default=0.05)
|
| 88 |
+
parser.add_argument("--lr_scheduler_type", type=str, default="cosine")
|
| 89 |
+
parser.add_argument("--max_grad_norm", type=float, default=1.0)
|
| 90 |
+
|
| 91 |
+
parser.add_argument("--train_embed_bs", type=int, default=4)
|
| 92 |
+
parser.add_argument("--eval_embed_bs", type=int, default=4)
|
| 93 |
+
|
| 94 |
+
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
|
| 95 |
+
|
| 96 |
+
args = parser.parse_args()
|
| 97 |
+
|
| 98 |
+
def split_paths(path_str):
|
| 99 |
+
if not path_str: return []
|
| 100 |
+
return [p.strip() for p in path_str.split(',') if p.strip()]
|
| 101 |
+
|
| 102 |
+
args.context_data_paths = split_paths(args.context_data_paths)
|
| 103 |
+
args.train_data_paths = split_paths(args.train_data_paths)
|
| 104 |
+
args.eval_data_paths = split_paths(args.eval_data_paths)
|
| 105 |
+
args.validity_data_paths = split_paths(args.validity_data_paths)
|
| 106 |
+
args.prompt_dict_path = split_paths(args.prompt_dict_path)
|
| 107 |
+
|
| 108 |
+
return args
|
| 109 |
+
|
| 110 |
+
def print_elegant_args(args):
|
| 111 |
+
"""
|
| 112 |
+
打印参数列表
|
| 113 |
+
"""
|
| 114 |
+
args_dict = vars(args)
|
| 115 |
+
keys = sorted(args_dict.keys())
|
| 116 |
+
# 计算最长键名以便对齐
|
| 117 |
+
max_k = max([len(k) for k in keys]) if keys else 10
|
| 118 |
+
|
| 119 |
+
# 定义颜色
|
| 120 |
+
C_KEY = "\033[36m" # 青色用于键
|
| 121 |
+
C_VALUE = "\033[33m" # 黄色用于值(如果不想要颜色,设为 "" 即可)
|
| 122 |
+
C_RESET = "\033[0m" # 重置
|
| 123 |
+
|
| 124 |
+
print(f"\n{C_VALUE}Arguments:{C_RESET}")
|
| 125 |
+
|
| 126 |
+
for k in keys:
|
| 127 |
+
val = str(args_dict[k])
|
| 128 |
+
# 格式说明:
|
| 129 |
+
# {k:<{max_k}} : 让键名左对齐并填充空格
|
| 130 |
+
# val : 完整打印值,不截断
|
| 131 |
+
print(f" {C_KEY}{k:<{max_k}}{C_RESET} : {val}")
|
| 132 |
+
|
| 133 |
+
print() # 打印末尾空行
|
v0_core/data/__init__.py
ADDED
|
File without changes
|
v0_core/data/collator.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
def meta_collate_fn(batch):
|
| 4 |
+
all_prompts = []
|
| 5 |
+
all_labels = []
|
| 6 |
+
metadata = []
|
| 7 |
+
current_start = 0
|
| 8 |
+
for item in batch:
|
| 9 |
+
t_len = len(item['prompts'])
|
| 10 |
+
all_prompts.extend(item['prompts'])
|
| 11 |
+
all_labels.append(item['labels'])
|
| 12 |
+
metadata.append({
|
| 13 |
+
'start': current_start,
|
| 14 |
+
'len': t_len,
|
| 15 |
+
'split': item['split_idx'],
|
| 16 |
+
'q_ids': item['q_ids'],
|
| 17 |
+
'pair_ids': item['pair_ids'],
|
| 18 |
+
'pair_types': item['pair_types'],
|
| 19 |
+
'key': item['key'],
|
| 20 |
+
'stats': item['stats']
|
| 21 |
+
})
|
| 22 |
+
current_start += t_len
|
| 23 |
+
return {
|
| 24 |
+
'flat_prompts': all_prompts,
|
| 25 |
+
'flat_labels': torch.cat(all_labels),
|
| 26 |
+
'metadata': metadata
|
| 27 |
+
}
|
v0_core/data/dataset.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import random
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
from torch.utils.data import Dataset
|
| 7 |
+
from v0_core.data.utils import load_jsonl_lines
|
| 8 |
+
|
| 9 |
+
# =============================================================================
|
| 10 |
+
# 数据与日志工具
|
| 11 |
+
# =============================================================================
|
| 12 |
+
class ValueModelDataset(Dataset):
|
| 13 |
+
def __init__(self,
|
| 14 |
+
context_paths,
|
| 15 |
+
query_paths,
|
| 16 |
+
prompt_dict_path,
|
| 17 |
+
label_strategy='binary',
|
| 18 |
+
query_batch_size=8,
|
| 19 |
+
support_size=256,
|
| 20 |
+
mode='train'):
|
| 21 |
+
"""
|
| 22 |
+
args:
|
| 23 |
+
context_paths: List of paths to context_pool jsonl files
|
| 24 |
+
query_paths: List of paths to query_pool jsonl files (train/test/validity)
|
| 25 |
+
prompt_dict_path: List of paths to prompt dictionaries
|
| 26 |
+
query_batch_size: Number of queries in one forward pass (all from same step)
|
| 27 |
+
support_size: Number of context samples to sample
|
| 28 |
+
mode: 'train' (shuffle queries before chunking) or 'eval' (sequential)
|
| 29 |
+
"""
|
| 30 |
+
self.label_strategy = label_strategy
|
| 31 |
+
self.query_batch_size = query_batch_size
|
| 32 |
+
self.support_size = support_size
|
| 33 |
+
self.mode = mode
|
| 34 |
+
|
| 35 |
+
# 1. Load Prompt Dictionary
|
| 36 |
+
print(f"Loading prompts from {prompt_dict_path}...")
|
| 37 |
+
self.prompt_map = {}
|
| 38 |
+
for path in prompt_dict_path:
|
| 39 |
+
with open(path, 'r', encoding='utf-8') as f:
|
| 40 |
+
self.prompt_map.update(json.load(f))
|
| 41 |
+
|
| 42 |
+
# 2. Load Context Pool and Index it
|
| 43 |
+
# Structure: {(dataset, model, step): [list of sample dicts]}
|
| 44 |
+
print("Loading Context Pool...")
|
| 45 |
+
self.context_pool = defaultdict(list)
|
| 46 |
+
self.context_pool_fallback = defaultdict(list)
|
| 47 |
+
raw_context = load_jsonl_lines(context_paths)
|
| 48 |
+
for item in raw_context:
|
| 49 |
+
key = (item['dataset'], item['model'], item['step'])
|
| 50 |
+
self.context_pool[key].append(item)
|
| 51 |
+
fallback_key = (item['model'], item['step'])
|
| 52 |
+
self.context_pool_fallback[fallback_key].append(item)
|
| 53 |
+
print(f"Loaded Context Pool with {len(self.context_pool)} unique (dataset, model, step) keys.")
|
| 54 |
+
|
| 55 |
+
# 3. Load Query Pool
|
| 56 |
+
print(f"Loading Query Pool from {query_paths}...")
|
| 57 |
+
raw_queries = load_jsonl_lines(query_paths)
|
| 58 |
+
|
| 59 |
+
# 4. Group Queries by Key
|
| 60 |
+
self.queries_by_key = defaultdict(list)
|
| 61 |
+
print("Grouping Queries...")
|
| 62 |
+
for item in raw_queries:
|
| 63 |
+
key = (item['dataset'], item['model'], item['step'])
|
| 64 |
+
# Pre-fetch prompt text to save time later, if ID exists
|
| 65 |
+
s_id_str = f"{item['dataset']}_{item['id']}"
|
| 66 |
+
item['text'] = self.prompt_map[s_id_str]
|
| 67 |
+
self.queries_by_key[key].append(item)
|
| 68 |
+
|
| 69 |
+
# 5. Pre-calculate Class Statistics for Context-Aware Re-weighting
|
| 70 |
+
# 统计每个Context Key下,Query Pool中的正负样本总数,用于计算加权Loss
|
| 71 |
+
print("Calculating Global Context Statistics for Re-weighting...")
|
| 72 |
+
self.context_stats = {}
|
| 73 |
+
if mode == 'train':
|
| 74 |
+
for key, items in self.queries_by_key.items():
|
| 75 |
+
# 定义正样本: score >= 0
|
| 76 |
+
n_pos = sum(1 for x in items if float(x.get('score', -1)) >= 0)
|
| 77 |
+
n_neg = len(items) - n_pos
|
| 78 |
+
self.context_stats[key] = {'n_pos': n_pos, 'n_neg': n_neg}
|
| 79 |
+
|
| 80 |
+
print("\n" + "="*60)
|
| 81 |
+
print(f"Top 10 Steps Statistics ({mode} mode)")
|
| 82 |
+
print(f"{'Dataset':<15} | {'Model':<15} | {'Step':<6} | {'n_pos':<6} | {'n_neg':<6} | {'Total':<6}")
|
| 83 |
+
print("-" * 60)
|
| 84 |
+
|
| 85 |
+
sorted_keys = sorted(list(self.context_stats.keys()))
|
| 86 |
+
|
| 87 |
+
for i, key in enumerate(sorted_keys[:10]):
|
| 88 |
+
|
| 89 |
+
dataset_name, model_name, step_val = key
|
| 90 |
+
stats = self.context_stats[key]
|
| 91 |
+
total = stats['n_pos'] + stats['n_neg']
|
| 92 |
+
print(f"{dataset_name:<15} | {model_name:<15} | {str(step_val):<6} | "
|
| 93 |
+
f"{stats['n_pos']:<6} | {stats['n_neg']:<6} | {total:<6}")
|
| 94 |
+
|
| 95 |
+
print(f"... (Total {len(sorted_keys)} steps loaded)")
|
| 96 |
+
print("="*60 + "\n")
|
| 97 |
+
|
| 98 |
+
# 6. Create Tasks (Chunks of Queries)
|
| 99 |
+
self.tasks = []
|
| 100 |
+
self.generate_tasks(shuffle=(self.mode == 'train'))
|
| 101 |
+
|
| 102 |
+
print(f"Dataset Initialized. Total Tasks: {len(self.tasks)}")
|
| 103 |
+
|
| 104 |
+
def generate_tasks(self, shuffle=True):
|
| 105 |
+
"""
|
| 106 |
+
Pairwise Task Generation with Cyclic Oversampling.
|
| 107 |
+
目标:保留所有样本,不进行丢弃。对于数量较少的一方,循环重复使用以匹配数量较多的一方。
|
| 108 |
+
"""
|
| 109 |
+
new_tasks = []
|
| 110 |
+
keys = sorted(list(self.queries_by_key.keys()))
|
| 111 |
+
|
| 112 |
+
if shuffle:
|
| 113 |
+
random.shuffle(keys)
|
| 114 |
+
|
| 115 |
+
dropped_steps = 0
|
| 116 |
+
total_pairs = 0
|
| 117 |
+
|
| 118 |
+
for key in keys:
|
| 119 |
+
samples = list(self.queries_by_key[key])
|
| 120 |
+
|
| 121 |
+
if self.mode == 'train':
|
| 122 |
+
# 1. 分离正负样本
|
| 123 |
+
pos_list = [x for x in samples if self._process_label(x['score']) >= 0.5]
|
| 124 |
+
neg_list = [x for x in samples if self._process_label(x['score']) < 0.5]
|
| 125 |
+
|
| 126 |
+
n_pos = len(pos_list)
|
| 127 |
+
n_neg = len(neg_list)
|
| 128 |
+
|
| 129 |
+
# 2. 如果某一方完全缺失,不得不跳过 (无法构建 Pair)
|
| 130 |
+
if n_pos == 0 or n_neg == 0:
|
| 131 |
+
dropped_steps += 1
|
| 132 |
+
continue
|
| 133 |
+
|
| 134 |
+
# 3. Shuffle (保证每次 Epoch 重复使用的样本是随机顺序的)
|
| 135 |
+
if shuffle:
|
| 136 |
+
random.shuffle(pos_list)
|
| 137 |
+
random.shuffle(neg_list)
|
| 138 |
+
|
| 139 |
+
# 4. Maximize Pairs via Cyclic Oversampling
|
| 140 |
+
# 取最大长度,保证所有样本至少被用到一次
|
| 141 |
+
n_pairs = max(n_pos, n_neg)
|
| 142 |
+
|
| 143 |
+
paired_samples = []
|
| 144 |
+
for i in range(n_pairs):
|
| 145 |
+
p = pos_list[i % n_pos]
|
| 146 |
+
n = neg_list[i % n_neg]
|
| 147 |
+
|
| 148 |
+
paired_samples.append(p)
|
| 149 |
+
paired_samples.append(n)
|
| 150 |
+
|
| 151 |
+
total_pairs += n_pairs
|
| 152 |
+
|
| 153 |
+
# 5. Chunking
|
| 154 |
+
# query_batch_size 必须是偶数
|
| 155 |
+
bs = self.query_batch_size
|
| 156 |
+
if bs % 2 != 0:
|
| 157 |
+
bs -= 1
|
| 158 |
+
if bs < 2: bs = 2
|
| 159 |
+
|
| 160 |
+
for i in range(0, len(paired_samples), bs):
|
| 161 |
+
chunk = paired_samples[i : i + bs]
|
| 162 |
+
|
| 163 |
+
# 丢弃末尾不完整的 Pair (极少发生,仅当 chunk 长度为奇数时)
|
| 164 |
+
if len(chunk) % 2 != 0:
|
| 165 |
+
chunk = chunk[:-1]
|
| 166 |
+
|
| 167 |
+
context_key_to_use = None
|
| 168 |
+
if key in self.context_pool and len(self.context_pool[key]) > 0:
|
| 169 |
+
context_key_to_use = key
|
| 170 |
+
else:
|
| 171 |
+
fallback_key = (key[1], key[2]) # (model, step)
|
| 172 |
+
if fallback_key in self.context_pool_fallback and len(self.context_pool_fallback[fallback_key]) > 0:
|
| 173 |
+
context_key_to_use = fallback_key
|
| 174 |
+
|
| 175 |
+
if len(chunk) > 0 and context_key_to_use is not None:
|
| 176 |
+
new_tasks.append({
|
| 177 |
+
'key': key,
|
| 178 |
+
'context_key': context_key_to_use,
|
| 179 |
+
'queries': chunk,
|
| 180 |
+
'is_pairwise': True
|
| 181 |
+
})
|
| 182 |
+
|
| 183 |
+
else:
|
| 184 |
+
if shuffle: random.shuffle(samples)
|
| 185 |
+
for i in range(0, len(samples), self.query_batch_size):
|
| 186 |
+
chunk = samples[i : i + self.query_batch_size]
|
| 187 |
+
context_key_to_use = None
|
| 188 |
+
if key in self.context_pool and len(self.context_pool[key]) > 0:
|
| 189 |
+
context_key_to_use = key
|
| 190 |
+
else:
|
| 191 |
+
fallback_key = (key[1], key[2]) # (model, step)
|
| 192 |
+
if fallback_key in self.context_pool_fallback and len(self.context_pool_fallback[fallback_key]) > 0:
|
| 193 |
+
context_key_to_use = fallback_key
|
| 194 |
+
|
| 195 |
+
if context_key_to_use is not None:
|
| 196 |
+
new_tasks.append({
|
| 197 |
+
'key': key,
|
| 198 |
+
'context_key': context_key_to_use,
|
| 199 |
+
'queries': chunk,
|
| 200 |
+
'is_pairwise': False
|
| 201 |
+
})
|
| 202 |
+
|
| 203 |
+
self.tasks = new_tasks
|
| 204 |
+
if self.mode == 'train':
|
| 205 |
+
print(f" >>> [Dataset] Generated {len(self.tasks)} tasks from {len(keys)} contexts.")
|
| 206 |
+
print(f" >>> [Pairwise Stats] Total Pairs: {total_pairs} (Using Oversampling). Dropped Steps (0 pos or 0 neg): {dropped_steps}")
|
| 207 |
+
|
| 208 |
+
def _process_label(self, reward):
|
| 209 |
+
val = float(reward)
|
| 210 |
+
if self.label_strategy == "binary":
|
| 211 |
+
return 1.0 if val >= 0 else 0.0
|
| 212 |
+
elif self.label_strategy == "minmax_norm":
|
| 213 |
+
return (np.clip(val, -1.0, 1.0) + 1.0) / 2.0
|
| 214 |
+
return val
|
| 215 |
+
|
| 216 |
+
def __len__(self):
|
| 217 |
+
return len(self.tasks)
|
| 218 |
+
|
| 219 |
+
def __getitem__(self, idx):
|
| 220 |
+
task = self.tasks[idx]
|
| 221 |
+
key = task['key'] # (dataset, model, step)
|
| 222 |
+
query_samples = task['queries']
|
| 223 |
+
|
| 224 |
+
# 1. Sample Context
|
| 225 |
+
context_key = task.get('context_key', key)
|
| 226 |
+
available_context = self.context_pool[key] if context_key == key else self.context_pool_fallback[context_key]
|
| 227 |
+
|
| 228 |
+
if len(available_context) >= self.support_size:
|
| 229 |
+
support_samples = random.sample(available_context, self.support_size)
|
| 230 |
+
else:
|
| 231 |
+
support_samples = available_context
|
| 232 |
+
|
| 233 |
+
# 2. Format Data
|
| 234 |
+
prompts = []
|
| 235 |
+
labels = []
|
| 236 |
+
|
| 237 |
+
# Process Support
|
| 238 |
+
for item in support_samples:
|
| 239 |
+
s_id_str = f"{item['dataset']}_{item['id']}"
|
| 240 |
+
text = self.prompt_map[s_id_str]
|
| 241 |
+
if text:
|
| 242 |
+
prompts.append(text)
|
| 243 |
+
labels.append(self._process_label(item['score']))
|
| 244 |
+
|
| 245 |
+
split_idx = len(prompts) # Boundary
|
| 246 |
+
|
| 247 |
+
# Process Query
|
| 248 |
+
q_ids = []
|
| 249 |
+
pair_ids = []
|
| 250 |
+
pair_types = []
|
| 251 |
+
|
| 252 |
+
for item in query_samples:
|
| 253 |
+
prompts.append(item['text'])
|
| 254 |
+
labels.append(self._process_label(item['score']))
|
| 255 |
+
q_ids.append(item['id'])
|
| 256 |
+
if 'pair_id' in item:
|
| 257 |
+
pair_ids.append(item['pair_id'])
|
| 258 |
+
if 'pair_type' in item:
|
| 259 |
+
pair_types.append(item['pair_type'])
|
| 260 |
+
|
| 261 |
+
# 获取该Context的全局正负样本统计量
|
| 262 |
+
stats = self.context_stats.get(key, {'n_pos': 0, 'n_neg': 0})
|
| 263 |
+
|
| 264 |
+
return {
|
| 265 |
+
"prompts": prompts,
|
| 266 |
+
"labels": torch.tensor(labels, dtype=torch.float),
|
| 267 |
+
"split_idx": split_idx,
|
| 268 |
+
"q_ids": q_ids,
|
| 269 |
+
"pair_ids": pair_ids,
|
| 270 |
+
"pair_types": pair_types,
|
| 271 |
+
"key": key,
|
| 272 |
+
"stats": stats # Pass stats to collate
|
| 273 |
+
}
|
v0_core/data/utils.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
def load_jsonl_lines(paths):
|
| 5 |
+
"""读取多个文件路径并将所有行合并为一个列表"""
|
| 6 |
+
all_lines = []
|
| 7 |
+
if not isinstance(paths, list): paths = [paths]
|
| 8 |
+
for p in paths:
|
| 9 |
+
if not p or not os.path.exists(p):
|
| 10 |
+
print(f"Warning: Path not found {p}")
|
| 11 |
+
continue
|
| 12 |
+
print(f"Loading {p}...")
|
| 13 |
+
try:
|
| 14 |
+
with open(p, 'r', encoding='utf-8') as f:
|
| 15 |
+
for line in f:
|
| 16 |
+
if line.strip():
|
| 17 |
+
all_lines.append(json.loads(line.strip()))
|
| 18 |
+
except Exception as e:
|
| 19 |
+
print(f"Error reading {p}: {e}")
|
| 20 |
+
return all_lines
|
v0_core/models/__init__.py
ADDED
|
File without changes
|
v0_core/models/v0.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch import Tensor
|
| 5 |
+
from transformers import AutoTokenizer, AutoModel
|
| 6 |
+
|
| 7 |
+
# =============================================================================
|
| 8 |
+
# TabPFN 修复补丁
|
| 9 |
+
# =============================================================================
|
| 10 |
+
try:
|
| 11 |
+
from tabpfn import TabPFNClassifier
|
| 12 |
+
except ImportError as e:
|
| 13 |
+
print(f"导入 TabPFN 模块失败: {e}")
|
| 14 |
+
print("请确保已安装 tabpfn,并且处于包含 tabpfn 源代码的环境中。")
|
| 15 |
+
exit(1)
|
| 16 |
+
|
| 17 |
+
from v0_core.utils.tabpfn_patches import fixed_fit, fixed_forward
|
| 18 |
+
|
| 19 |
+
# Apply Patches
|
| 20 |
+
TabPFNClassifier.fit = fixed_fit
|
| 21 |
+
TabPFNClassifier.forward = fixed_forward
|
| 22 |
+
# print("已应用 TabPFNClassifier 的 fit 和 forward 最终修复补丁。")
|
| 23 |
+
|
| 24 |
+
# =============================================================================
|
| 25 |
+
# Qwen Official Pooling
|
| 26 |
+
# =============================================================================
|
| 27 |
+
def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
|
| 28 |
+
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
|
| 29 |
+
if left_padding:
|
| 30 |
+
return last_hidden_states[:, -1]
|
| 31 |
+
else:
|
| 32 |
+
sequence_lengths = attention_mask.sum(dim=1) - 1
|
| 33 |
+
batch_size = last_hidden_states.shape[0]
|
| 34 |
+
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
|
| 35 |
+
|
| 36 |
+
# =============================================================================
|
| 37 |
+
# Adapter 策略模块
|
| 38 |
+
# =============================================================================
|
| 39 |
+
class FixedQueryAdapter(nn.Module):
|
| 40 |
+
def __init__(self, input_dim, num_queries=10, embed_dim=32, num_heads=4):
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.proj_kv = nn.Linear(input_dim, embed_dim)
|
| 43 |
+
self.queries = nn.Parameter(torch.randn(1, num_queries, embed_dim))
|
| 44 |
+
self.mha = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)
|
| 45 |
+
self.ln_q = nn.LayerNorm(embed_dim)
|
| 46 |
+
self.ln_kv = nn.LayerNorm(embed_dim)
|
| 47 |
+
|
| 48 |
+
def forward(self, hidden_states, attention_mask=None):
|
| 49 |
+
batch_size = hidden_states.size(0)
|
| 50 |
+
kv = self.proj_kv(hidden_states)
|
| 51 |
+
q = self.queries.repeat(batch_size, 1, 1)
|
| 52 |
+
key_padding_mask = ~attention_mask.bool() if attention_mask is not None else None
|
| 53 |
+
attn_out, _ = self.mha(query=self.ln_q(q), key=self.ln_kv(kv), value=kv, key_padding_mask=key_padding_mask)
|
| 54 |
+
return attn_out.reshape(batch_size, -1)
|
| 55 |
+
|
| 56 |
+
class DynamicQueryAdapter(nn.Module):
|
| 57 |
+
def __init__(self, input_dim, num_queries=10, embed_dim=32, num_heads=4, generator_bottleneck_dim=128, generator_dropout_rate=0.2):
|
| 58 |
+
super().__init__()
|
| 59 |
+
self.num_queries = num_queries
|
| 60 |
+
self.embed_dim = embed_dim
|
| 61 |
+
self.static_queries = nn.Parameter(torch.randn(1, num_queries, embed_dim))
|
| 62 |
+
self.generator = nn.Sequential(
|
| 63 |
+
nn.Linear(input_dim, generator_bottleneck_dim),
|
| 64 |
+
nn.LayerNorm(generator_bottleneck_dim),
|
| 65 |
+
nn.GELU(),
|
| 66 |
+
nn.Dropout(generator_dropout_rate),
|
| 67 |
+
nn.Linear(generator_bottleneck_dim, num_queries * embed_dim)
|
| 68 |
+
)
|
| 69 |
+
nn.init.zeros_(self.generator[-1].weight)
|
| 70 |
+
nn.init.zeros_(self.generator[-1].bias)
|
| 71 |
+
self.proj_kv = nn.Linear(input_dim, embed_dim)
|
| 72 |
+
self.mha = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)
|
| 73 |
+
self.ln_q = nn.LayerNorm(embed_dim)
|
| 74 |
+
self.ln_kv = nn.LayerNorm(embed_dim)
|
| 75 |
+
|
| 76 |
+
def forward(self, hidden_states, attention_mask):
|
| 77 |
+
batch_size = hidden_states.size(0)
|
| 78 |
+
v_global = last_token_pool(hidden_states, attention_mask)
|
| 79 |
+
delta_q = self.generator(v_global).view(batch_size, self.num_queries, self.embed_dim)
|
| 80 |
+
q_final = self.static_queries.repeat(batch_size, 1, 1) + delta_q
|
| 81 |
+
kv = self.proj_kv(hidden_states)
|
| 82 |
+
key_padding_mask = ~attention_mask.bool() if attention_mask is not None else None
|
| 83 |
+
attn_out, _ = self.mha(query=self.ln_q(q_final), key=self.ln_kv(kv), value=kv, key_padding_mask=key_padding_mask)
|
| 84 |
+
return attn_out.reshape(batch_size, -1)
|
| 85 |
+
|
| 86 |
+
# =============================================================================
|
| 87 |
+
# Qwen Embedding 模型封装
|
| 88 |
+
# =============================================================================
|
| 89 |
+
class QwenEmbeddingModel(nn.Module):
|
| 90 |
+
def __init__(self, model_path, pooling_type='last_token', num_queries=10, embed_dim=32,
|
| 91 |
+
reduce_method='avg_pool', target_dim=1024, num_heads=4, generator_bottleneck_dim=128, generator_dropout_rate=0.2, device='cuda'):
|
| 92 |
+
super().__init__()
|
| 93 |
+
self.device = device
|
| 94 |
+
self.pooling_type = pooling_type
|
| 95 |
+
self.reduce_method = reduce_method
|
| 96 |
+
self.target_dim = target_dim
|
| 97 |
+
|
| 98 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side='left')
|
| 99 |
+
self.backbone = AutoModel.from_pretrained(model_path, trust_remote_code=True).to(device)
|
| 100 |
+
self.backbone.train()
|
| 101 |
+
|
| 102 |
+
with torch.no_grad(): hidden_size = self.backbone.config.hidden_size
|
| 103 |
+
|
| 104 |
+
if self.pooling_type == 'fixed_query':
|
| 105 |
+
self.adapter_layer = FixedQueryAdapter(input_dim=hidden_size, num_queries=num_queries, embed_dim=embed_dim, num_heads=num_heads).to(device)
|
| 106 |
+
elif self.pooling_type == 'dynamic_query':
|
| 107 |
+
self.adapter_layer = DynamicQueryAdapter(input_dim=hidden_size, num_queries=num_queries, embed_dim=embed_dim, num_heads=num_heads, generator_bottleneck_dim=generator_bottleneck_dim, generator_dropout_rate=generator_dropout_rate).to(device)
|
| 108 |
+
elif self.pooling_type == 'last_token':
|
| 109 |
+
self.adapter_layer = last_token_pool
|
| 110 |
+
|
| 111 |
+
def forward(self, prompts, batch_size=32):
|
| 112 |
+
embeddings = []
|
| 113 |
+
for i in range(0, len(prompts), batch_size):
|
| 114 |
+
batch_prompts = prompts[i : i + batch_size]
|
| 115 |
+
batch_dict = self.tokenizer(batch_prompts, max_length=2048, padding=True, truncation=True, return_tensors="pt").to(self.device)
|
| 116 |
+
with torch.no_grad():
|
| 117 |
+
outputs = self.backbone(**batch_dict)
|
| 118 |
+
last_hidden_state = outputs.last_hidden_state
|
| 119 |
+
emb = self.adapter_layer(last_hidden_state, batch_dict['attention_mask'])
|
| 120 |
+
|
| 121 |
+
if self.reduce_method == 'avg_pool' and emb.shape[1] > self.target_dim:
|
| 122 |
+
emb = F.adaptive_avg_pool1d(emb.unsqueeze(1), self.target_dim).squeeze(1)
|
| 123 |
+
elif self.reduce_method == 'max_pool' and emb.shape[1] > self.target_dim:
|
| 124 |
+
emb = F.adaptive_max_pool1d(emb.unsqueeze(1), self.target_dim).squeeze(1)
|
| 125 |
+
embeddings.append(emb)
|
| 126 |
+
return torch.cat(embeddings, dim=0)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class V0:
|
| 130 |
+
def __init__(self, embedding_model, tabpfn_model, device):
|
| 131 |
+
self.embedding_model = embedding_model
|
| 132 |
+
self.tabpfn = tabpfn_model
|
| 133 |
+
self.device = device
|
| 134 |
+
|
| 135 |
+
@classmethod
|
| 136 |
+
def from_pretrained(cls,
|
| 137 |
+
checkpoint_path,
|
| 138 |
+
embedding_model_path,
|
| 139 |
+
tabpfn_head_path,
|
| 140 |
+
device="cuda",
|
| 141 |
+
num_queries=168,
|
| 142 |
+
embed_dim=6,
|
| 143 |
+
num_heads=3,
|
| 144 |
+
bottleneck_dim=128,
|
| 145 |
+
tabpfn_estimators=4):
|
| 146 |
+
|
| 147 |
+
# 1. Initialize Embedding Model (Qwen + Adapter)
|
| 148 |
+
embedding_model = QwenEmbeddingModel(
|
| 149 |
+
model_path=embedding_model_path,
|
| 150 |
+
num_queries=num_queries,
|
| 151 |
+
embed_dim=embed_dim,
|
| 152 |
+
num_heads=num_heads,
|
| 153 |
+
generator_bottleneck_dim=bottleneck_dim,
|
| 154 |
+
generator_dropout_rate=0.0, # Dropout not needed for inference
|
| 155 |
+
device=device
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
# 2. Load Trained Weights (Adapter + potentially Backbone)
|
| 159 |
+
ckpt = torch.load(checkpoint_path, map_location=device)
|
| 160 |
+
state_dict = ckpt['model_state_dict']
|
| 161 |
+
|
| 162 |
+
# Clean DDP 'module.' prefix if present
|
| 163 |
+
if list(state_dict.keys())[0].startswith('module.'):
|
| 164 |
+
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
|
| 165 |
+
|
| 166 |
+
# Load weights
|
| 167 |
+
msg = embedding_model.load_state_dict(state_dict, strict=False)
|
| 168 |
+
|
| 169 |
+
# 3. Initialize TabPFN
|
| 170 |
+
tabpfn = TabPFNClassifier(
|
| 171 |
+
model_path=tabpfn_head_path,
|
| 172 |
+
device=device,
|
| 173 |
+
n_estimators=tabpfn_estimators,
|
| 174 |
+
inference_precision=torch.float32,
|
| 175 |
+
differentiable_input=True # As per training script
|
| 176 |
+
)
|
| 177 |
+
# Manual init to ensure weights are loaded
|
| 178 |
+
tabpfn._initialize_model_variables()
|
| 179 |
+
|
| 180 |
+
return cls(embedding_model, tabpfn, device)
|
| 181 |
+
|
| 182 |
+
def predict(self, context_prompts, context_labels, target_prompts, batch_size=32):
|
| 183 |
+
"""
|
| 184 |
+
Args:
|
| 185 |
+
context_prompts: List[str] - Support Set Texts
|
| 186 |
+
context_labels: List[float] - Support Set Scores (0.0 to 1.0)
|
| 187 |
+
target_prompts: List[str] - Query Set Texts to be scored
|
| 188 |
+
Returns:
|
| 189 |
+
scores: List[float] - Predicted scores (probability of class 1)
|
| 190 |
+
"""
|
| 191 |
+
# 1. Encode Context (Support Set)
|
| 192 |
+
X_sup = self.embedding_model(context_prompts, batch_size=batch_size)
|
| 193 |
+
|
| 194 |
+
# 2. Process Labels (Training script logic: >= 0.5 is Positive)
|
| 195 |
+
y_sup = torch.tensor(context_labels, device=self.device)
|
| 196 |
+
y_sup_hard = (y_sup >= 0.5).long() # Convert to class indices 0 or 1
|
| 197 |
+
|
| 198 |
+
# 3. Fit TabPFN (In-Context Learning)
|
| 199 |
+
# TabPFN learns from this specific batch of context
|
| 200 |
+
self.tabpfn.fit(X_sup, y_sup_hard)
|
| 201 |
+
|
| 202 |
+
# 4. Encode Targets (Query Set)
|
| 203 |
+
X_que = self.embedding_model(target_prompts)
|
| 204 |
+
|
| 205 |
+
# 5. Predict
|
| 206 |
+
# use_inference_mode=True as per eval logic in run_epoch
|
| 207 |
+
with torch.no_grad():
|
| 208 |
+
logits = self.tabpfn.forward(X_que, use_inference_mode=True, return_logits=True)
|
| 209 |
+
probs = torch.softmax(logits, dim=1)
|
| 210 |
+
|
| 211 |
+
# Return probability of the positive class (class 1)
|
| 212 |
+
# If batch size is 1, output might be squeezed, handling that:
|
| 213 |
+
if probs.dim() == 1:
|
| 214 |
+
return [probs[1].item()]
|
| 215 |
+
else:
|
| 216 |
+
return probs[:, 1].tolist()
|
v0_core/utils/__init__.py
ADDED
|
File without changes
|
v0_core/utils/checkpoint.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import glob
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
# =============================================================================
|
| 7 |
+
# Checkpoint 管理器
|
| 8 |
+
# =============================================================================
|
| 9 |
+
class CheckpointManager:
|
| 10 |
+
def __init__(self, checkpoint_dir, max_keep=2, is_master=False):
|
| 11 |
+
self.checkpoint_dir = checkpoint_dir
|
| 12 |
+
self.max_keep = max_keep
|
| 13 |
+
self.is_master = is_master
|
| 14 |
+
if self.is_master and self.checkpoint_dir:
|
| 15 |
+
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
| 16 |
+
|
| 17 |
+
def save(self, model, optimizer, scheduler, epoch, args, wandb_run_id=None):
|
| 18 |
+
if not self.is_master or not self.checkpoint_dir: return
|
| 19 |
+
raw_model = model.module if hasattr(model, 'module') else model
|
| 20 |
+
state = {
|
| 21 |
+
'epoch': epoch,
|
| 22 |
+
'model_state_dict': raw_model.state_dict(),
|
| 23 |
+
'optimizer_state_dict': optimizer.state_dict() if optimizer else None,
|
| 24 |
+
'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
|
| 25 |
+
'args': vars(args),
|
| 26 |
+
'wandb_run_id': wandb_run_id
|
| 27 |
+
}
|
| 28 |
+
filename = f"checkpoint_epoch_{epoch:04d}.pt"
|
| 29 |
+
filepath = os.path.join(self.checkpoint_dir, filename)
|
| 30 |
+
tmp_filepath = filepath + ".tmp"
|
| 31 |
+
print(f">> Saving Checkpoint to {filepath} (Atomic)...")
|
| 32 |
+
try:
|
| 33 |
+
# 1. 先写入临时文件
|
| 34 |
+
torch.save(state, tmp_filepath)
|
| 35 |
+
# 2. 强制刷盘,确保数据落盘
|
| 36 |
+
if os.path.exists(tmp_filepath):
|
| 37 |
+
with open(tmp_filepath, 'rb') as f:
|
| 38 |
+
os.fsync(f.fileno())
|
| 39 |
+
# 3. 原子重命名 (如果掉电发生在这里之前,旧文件还在;之后,新文件生效)
|
| 40 |
+
os.replace(tmp_filepath, filepath)
|
| 41 |
+
except Exception as e:
|
| 42 |
+
print(f"Error saving checkpoint: {e}")
|
| 43 |
+
if os.path.exists(tmp_filepath):
|
| 44 |
+
os.remove(tmp_filepath)
|
| 45 |
+
return
|
| 46 |
+
self._rotate_checkpoints()
|
| 47 |
+
|
| 48 |
+
def _rotate_checkpoints(self):
|
| 49 |
+
# 保持原逻辑不变,但增加健壮性检查
|
| 50 |
+
files = glob.glob(os.path.join(self.checkpoint_dir, "checkpoint_epoch_*.pt"))
|
| 51 |
+
# 过滤掉 .tmp 文件
|
| 52 |
+
files = [f for f in files if not f.endswith('.tmp')]
|
| 53 |
+
|
| 54 |
+
def extract_epoch(f):
|
| 55 |
+
try:
|
| 56 |
+
match = re.search(r"epoch_(\d+)", f)
|
| 57 |
+
return int(match.group(1)) if match else -1
|
| 58 |
+
except: return -1
|
| 59 |
+
|
| 60 |
+
files.sort(key=extract_epoch)
|
| 61 |
+
if len(files) > self.max_keep:
|
| 62 |
+
to_delete = files[: -self.max_keep]
|
| 63 |
+
for f in to_delete:
|
| 64 |
+
try:
|
| 65 |
+
print(f"Removing old checkpoint: {f}")
|
| 66 |
+
os.remove(f)
|
| 67 |
+
except OSError as e:
|
| 68 |
+
print(f"Error removing {f}: {e}")
|
| 69 |
+
|
| 70 |
+
def find_latest_epoch_num(self):
|
| 71 |
+
if not self.checkpoint_dir or not os.path.exists(self.checkpoint_dir): return 0
|
| 72 |
+
files = glob.glob(os.path.join(self.checkpoint_dir, "checkpoint_epoch_*.pt"))
|
| 73 |
+
files = [f for f in files if not f.endswith('.tmp')]
|
| 74 |
+
if not files: return 0
|
| 75 |
+
def extract_epoch(f):
|
| 76 |
+
match = re.search(r"epoch_(\d+)", f)
|
| 77 |
+
return int(match.group(1)) if match else -1
|
| 78 |
+
files.sort(key=extract_epoch)
|
| 79 |
+
return extract_epoch(files[-1])
|
| 80 |
+
|
| 81 |
+
def load_specific_epoch(self, target_epoch, model, optimizer, scheduler, device):
|
| 82 |
+
if target_epoch <= 0: return 1
|
| 83 |
+
filename = f"checkpoint_epoch_{target_epoch:04d}.pt"
|
| 84 |
+
filepath = os.path.join(self.checkpoint_dir, filename)
|
| 85 |
+
if not os.path.exists(filepath):
|
| 86 |
+
import time
|
| 87 |
+
print(f">> [Warning] Checkpoint {filepath} not found immediately. Waiting for FS sync...")
|
| 88 |
+
time.sleep(5)
|
| 89 |
+
if not os.path.exists(filepath): raise FileNotFoundError(f"Checkpoint {filepath} does not exist.")
|
| 90 |
+
print(f">> Resuming from checkpoint: {filepath}")
|
| 91 |
+
checkpoint = torch.load(filepath, map_location=device)
|
| 92 |
+
|
| 93 |
+
state_dict = checkpoint['model_state_dict']
|
| 94 |
+
raw_model = model.module if hasattr(model, 'module') else model
|
| 95 |
+
|
| 96 |
+
# 检查是否 key 不匹配 (例如保存时有 module. 读取时没有,或者反之)
|
| 97 |
+
model_keys = set(raw_model.state_dict().keys())
|
| 98 |
+
ckpt_keys = set(state_dict.keys())
|
| 99 |
+
|
| 100 |
+
# 简单的 key 修正逻辑
|
| 101 |
+
if list(model_keys)[0].startswith('module.') and not list(ckpt_keys)[0].startswith('module.'):
|
| 102 |
+
state_dict = {f"module.{k}": v for k, v in state_dict.items()}
|
| 103 |
+
elif not list(model_keys)[0].startswith('module.') and list(ckpt_keys)[0].startswith('module.'):
|
| 104 |
+
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
|
| 105 |
+
|
| 106 |
+
raw_model.load_state_dict(state_dict)
|
| 107 |
+
|
| 108 |
+
if optimizer is not None and 'optimizer_state_dict' in checkpoint:
|
| 109 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 110 |
+
if scheduler is not None and 'scheduler_state_dict' in checkpoint:
|
| 111 |
+
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
| 112 |
+
|
| 113 |
+
start_epoch = checkpoint['epoch'] + 1
|
| 114 |
+
wandb_id = checkpoint.get('wandb_run_id', None)
|
| 115 |
+
|
| 116 |
+
print(f"✅ Successfully resumed. Next epoch: {start_epoch}. WandB ID: {wandb_id}")
|
| 117 |
+
return start_epoch, wandb_id
|
v0_core/utils/metrics.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
from sklearn.metrics import roc_auc_score, accuracy_score
|
| 5 |
+
import torch.distributed as dist
|
| 6 |
+
|
| 7 |
+
def append_jsonl(path, data):
|
| 8 |
+
try:
|
| 9 |
+
with open(path, 'a', encoding='utf-8') as f:
|
| 10 |
+
f.write(json.dumps(data, ensure_ascii=False) + '\n')
|
| 11 |
+
except Exception as e:
|
| 12 |
+
print(f"Error appending to jsonl: {e}")
|
| 13 |
+
|
| 14 |
+
# =============================================================================
|
| 15 |
+
# Global Metrics Calculation & Aggregation
|
| 16 |
+
# =============================================================================
|
| 17 |
+
def calculate_metrics_by_group(all_results, phase, epoch, is_master=True, output_dir=None, dataset_name_tag="", avg_loss=None):
|
| 18 |
+
|
| 19 |
+
# 1. Gather from all ranks
|
| 20 |
+
world_size = dist.get_world_size()
|
| 21 |
+
gathered_results = [None for _ in range(world_size)]
|
| 22 |
+
dist.all_gather_object(gathered_results, all_results)
|
| 23 |
+
|
| 24 |
+
if not is_master:
|
| 25 |
+
return {}
|
| 26 |
+
|
| 27 |
+
# Flatten list of lists
|
| 28 |
+
flat_results = []
|
| 29 |
+
for rank_res in gathered_results:
|
| 30 |
+
flat_results.extend(rank_res)
|
| 31 |
+
|
| 32 |
+
print(f"[{phase}] Collected {len(flat_results)} samples for evaluation.")
|
| 33 |
+
|
| 34 |
+
if len(flat_results) == 0:
|
| 35 |
+
return {}
|
| 36 |
+
|
| 37 |
+
metrics_summary = {"epoch": epoch}
|
| 38 |
+
# =========================================================================
|
| 39 |
+
# Part A: Pair-wise Metrics Calculation (Global)
|
| 40 |
+
# =========================================================================
|
| 41 |
+
|
| 42 |
+
pair_grouping = defaultdict(lambda: {'pos': [], 'neg': []})
|
| 43 |
+
|
| 44 |
+
for r in flat_results:
|
| 45 |
+
pid = r.get('pair_id')
|
| 46 |
+
if pid is not None:
|
| 47 |
+
if r['label'] >= 0.5:
|
| 48 |
+
pair_grouping[pid]['pos'].append(r)
|
| 49 |
+
else:
|
| 50 |
+
pair_grouping[pid]['neg'].append(r)
|
| 51 |
+
|
| 52 |
+
valid_pairs = []
|
| 53 |
+
|
| 54 |
+
for pid, group in pair_grouping.items():
|
| 55 |
+
if len(group['pos']) == 1 and len(group['neg']) == 1:
|
| 56 |
+
valid_pairs.append((group['pos'][0], group['neg'][0]))
|
| 57 |
+
|
| 58 |
+
total_valid_pairs = len(valid_pairs)
|
| 59 |
+
strict_pair_correct_count = 0
|
| 60 |
+
rlhf_pair_correct_count = 0
|
| 61 |
+
|
| 62 |
+
# 3. Calculate Global Pair Metrics
|
| 63 |
+
for pos_item, neg_item in valid_pairs:
|
| 64 |
+
# Strict
|
| 65 |
+
if (pos_item['pred'] == 1) and (neg_item['pred'] == 0):
|
| 66 |
+
strict_pair_correct_count += 1
|
| 67 |
+
# RLHF
|
| 68 |
+
if pos_item['prob'] > neg_item['prob']:
|
| 69 |
+
rlhf_pair_correct_count += 1
|
| 70 |
+
|
| 71 |
+
metrics_summary[f"{phase}/global_strict_pair_acc"] = strict_pair_correct_count / total_valid_pairs if total_valid_pairs > 0 else -1
|
| 72 |
+
metrics_summary[f"{phase}/global_rlhf_pair_acc"] = rlhf_pair_correct_count / total_valid_pairs if total_valid_pairs > 0 else -1
|
| 73 |
+
metrics_summary[f"{phase}/num_valid_pairs"] = total_valid_pairs
|
| 74 |
+
|
| 75 |
+
# =========================================================================
|
| 76 |
+
# Part B: Standard Global Metrics (Acc / AUC)
|
| 77 |
+
# =========================================================================
|
| 78 |
+
y_true_binary = [1 if r['label'] >= 0.5 else 0 for r in flat_results]
|
| 79 |
+
y_scores = [r['prob'] for r in flat_results]
|
| 80 |
+
y_preds = [r['pred'] for r in flat_results]
|
| 81 |
+
|
| 82 |
+
def get_auc_strict(y_t, y_s):
|
| 83 |
+
try:
|
| 84 |
+
return roc_auc_score(y_t, y_s) if len(set(y_t)) > 1 else -1
|
| 85 |
+
except:
|
| 86 |
+
return -1
|
| 87 |
+
|
| 88 |
+
g_auc = get_auc_strict(y_true_binary, y_scores)
|
| 89 |
+
metrics_summary[f"{phase}/global_acc"] = accuracy_score(y_true_binary, y_preds)
|
| 90 |
+
metrics_summary[f"{phase}/global_auc"] = g_auc
|
| 91 |
+
|
| 92 |
+
if avg_loss is not None:
|
| 93 |
+
metrics_summary[f"{phase}/loss"] = avg_loss
|
| 94 |
+
|
| 95 |
+
# =========================================================================
|
| 96 |
+
# Part C: Step-wise Metrics
|
| 97 |
+
# =========================================================================
|
| 98 |
+
step_groups = defaultdict(list)
|
| 99 |
+
for r in flat_results:
|
| 100 |
+
step_groups[r['step']].append(r)
|
| 101 |
+
|
| 102 |
+
step_valid_pairs = defaultdict(list)
|
| 103 |
+
for pos_item, neg_item in valid_pairs:
|
| 104 |
+
if pos_item['step'] == neg_item['step']:
|
| 105 |
+
step_valid_pairs[pos_item['step']].append((pos_item, neg_item))
|
| 106 |
+
|
| 107 |
+
print(f"[{phase}] Calculating metrics for {len(step_groups)} distinct steps...")
|
| 108 |
+
|
| 109 |
+
gauc_weighted_sum = 0.0
|
| 110 |
+
gauc_total_weight = 0.0
|
| 111 |
+
valid_gauc_steps = 0
|
| 112 |
+
|
| 113 |
+
step_details_list = []
|
| 114 |
+
|
| 115 |
+
for step_val, items in step_groups.items():
|
| 116 |
+
s_true = [1 if x['label'] >= 0.5 else 0 for x in items]
|
| 117 |
+
s_scores = [x['prob'] for x in items]
|
| 118 |
+
s_preds = [x['pred'] for x in items]
|
| 119 |
+
|
| 120 |
+
# 1. Basic Step Metrics
|
| 121 |
+
step_acc = accuracy_score(s_true, s_preds)
|
| 122 |
+
step_auc = get_auc_strict(s_true, s_scores) # Returns None if only 1 class
|
| 123 |
+
|
| 124 |
+
step_record = {
|
| 125 |
+
"step": step_val,
|
| 126 |
+
"count": len(items),
|
| 127 |
+
"acc": step_acc,
|
| 128 |
+
"auc": step_auc
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
if step_auc != -1:
|
| 132 |
+
weight = len(items)
|
| 133 |
+
gauc_weighted_sum += step_auc * weight
|
| 134 |
+
gauc_total_weight += weight
|
| 135 |
+
valid_gauc_steps += 1
|
| 136 |
+
|
| 137 |
+
# 3. Step Pair Metrics
|
| 138 |
+
pairs_in_step = step_valid_pairs.get(step_val, [])
|
| 139 |
+
n_pairs = len(pairs_in_step)
|
| 140 |
+
|
| 141 |
+
if n_pairs > 0:
|
| 142 |
+
s_strict_corr = sum(1 for p, n in pairs_in_step if (p['pred'] == 1 and n['pred'] == 0))
|
| 143 |
+
s_rlhf_corr = sum(1 for p, n in pairs_in_step if p['prob'] > n['prob'])
|
| 144 |
+
|
| 145 |
+
step_record["pair_count"] = n_pairs
|
| 146 |
+
step_record["strict_pair_acc"] = s_strict_corr / n_pairs
|
| 147 |
+
step_record["rlhf_pair_acc"] = s_rlhf_corr / n_pairs
|
| 148 |
+
else:
|
| 149 |
+
step_record["pair_count"] = 0
|
| 150 |
+
step_record["strict_pair_acc"] = -1
|
| 151 |
+
step_record["rlhf_pair_acc"] = -1
|
| 152 |
+
|
| 153 |
+
step_details_list.append(step_record)
|
| 154 |
+
|
| 155 |
+
# Calculate Weighted gAUC
|
| 156 |
+
final_gauc = gauc_weighted_sum / gauc_total_weight if gauc_total_weight > 0 else -1
|
| 157 |
+
|
| 158 |
+
metrics_summary[f"{phase}/gAUC"] = final_gauc
|
| 159 |
+
metrics_summary[f"{phase}/gAUC_valid_steps"] = valid_gauc_steps
|
| 160 |
+
|
| 161 |
+
print(f"[{phase}] gAUC: {final_gauc:.4f} (Computed over {valid_gauc_steps} valid steps out of {len(step_groups)})")
|
| 162 |
+
|
| 163 |
+
# =========================================================================
|
| 164 |
+
# Part D: Save Logs
|
| 165 |
+
# =========================================================================
|
| 166 |
+
if output_dir:
|
| 167 |
+
# 1. Save Raw Predictions (Keep as is)
|
| 168 |
+
log_filename = f"{phase}_predictions_epoch_{epoch}{dataset_name_tag}.jsonl"
|
| 169 |
+
log_path = os.path.join(output_dir, log_filename)
|
| 170 |
+
valid_pair_ids = set(p[0]['pair_id'] for p in valid_pairs)
|
| 171 |
+
|
| 172 |
+
print(f"Saving raw predictions to {log_path}...")
|
| 173 |
+
with open(log_path, 'w', encoding='utf-8') as f:
|
| 174 |
+
for item in flat_results:
|
| 175 |
+
item['is_valid_pair_part'] = item.get('pair_id') in valid_pair_ids
|
| 176 |
+
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
| 177 |
+
|
| 178 |
+
# 2. Save Global Metrics (Only summary)
|
| 179 |
+
metric_filename = "all_metrics.jsonl"
|
| 180 |
+
metric_path = os.path.join(output_dir, metric_filename)
|
| 181 |
+
append_jsonl(metric_path, metrics_summary)
|
| 182 |
+
|
| 183 |
+
# 3. [NEW] Save Step-wise Details to a separate file
|
| 184 |
+
step_log_filename = f"{phase}_step_metrics_epoch_{epoch}{dataset_name_tag}.jsonl"
|
| 185 |
+
step_log_path = os.path.join(output_dir, step_log_filename)
|
| 186 |
+
print(f"Saving step-wise metrics to {step_log_path}...")
|
| 187 |
+
|
| 188 |
+
# Sort by step for readability
|
| 189 |
+
step_details_list.sort(key=lambda x: x['step'] if isinstance(x['step'], int) else -1)
|
| 190 |
+
|
| 191 |
+
with open(step_log_path, 'w', encoding='utf-8') as f:
|
| 192 |
+
for item in step_details_list:
|
| 193 |
+
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
| 194 |
+
|
| 195 |
+
return metrics_summary
|
v0_core/utils/tabpfn_patches.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import logging
|
| 3 |
+
import numpy as np
|
| 4 |
+
try:
|
| 5 |
+
from tabpfn import TabPFNClassifier
|
| 6 |
+
from tabpfn.base import create_inference_engine, determine_precision
|
| 7 |
+
from tabpfn.utils import infer_random_state
|
| 8 |
+
from tabpfn.classifier import _validate_eval_metric
|
| 9 |
+
from tabpfn.inference import InferenceEngineBatchedNoPreprocessing
|
| 10 |
+
except ImportError as e:
|
| 11 |
+
print(f"导入 TabPFN 模块失败: {e}")
|
| 12 |
+
print("请确保已安装 tabpfn,并且处于包含 tabpfn 源代码的环境中。")
|
| 13 |
+
exit(1)
|
| 14 |
+
|
| 15 |
+
def fixed_fit(self, X, y) -> "TabPFNClassifier":
|
| 16 |
+
"""修复 fit 方法:解决 differentiable_input=True 时 ensemble_configs 未定义的问题"""
|
| 17 |
+
self.eval_metric_ = _validate_eval_metric(self.eval_metric)
|
| 18 |
+
|
| 19 |
+
if self.fit_mode == "batched":
|
| 20 |
+
logging.warning("Switching from 'batched' to 'fit_preprocessors' mode...")
|
| 21 |
+
self.fit_mode = "fit_preprocessors"
|
| 22 |
+
|
| 23 |
+
if not hasattr(self, "models_") or not self.differentiable_input:
|
| 24 |
+
byte_size, rng = self._initialize_model_variables()
|
| 25 |
+
ensemble_configs, X, y = self._initialize_dataset_preprocessing(X, y, rng)
|
| 26 |
+
else:
|
| 27 |
+
_, rng = infer_random_state(self.random_state)
|
| 28 |
+
_, _, byte_size = determine_precision(self.inference_precision, self.devices_)
|
| 29 |
+
ensemble_configs, X, y = self._initialize_dataset_preprocessing(X, y, rng)
|
| 30 |
+
|
| 31 |
+
self._maybe_calibrate_temperature_and_tune_decision_thresholds(X=X, y=y)
|
| 32 |
+
|
| 33 |
+
self.executor_ = create_inference_engine(
|
| 34 |
+
X_train=X,
|
| 35 |
+
y_train=y,
|
| 36 |
+
models=self.models_,
|
| 37 |
+
ensemble_configs=ensemble_configs,
|
| 38 |
+
cat_ix=self.inferred_categorical_indices_,
|
| 39 |
+
fit_mode=self.fit_mode,
|
| 40 |
+
devices_=self.devices_,
|
| 41 |
+
rng=rng,
|
| 42 |
+
n_preprocessing_jobs=self.n_preprocessing_jobs,
|
| 43 |
+
byte_size=byte_size,
|
| 44 |
+
forced_inference_dtype_=self.forced_inference_dtype_,
|
| 45 |
+
memory_saving_mode=self.memory_saving_mode,
|
| 46 |
+
use_autocast_=self.use_autocast_,
|
| 47 |
+
inference_mode=not self.differentiable_input,
|
| 48 |
+
)
|
| 49 |
+
return self
|
| 50 |
+
|
| 51 |
+
def fixed_forward(
|
| 52 |
+
self,
|
| 53 |
+
X: list[torch.Tensor] | torch.Tensor,
|
| 54 |
+
*,
|
| 55 |
+
use_inference_mode: bool = False,
|
| 56 |
+
return_logits: bool = False,
|
| 57 |
+
return_raw_logits: bool = False,
|
| 58 |
+
) -> torch.Tensor:
|
| 59 |
+
"""修复 forward 方法:允许 standard inference 下保留梯度"""
|
| 60 |
+
if return_logits and return_raw_logits:
|
| 61 |
+
raise ValueError("Cannot return both logits and raw logits.")
|
| 62 |
+
|
| 63 |
+
is_standard_inference = not isinstance(
|
| 64 |
+
self.executor_, InferenceEngineBatchedNoPreprocessing
|
| 65 |
+
)
|
| 66 |
+
is_batched_for_grads = (
|
| 67 |
+
not use_inference_mode
|
| 68 |
+
and isinstance(self.executor_, InferenceEngineBatchedNoPreprocessing)
|
| 69 |
+
and isinstance(X, list)
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
assert is_standard_inference or is_batched_for_grads, "Invalid forward pass."
|
| 73 |
+
|
| 74 |
+
if self.fit_mode in ["fit_preprocessors", "batched"]:
|
| 75 |
+
self.executor_.use_torch_inference_mode(use_inference=use_inference_mode)
|
| 76 |
+
|
| 77 |
+
outputs = []
|
| 78 |
+
for output, config in self.executor_.iter_outputs(X, autocast=self.use_autocast_):
|
| 79 |
+
processed_output = output.unsqueeze(1) if output.ndim == 2 else output
|
| 80 |
+
config_list = [config] if output.ndim == 2 else config
|
| 81 |
+
|
| 82 |
+
output_batch = []
|
| 83 |
+
for i, batch_config in enumerate(config_list):
|
| 84 |
+
if batch_config.class_permutation is None:
|
| 85 |
+
output_batch.append(processed_output[:, i, : self.n_classes_])
|
| 86 |
+
else:
|
| 87 |
+
use_perm = batch_config.class_permutation
|
| 88 |
+
if len(use_perm) != self.n_classes_:
|
| 89 |
+
full_perm = np.arange(self.n_classes_)
|
| 90 |
+
full_perm[:len(use_perm)] = use_perm
|
| 91 |
+
use_perm = full_perm
|
| 92 |
+
output_batch.append(processed_output[:, i, use_perm])
|
| 93 |
+
outputs.append(torch.stack(output_batch, dim=1))
|
| 94 |
+
|
| 95 |
+
stacked_outputs = torch.stack(outputs) # (Chunks, Samples, Est, Classes)
|
| 96 |
+
|
| 97 |
+
if return_logits:
|
| 98 |
+
temp_scaled = self._apply_temperature(stacked_outputs)
|
| 99 |
+
output = temp_scaled.mean(dim=(0, 2))
|
| 100 |
+
elif return_raw_logits:
|
| 101 |
+
output = stacked_outputs
|
| 102 |
+
else:
|
| 103 |
+
temp_scaled = self._apply_temperature(stacked_outputs)
|
| 104 |
+
avg_logits = temp_scaled.mean(dim=(0, 2))
|
| 105 |
+
output = torch.nn.functional.softmax(avg_logits, dim=-1)
|
| 106 |
+
|
| 107 |
+
if not use_inference_mode:
|
| 108 |
+
if return_logits and output.ndim == 2:
|
| 109 |
+
return output
|
| 110 |
+
if output.ndim == 2:
|
| 111 |
+
output = output.unsqueeze(0)
|
| 112 |
+
output = output.transpose(0, 1).transpose(1, 2)
|
| 113 |
+
elif output.ndim > 2 and use_inference_mode:
|
| 114 |
+
output = output.squeeze(1) if not return_raw_logits else output.squeeze(2)
|
| 115 |
+
|
| 116 |
+
return output
|