zhangyikai's picture
Update UI notes
aba7b4a
import gradio as gr
import pandas as pd
import json
import random
import os
from huggingface_hub import hf_hub_download, snapshot_download
# ==========================================
# V0 Model Import & Initialization
# ==========================================
MODEL_REPO_ID = "Now-Join-Us/Generalist-Value-Model-V0"
EMBEDDING_REPO_ID = "Qwen/Qwen3-Embedding-0.6B"
v0_model = None
print(">>> Starting V0 App...")
try:
# Ensure v0_core is accessible.
# If v0_core is in the root directory, this import works.
from v0_core.models.v0 import V0
print(f">>> Downloading models from Hugging Face Hub...")
# 1. Download V0 Checkpoint
print(f" - Fetching checkpoint from {MODEL_REPO_ID}...")
checkpoint_path = hf_hub_download(
repo_id=MODEL_REPO_ID,
filename="v_0_for_router_demo.pt"
)
# 2. Download TabPFN Classifier
print(f" - Fetching TabPFN head from {MODEL_REPO_ID}...")
tabpfn_path = hf_hub_download(
repo_id=MODEL_REPO_ID,
filename="pretrained/tabpfn-v2.5-classifier-v2.5_default.ckpt"
)
# 3. Download Qwen Embedding Model (Snapshot because we need config, tokenizer, etc.)
print(f" - Fetching Embedding model from {EMBEDDING_REPO_ID}...")
embedding_path = snapshot_download(
repo_id=EMBEDDING_REPO_ID
)
print(">>> All assets downloaded. Initializing V0 class...")
# Determine device (Defaulting to CPU as per request, but CUDA is preferred if available)
# device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"
print(f">>> Device selected: {device}")
# Initialize the Model
v0_model = V0.from_pretrained(
checkpoint_path=checkpoint_path,
embedding_model_path=embedding_path,
tabpfn_head_path=tabpfn_path,
device=device
)
# Set to eval mode if applicable (V0 usually handles this internally, but good practice)
if hasattr(v0_model, 'eval'):
v0_model.eval()
print(f">>> V0 Model Loaded Successfully on {device}!")
except Exception as e:
print(f"Warning: Failed to load V0 model or dependencies. Prediction will fail. Error: {e}")
v0_model = None
# ==========================================
# 0. 配置与数据加载逻辑
# ==========================================
# 指定的目标模型列表
TARGET_MODELS = [
'Qwen3-30B-A3B-Instruct-2507',
'Qwen3-4B-Instruct-2507',
'DeepSeek-R1-Distill-Qwen-1.5B',
"Qwen3-32B",
"DeepSeek-R1-Distill-Qwen-7B",
"Qwen3-0.6B",
]
# 指定的数据集列表
TARGET_DATASETS = [
'aime_2024', 'aime_2025', 'amc23', 'gaokao_math_cloze', 'gpqa_diamond', 'olympiad'
]
# 全局变量:Prompt 字典缓存
PROMPT_DICT_CACHE = None
# 全局变量:性能数据缓存 { "Model_Dataset": { id_int: score_float } }
PERFORMANCE_DB = {}
def load_prompt_dict():
"""读取 Prompt 字典文件"""
global PROMPT_DICT_CACHE
if PROMPT_DICT_CACHE is not None:
return PROMPT_DICT_CACHE
path = "data/router_context_sampled/prompt_dict.json"
try:
with open(path, "r", encoding="utf-8") as f:
PROMPT_DICT_CACHE = json.load(f)
print(f"Loaded {len(PROMPT_DICT_CACHE)} prompts from {path}")
except FileNotFoundError:
print(f"Warning: {path} not found. Using empty dict.")
PROMPT_DICT_CACHE = {}
return PROMPT_DICT_CACHE
def load_performance_data(model_name, dataset_name):
"""
按需加载特定模型在特定数据集上的性能数据到内存缓存中
Key格式: f"{model_name}_{dataset_name}"
"""
cache_key = f"{model_name}_{dataset_name}"
# 如果已经缓存过,直接返回
if cache_key in PERFORMANCE_DB:
return PERFORMANCE_DB[cache_key]
file_path = f"data/router_val_rollouts_and_performance/{model_name}_{dataset_name}.jsonl"
data_map = {}
if os.path.exists(file_path):
try:
with open(file_path, "r", encoding="utf-8") as f:
for line in f:
try:
item = json.loads(line)
idx = item.get("id")
score = item.get("mean@10")
if idx is not None and score is not None:
data_map[idx] = (float(score) + 1.0) / 2.0
except json.JSONDecodeError:
continue
print(f"Loaded performance map for {cache_key}: {len(data_map)} entries.")
except Exception as e:
print(f"Error loading performance file {file_path}: {e}")
else:
pass
PERFORMANCE_DB[cache_key] = data_map
return data_map
def load_real_model_data(model_name):
prompt_db = load_prompt_dict()
file_path = f"data/router_context_sampled/performance/{model_name}.jsonl"
# 临时存储解析后的所有数据
all_parsed_data = []
if not os.path.exists(file_path):
return [{
"prompt": f"Error: File {file_path} not found.",
"score": 0.0,
"is_correct": False
}]
try:
with open(file_path, "r", encoding="utf-8") as f:
lines = f.readlines()
# 1. 先解析所有行,获取 label 信息
for line in lines:
try:
item = json.loads(line)
dataset = item.get("dataset", "")
idx = item.get("id", "")
key = f"{dataset}_{idx}"
prompt_text = prompt_db.get(key, f"[Prompt missing for key: {key}]")
raw_score = float(item.get("mean@10", -1.0))
normalized_score = (raw_score + 1.0) / 2.0
is_correct = normalized_score > 0.6
all_parsed_data.append({
"prompt": prompt_text,
"score": normalized_score,
"is_correct": is_correct
})
except json.JSONDecodeError:
continue
# 2. 采样逻辑 (Target = 16)
target_count = 16
if len(all_parsed_data) <= target_count:
return all_parsed_data
# 初步随机采样 16 个
sampled_data = random.sample(all_parsed_data, target_count)
# 检查正负样本分布
positives = [x for x in sampled_data if x['is_correct']]
negatives = [x for x in sampled_data if not x['is_correct']]
# 情况 A: 全是正样本 (All Positive) -> 强行塞入一个负样本
if len(negatives) == 0:
# 从原始池中寻找所有负样本
all_negatives_pool = [x for x in all_parsed_data if not x['is_correct']]
if all_negatives_pool:
# 随机选一个负样本
replacement = random.choice(all_negatives_pool)
# 随机替换掉当前采样中的一个
replace_idx = random.randint(0, target_count - 1)
sampled_data[replace_idx] = replacement
# 情况 B: 全是负样本 (All Negative) -> 强行塞入一个正样本
elif len(positives) == 0:
# 从原始池中寻找所有正样本
all_positives_pool = [x for x in all_parsed_data if x['is_correct']]
if all_positives_pool:
# 随机选一个正样本
replacement = random.choice(all_positives_pool)
# 随机替换掉当前采样中的一个
replace_idx = random.randint(0, target_count - 1)
sampled_data[replace_idx] = replacement
return sampled_data
except Exception as e:
print(f"Error reading {file_path}: {e}")
return []
def load_dataset_batch(dataset_name):
"""
从指定数据集文件中随机采样4条数据
返回: (input_list, id_list, dataset_name)
"""
file_path = f"data/router_val_prompt_and_gts/{dataset_name}.jsonl"
inputs = []
ids = []
try:
if os.path.exists(file_path):
with open(file_path, "r", encoding="utf-8") as f:
lines = f.readlines()
# 随机采样 4 条,需要保留原始索引或ID信息
count = 4
total_lines = len(lines)
if total_lines > 0:
# 随机生成索引
sampled_indices = random.sample(range(total_lines), min(count, total_lines))
for idx in sampled_indices:
line = lines[idx]
try:
item = json.loads(line)
inputs.append(item.get("input", ""))
# 优先使用 JSON 中的 "id",如果没有则尝试使用 "idx",最后使用文件行号
# 注意:需要确保这里的 id 能够对应到 performance 文件中的 id
item_id = item.get("id", item.get("idx", idx))
ids.append(item_id)
except:
inputs.append("Error parsing JSON")
ids.append(-1)
else:
inputs = []
ids = []
else:
inputs = [f"File not found: {file_path}"] * 4
ids = [-1] * 4
except Exception as e:
inputs = [f"Error loading dataset: {str(e)}"] * 4
ids = [-1] * 4
while len(inputs) < 4:
inputs.append("")
ids.append(-1)
return inputs[0], inputs[1], inputs[2], inputs[3], ids, dataset_name
# ==========================================
# 1. 核心逻辑
# ==========================================
def process_custom_file(file_obj):
"""解析用户上传的 JSONL 文件"""
if file_obj is None:
return []
data_list = []
try:
with open(file_obj.name, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if not line: continue
try:
item = json.loads(line)
prompt = item.get('prompt')
if not prompt: continue
# 尝试多种分数格式
score = 0.0
if 'score' in item:
score = float(item['score'])
elif 'mean@10' in item:
score = (float(item['mean@10']) + 1.0) / 2.0
elif 'is_correct' in item:
score = 1.0 if item['is_correct'] else 0.0
is_correct = score > 0.6
data_list.append({
"prompt": prompt,
"score": score,
"is_correct": is_correct
})
except:
continue
except Exception as e:
print(f"Error parsing custom file: {e}")
return []
return data_list
def format_model_card(data_list, model_name, is_empty=False):
if is_empty or not data_list:
return f"""
<div class='model-card empty'>
<div class='card-title'>Empty Slot</div>
<div class='card-subtitle'>Select a model above to load</div>
</div>
"""
total = len(data_list)
rows_html = ""
preview_limit = 3
preview_data = data_list[:preview_limit]
for item in preview_data:
p_text = item.get('prompt', '')
if len(p_text) > 45:
p_text = p_text[:45] + " ... "
score = item.get('score', 0.0)
if score >= 0.8:
status_class = "status-green"
icon = "✔"
elif score >= 0.4:
status_class = "status-yellow"
icon = "~"
else:
status_class = "status-red"
icon = "✘"
rows_html += f"""
<div class='history-row'>
<div class='status-box {status_class}'>{icon}</div>
<div class='prompt-text'>{p_text}</div>
<span style='opacity:0.9; font-size:0.9em; white-space: nowrap; margin-left: 8px;'>{score:.2f}</span>
</div>
"""
remaining = total - preview_limit
if remaining > 0:
rows_html += f"<div class='history-more'>+ {remaining} more items</div>"
return f"""
<div class='model-card populated'>
<div class='card-header'>
<span class='model-name'>{model_name}</span>
<span class='acc-badge'>Context Size: {total}</span>
</div>
<div class='card-body'>
<div class='history-container'>
{rows_html}
</div>
</div>
</div>
"""
def select_model(btn_name, current_pointer, data_s1, name_s1, data_s2, name_s2):
if btn_name == name_s1 or btn_name == name_s2:
html1 = format_model_card(data_s1, name_s1, is_empty=(name_s1 is None))
html2 = format_model_card(data_s2, name_s2, is_empty=(name_s2 is None))
return (html1, html2, current_pointer, data_s1, name_s1, data_s2, name_s2)
new_data = load_real_model_data(btn_name)
if current_pointer == 0:
html1 = format_model_card(new_data, btn_name)
html2 = format_model_card(data_s2, name_s2, is_empty=(name_s2 is None))
return (html1, html2, 1, new_data, btn_name, data_s2, name_s2)
else:
html1 = format_model_card(data_s1, name_s1, is_empty=(name_s1 is None))
html2 = format_model_card(new_data, btn_name)
return (html1, html2, 0, data_s1, name_s1, new_data, btn_name)
def handle_custom_upload(file_obj, current_pointer, data_s1, name_s1, data_s2, name_s2):
"""处理自定义文件上传"""
if file_obj is None:
html1 = format_model_card(data_s1, name_s1, is_empty=(name_s1 is None))
html2 = format_model_card(data_s2, name_s2, is_empty=(name_s2 is None))
return (html1, html2, current_pointer, data_s1, name_s1, data_s2, name_s2)
new_data = process_custom_file(file_obj)
custom_name = f"Custom: {os.path.basename(file_obj.name)[:15]}"
if not new_data:
# 文件解析失败处理
custom_name = "Upload Failed"
new_data = [{"prompt": "Invalid JSONL format", "score": 0.0, "is_correct": False}]
if current_pointer == 0:
html1 = format_model_card(new_data, custom_name)
html2 = format_model_card(data_s2, name_s2, is_empty=(name_s2 is None))
return (html1, html2, 1, new_data, custom_name, data_s2, name_s2)
else:
html1 = format_model_card(data_s1, name_s1, is_empty=(name_s1 is None))
html2 = format_model_card(new_data, custom_name)
return (html1, html2, 0, data_s1, name_s1, new_data, custom_name)
def invalidate_specific_id(current_ids, index):
"""
当用户修改了第 index 个文本框时,将对应的 ID 设为 -1 (N/A)
"""
if current_ids is None:
return [-1, -1, -1, -1]
new_ids = list(current_ids)
while len(new_ids) <= index:
new_ids.append(-1)
new_ids[index] = -1
return new_ids
def predict_performance(data1, name1, data2, name2, t1, t2, t3, t4, current_ids, current_dataset):
"""
使用 V0 模型进行性能预测
"""
if v0_model is None:
return pd.DataFrame([{"Error": "V0 Model not loaded successfully on server."}])
# 1. 收集目标 Instruction
inputs_list = [t1, t2, t3, t4]
target_prompts = []
target_ids = []
target_indices = [] # 记录原始索引以便填回
for i, t in enumerate(inputs_list):
if t.strip():
target_prompts.append(t)
target_indices.append(i)
if current_ids and i < len(current_ids):
target_ids.append(current_ids[i])
else:
target_ids.append(-1)
if not target_prompts:
return pd.DataFrame([{"Error": "Please enter at least one target instruction."}])
# 2. 确定要运行的模型
active_models = []
if name1 and data1: active_models.append((name1, data1))
if name2 and data2: active_models.append((name2, data2))
if not active_models:
return pd.DataFrame([{"Error": "Please select at least one model above."}])
results = []
for m_name, m_history in active_models:
# --- 准备 Context (History) ---
# 提取历史数据中的 prompt 和 score
# 限制最大 Context 长度为 256 (模型要求)
context_prompts = []
context_labels = []
# 过滤无效数据并取前256个
valid_items = [x for x in m_history if isinstance(x.get('score'), (int, float)) and x.get('prompt')]
# 这里假设 m_history 已经是随机采样过的(在 load_real_model_data 中处理)
batch_context = valid_items[:256]
for item in batch_context:
context_prompts.append(item['prompt'])
# 逻辑:先确保 score 在 0-1 之间 (load_real_model_data 已经做了 normalized),
# 然后以 0.5 为分界,转化为 0 或 1 的标签
score_val = item['score']
label = 1 if score_val > 0.5 else 0
context_labels.append(label)
if not context_prompts:
results.append({"Model": m_name, "Error": "No valid context data found."})
continue
# --- 获取 Ground Truth (如果可用) ---
real_perf_map = {}
if current_dataset and "Custom:" not in m_name:
real_perf_map = load_performance_data(m_name, current_dataset)
# --- 运行 V0 预测 ---
try:
# V0 模型一次性处理所有 targets,共享同一个 context
pred_scores = v0_model.predict(
context_prompts=context_prompts,
context_labels=context_labels,
target_prompts=target_prompts
)
except Exception as e:
print(f"Inference error for {m_name}: {e}")
pred_scores = [0.0] * len(target_prompts)
# --- 格式化结果 ---
for idx, t_text in enumerate(target_prompts):
cur_id = target_ids[idx]
# 真实分数
if cur_id != -1 and cur_id in real_perf_map:
real_val = real_perf_map[cur_id]
real_mean_10_str = f"{real_val:.2f}"
else:
real_mean_10_str = "N/A"
# 预测分数
final_score = float(pred_scores[idx])
pred_str = "✔ Success" if final_score > 0.5 else "✘ Failure"
results.append({
"Model": m_name,
"Instruction": t_text,
"Actual mean@10": real_mean_10_str,
"Predicted Score": f"{final_score:.4f}",
"Predicted Class": pred_str
})
return pd.DataFrame(results)
# ==========================================
# 2. CSS 样式
# ==========================================
css = """
/* 全局变量 */
:root {
--primary: #10b981;
--primary-light: #ecfdf5;
--primary-dark: #047857;
--bg-card: #ffffff;
--border-sub: #e5e7eb;
--text-main: #1f2937;
--text-sub: #6b7280;
--success: #10b981;
--fail: #ef4444;
--warning: #f59e0b;
}
.dark {
--bg-card: #1f2937;
--border-sub: #374151;
--text-main: #f3f4f6;
--text-sub: #9ca3af;
--primary-light: #064e3b;
--primary-dark: #ecfdf5;
}
/* 顶部 Banner */
.concept-banner {
background: linear-gradient(135deg, rgba(16, 185, 129, 0.08) 0%, rgba(59, 130, 246, 0.05) 100%);
border: 1px solid var(--primary-light);
border-radius: 12px;
padding: 24px;
text-align: center;
margin-bottom: 20px;
}
.concept-title {
font-size: 1.8em; font-weight: 700; color: var(--text-main); margin-bottom: 8px;
}
.concept-subtitle { font-size: 1em; color: var(--text-sub); max-width: 600px; margin: 0 auto; line-height: 1.5; }
/* 步骤标题 */
.step-header {
display: flex; align-items: center; margin-bottom: 15px;
border-bottom: 2px solid var(--border-sub); padding-bottom: 10px; margin-top: 20px;
}
.step-num {
background: var(--primary); color: white; width: 28px; height: 28px;
border-radius: 50%; display: flex; align-items: center; justify-content: center;
font-weight: bold; margin-right: 10px; font-size: 0.9em;
}
.step-title { font-size: 1.2em; font-weight: 600; color: var(--text-main); }
.step-desc { font-size: 0.93em; color: var(--text-sub); margin-left: auto; font-style: italic;}
/* --- 模型选择按钮容器 --- */
.model-select-container {
display: flex !important;
justify-content: flex-start !important;
flex-wrap: wrap !important;
align-items: center !important;
gap: 10px !important;
width: 100% !important;
margin-bottom: 20px !important;
}
/* 按钮通用样式 */
button.model-btn {
background: linear-gradient(145deg, #ecfdf5 0%, #d1fae5 100%);
border: 1px solid #6ee7b7 !important;
color: #065f46 !important;
font-weight: 600 !important;
font-size: 0.9em !important;
border-radius: 12px !important;
box-shadow: 0 2px 4px rgba(16, 185, 129, 0.1);
transition: all 0.2s ease;
height: 46px !important;
min-height: 46px !important;
white-space: nowrap;
padding: 0 30px !important;
min-width: 220px !important;
text-align: center !important;
flex-grow: 0 !important;
width: auto !important;
}
.dark button.model-btn {
background: linear-gradient(145deg, #064e3b 0%, #065f46 100%);
border-color: #059669 !important;
color: #ecfdf5 !important;
}
button.model-btn:hover {
transform: translateY(-2px);
box-shadow: 0 4px 8px rgba(16, 185, 129, 0.25);
background: linear-gradient(145deg, #d1fae5 0%, #a7f3d0 100%);
}
.dark button.model-btn:hover {
background: linear-gradient(145deg, #059669 0%, #047857 100%) !important;
color: #ffffff !important;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.4);
}
/* --- 卡片样式 --- */
.model-card {
background: var(--bg-card); border: 1px solid var(--border-sub);
border-radius: 10px; padding: 16px; margin-bottom: 15px; height: 100%;
transition: all 0.2s; position: relative; overflow: hidden;
min-height: 200px;
}
.model-card.populated { border-left: 5px solid var(--primary); box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.05); }
.model-card.empty {
border: 2px dashed var(--border-sub);
display: flex; flex-direction: column; justify-content: center; align-items: center;
opacity: 0.6;
}
.card-title { font-weight: bold; color: var(--text-sub); font-size: 1.1em; }
.card-subtitle { font-size: 0.85em; color: var(--text-sub); margin-top: 5px;}
.card-header { display: flex; justify-content: space-between; align-items: center; margin-bottom: 12px; }
.model-name { font-weight: bold; font-size: 1.05em; color: var(--text-main); }
.acc-badge { background: var(--primary-light); color: var(--primary-dark); font-size: 0.75em; padding: 3px 8px; border-radius: 12px; font-weight: 700; }
.history-container { display: flex; flex-direction: column; gap: 6px; }
.history-row { display: flex; align-items: center; background: rgba(0,0,0,0.03); padding: 5px 8px; border-radius: 6px; }
.status-box {
width: 20px; height: 20px; border-radius: 5px; display: flex; align-items: center; justify-content: center;
color: white; font-size: 0.7em; font-weight: bold; margin-right: 8px; flex-shrink: 0;
}
.status-green { background-color: var(--success); }
.status-red { background-color: var(--fail); }
.status-yellow { background-color: var(--warning); }
.prompt-text {
flex: 1;
min-width: 0;
font-size: 0.85em; color: var(--text-main);
white-space: nowrap; overflow: hidden; text-overflow: ellipsis;
}
.history-more { font-size: 0.8em; color: var(--text-sub); text-align: center; font-style: italic; margin-top: 4px; }
/* 按钮和链接 */
.custom-btn { font-weight: bold !important; font-size: 1.1em !important; }
.sample-btn {
border: 1px dashed rgba(16, 185, 129, 0.3) !important;
color: var(--primary-dark) !important;
background: rgba(16, 185, 129, 0.05) !important;
}
.sample-btn:hover {
background: rgba(16, 185, 129, 0.15) !important;
}
.paper-link {
font-size: 0.5em; vertical-align: middle; color: var(--primary);
text-decoration: none; border: 1px solid var(--primary); padding: 4px 10px;
border-radius: 15px; font-weight: normal; transition: all 0.2s; background: transparent;
}
.paper-link:hover { background: var(--primary); color: white; }
/* Custom Upload Styling */
.upload-container {
display: flex; align-items: center; gap: 10px; margin-bottom: 20px;
padding: 8px 12px; background: rgba(0,0,0,0.02); border-radius: 8px;
border: 1px dashed var(--border-sub);
}
.upload-label { font-size: 0.9em; font-weight: 600; color: var(--text-sub); white-space: nowrap;}
.code-inline {
font-family: monospace; background: rgba(0,0,0,0.05); padding: 2px 4px;
border-radius: 4px; font-size: 0.85em; color: var(--primary-dark);
}
"""
# ==========================================
# 3. UI 构建
# ==========================================
with gr.Blocks(theme=gr.themes.Soft(primary_hue="emerald"), css=css, title="V0 Predictor") as demo:
# 默认加载第一个模型到左侧 slot
initial_model = TARGET_MODELS[2] # DeepSeek-R1-Distill-Qwen-1.5B
initial_data = load_real_model_data(initial_model)
init_in1, init_in2, init_in3, init_in4, init_ids, init_ds_name = load_dataset_batch(TARGET_DATASETS[0])
state_pointer = gr.State(value=1)
state_data1 = gr.State(value=initial_data)
state_name1 = gr.State(value=initial_model)
state_data2 = gr.State(value=None)
state_name2 = gr.State(value=None)
state_current_ids = gr.State(value=init_ids)
state_current_dataset = gr.State(value=init_ds_name)
gr.HTML("""
<div class="concept-banner">
<div class="concept-title">
Generalist Value Model V<sub>0</sub>
<a href="https://arxiv.org/abs/2602.03584" class="paper-link">Paper ↗</a>
<a href="https://github.com/Now-Join-Us/V0" class="paper-link">Code ↗</a>
</div>
<div class="concept-subtitle">
<span style="color: var(--primary); font-weight: bold;">Function:</span> V<sub>0</sub> uses a model's historical performance to predict<br>
how it will perform on unseen instructions without running the model itself.
</div>
</div>
""")
# STEP 1: MODEL SELECTION
gr.HTML("""
<div class="step-header">
<div class="step-num">1</div>
<div class="step-title">Select Model (with instruction-performance pairs)</div>
<div class="step-desc">Click buttons below to load real pairs, or upload your own.</div>
</div>
""")
with gr.Row(elem_classes=["model-select-container"]):
model_buttons = []
for m_name in TARGET_MODELS:
btn = gr.Button(m_name, elem_classes=["model-btn"], scale=0)
model_buttons.append((btn, m_name))
with gr.Row():
with gr.Column(variant="panel"):
slot1_html = gr.HTML(format_model_card(initial_data, initial_model))
with gr.Column(variant="panel"):
slot2_html = gr.HTML(format_model_card(None, None, True))
with gr.Row(elem_classes=["upload-container"]):
gr.HTML(
f"""<div class='upload-label'>
[Optional] Upload Custom Model (.jsonl) &nbsp;|&nbsp;
<span class='code-inline'>Format: {{"prompt": "...", "score": 0.8 }}\\n{{"prompt": "...", "score": 0.2}} ...</span>
</div>"""
)
upload_btn = gr.File(
file_count="single",
file_types=[".jsonl"],
label=None,
show_label=False,
container=False,
height=30,
scale=1
)
# STEP 2: INSTRUCTION INPUT
gr.HTML("""
<div class="step-header" style="margin-top: 40px;">
<div class="step-num">2</div>
<div class="step-title">Select Instructions (to predict how well the above models perform on them)</div>
<div class="step-desc">Click to randomly sample 4 from the corresponding dataset.</div>
</div>
""")
with gr.Row(elem_classes=["model-select-container"]):
dataset_buttons = []
for d_name in TARGET_DATASETS:
d_btn = gr.Button(d_name, elem_classes=["model-btn"], scale=0)
dataset_buttons.append((d_btn, d_name))
with gr.Row():
t1 = gr.Textbox(label="Instruction 1", value=init_in1, lines=2)
t2 = gr.Textbox(label="Instruction 2", value=init_in2, lines=2)
with gr.Row():
t3 = gr.Textbox(label="Instruction 3", value=init_in3, lines=2)
t4 = gr.Textbox(label="Instruction 4", value=init_in4, lines=2)
gr.HTML("<div style='height: 8px;'></div>")
# STEP 3: PREDICTION
with gr.Row():
with gr.Column():
predict_btn = gr.Button("Run V₀ Prediction", variant="primary", size="lg", elem_classes=["custom-btn"])
gr.HTML("""
<div class="step-header" style="margin-top: 20px; border-bottom: none;">
<div class="step-num">3</div>
<div class="step-title">Results</div>
<div class="step-desc">Limited to 16 context size (vs. standard 256) by HF CPU limits; GPU enables ms-level speeds.</div>
</div>
""")
output_df = gr.Dataframe(
headers=["Model", "Instruction", "Actual mean@10", "Predicted Score", "Predicted Class"],
datatype=["str", "str", "str", "str", "str"],
interactive=False,
column_widths=["20%", "44%", "10%", "13%", "13%"]
)
# Event Binding
def make_click_handler(model_name):
return dict(
fn=lambda ptr, d1, n1, d2, n2: select_model(model_name, ptr, d1, n1, d2, n2),
inputs=[state_pointer, state_data1, state_name1, state_data2, state_name2],
outputs=[slot1_html, slot2_html, state_pointer, state_data1, state_name1, state_data2, state_name2]
)
for btn, name in model_buttons:
btn.click(**make_click_handler(name))
upload_btn.upload(
fn=handle_custom_upload,
inputs=[upload_btn, state_pointer, state_data1, state_name1, state_data2, state_name2],
outputs=[slot1_html, slot2_html, state_pointer, state_data1, state_name1, state_data2, state_name2]
)
def make_dataset_handler(ds_name):
def handler():
i1, i2, i3, i4, ids, d_name = load_dataset_batch(ds_name)
return i1, i2, i3, i4, ids, d_name
return dict(
fn=handler,
inputs=None,
outputs=[t1, t2, t3, t4, state_current_ids, state_current_dataset],
show_progress="hidden"
)
for d_btn, d_name in dataset_buttons:
d_btn.click(**make_dataset_handler(d_name))
t1.input(fn=lambda ids: invalidate_specific_id(ids, 0), inputs=[state_current_ids], outputs=[state_current_ids])
t2.input(fn=lambda ids: invalidate_specific_id(ids, 1), inputs=[state_current_ids], outputs=[state_current_ids])
t3.input(fn=lambda ids: invalidate_specific_id(ids, 2), inputs=[state_current_ids], outputs=[state_current_ids])
t4.input(fn=lambda ids: invalidate_specific_id(ids, 3), inputs=[state_current_ids], outputs=[state_current_ids])
predict_btn.click(
fn=predict_performance,
inputs=[state_data1, state_name1, state_data2, state_name2, t1, t2, t3, t4, state_current_ids, state_current_dataset],
outputs=[output_df]
)
if __name__ == "__main__":
demo.launch()