| | import gradio as gr |
| | import pandas as pd |
| | import json |
| | import random |
| | import os |
| | from huggingface_hub import hf_hub_download, snapshot_download |
| |
|
| | |
| | |
| | |
| | MODEL_REPO_ID = "Now-Join-Us/Generalist-Value-Model-V0" |
| | EMBEDDING_REPO_ID = "Qwen/Qwen3-Embedding-0.6B" |
| |
|
| | v0_model = None |
| |
|
| | print(">>> Starting V0 App...") |
| |
|
| | try: |
| | |
| | |
| | from v0_core.models.v0 import V0 |
| | |
| | print(f">>> Downloading models from Hugging Face Hub...") |
| | |
| | |
| | print(f" - Fetching checkpoint from {MODEL_REPO_ID}...") |
| | checkpoint_path = hf_hub_download( |
| | repo_id=MODEL_REPO_ID, |
| | filename="v_0_for_router_demo.pt" |
| | ) |
| | |
| | |
| | print(f" - Fetching TabPFN head from {MODEL_REPO_ID}...") |
| | tabpfn_path = hf_hub_download( |
| | repo_id=MODEL_REPO_ID, |
| | filename="pretrained/tabpfn-v2.5-classifier-v2.5_default.ckpt" |
| | ) |
| | |
| | |
| | print(f" - Fetching Embedding model from {EMBEDDING_REPO_ID}...") |
| | embedding_path = snapshot_download( |
| | repo_id=EMBEDDING_REPO_ID |
| | ) |
| | |
| | print(">>> All assets downloaded. Initializing V0 class...") |
| | |
| | |
| | |
| | device = "cpu" |
| | print(f">>> Device selected: {device}") |
| | |
| | |
| | v0_model = V0.from_pretrained( |
| | checkpoint_path=checkpoint_path, |
| | embedding_model_path=embedding_path, |
| | tabpfn_head_path=tabpfn_path, |
| | device=device |
| | ) |
| | |
| | |
| | if hasattr(v0_model, 'eval'): |
| | v0_model.eval() |
| |
|
| | print(f">>> V0 Model Loaded Successfully on {device}!") |
| |
|
| | except Exception as e: |
| | print(f"Warning: Failed to load V0 model or dependencies. Prediction will fail. Error: {e}") |
| | v0_model = None |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | TARGET_MODELS = [ |
| | 'Qwen3-30B-A3B-Instruct-2507', |
| | 'Qwen3-4B-Instruct-2507', |
| | 'DeepSeek-R1-Distill-Qwen-1.5B', |
| | "Qwen3-32B", |
| | "DeepSeek-R1-Distill-Qwen-7B", |
| | "Qwen3-0.6B", |
| | ] |
| |
|
| | |
| | TARGET_DATASETS = [ |
| | 'aime_2024', 'aime_2025', 'amc23', 'gaokao_math_cloze', 'gpqa_diamond', 'olympiad' |
| | ] |
| |
|
| | |
| | PROMPT_DICT_CACHE = None |
| | |
| | PERFORMANCE_DB = {} |
| |
|
| | def load_prompt_dict(): |
| | """读取 Prompt 字典文件""" |
| | global PROMPT_DICT_CACHE |
| | if PROMPT_DICT_CACHE is not None: |
| | return PROMPT_DICT_CACHE |
| | |
| | path = "data/router_context_sampled/prompt_dict.json" |
| | try: |
| | with open(path, "r", encoding="utf-8") as f: |
| | PROMPT_DICT_CACHE = json.load(f) |
| | print(f"Loaded {len(PROMPT_DICT_CACHE)} prompts from {path}") |
| | except FileNotFoundError: |
| | print(f"Warning: {path} not found. Using empty dict.") |
| | PROMPT_DICT_CACHE = {} |
| | return PROMPT_DICT_CACHE |
| |
|
| | def load_performance_data(model_name, dataset_name): |
| | """ |
| | 按需加载特定模型在特定数据集上的性能数据到内存缓存中 |
| | Key格式: f"{model_name}_{dataset_name}" |
| | """ |
| | cache_key = f"{model_name}_{dataset_name}" |
| | |
| | |
| | if cache_key in PERFORMANCE_DB: |
| | return PERFORMANCE_DB[cache_key] |
| | |
| | file_path = f"data/router_val_rollouts_and_performance/{model_name}_{dataset_name}.jsonl" |
| | data_map = {} |
| | |
| | if os.path.exists(file_path): |
| | try: |
| | with open(file_path, "r", encoding="utf-8") as f: |
| | for line in f: |
| | try: |
| | item = json.loads(line) |
| | idx = item.get("id") |
| | score = item.get("mean@10") |
| | |
| | if idx is not None and score is not None: |
| | data_map[idx] = (float(score) + 1.0) / 2.0 |
| | except json.JSONDecodeError: |
| | continue |
| | print(f"Loaded performance map for {cache_key}: {len(data_map)} entries.") |
| | except Exception as e: |
| | print(f"Error loading performance file {file_path}: {e}") |
| | else: |
| | pass |
| |
|
| | PERFORMANCE_DB[cache_key] = data_map |
| | return data_map |
| |
|
| | def load_real_model_data(model_name): |
| | prompt_db = load_prompt_dict() |
| | file_path = f"data/router_context_sampled/performance/{model_name}.jsonl" |
| | |
| | |
| | all_parsed_data = [] |
| | |
| | if not os.path.exists(file_path): |
| | return [{ |
| | "prompt": f"Error: File {file_path} not found.", |
| | "score": 0.0, |
| | "is_correct": False |
| | }] |
| |
|
| | try: |
| | with open(file_path, "r", encoding="utf-8") as f: |
| | lines = f.readlines() |
| | |
| | |
| | for line in lines: |
| | try: |
| | item = json.loads(line) |
| | dataset = item.get("dataset", "") |
| | idx = item.get("id", "") |
| | key = f"{dataset}_{idx}" |
| | prompt_text = prompt_db.get(key, f"[Prompt missing for key: {key}]") |
| | |
| | raw_score = float(item.get("mean@10", -1.0)) |
| | normalized_score = (raw_score + 1.0) / 2.0 |
| | is_correct = normalized_score > 0.6 |
| | |
| | all_parsed_data.append({ |
| | "prompt": prompt_text, |
| | "score": normalized_score, |
| | "is_correct": is_correct |
| | }) |
| | |
| | except json.JSONDecodeError: |
| | continue |
| |
|
| | |
| | target_count = 16 |
| | |
| | if len(all_parsed_data) <= target_count: |
| | return all_parsed_data |
| | |
| | |
| | sampled_data = random.sample(all_parsed_data, target_count) |
| | |
| | |
| | positives = [x for x in sampled_data if x['is_correct']] |
| | negatives = [x for x in sampled_data if not x['is_correct']] |
| | |
| | |
| | if len(negatives) == 0: |
| | |
| | all_negatives_pool = [x for x in all_parsed_data if not x['is_correct']] |
| | if all_negatives_pool: |
| | |
| | replacement = random.choice(all_negatives_pool) |
| | |
| | replace_idx = random.randint(0, target_count - 1) |
| | sampled_data[replace_idx] = replacement |
| |
|
| | |
| | elif len(positives) == 0: |
| | |
| | all_positives_pool = [x for x in all_parsed_data if x['is_correct']] |
| | if all_positives_pool: |
| | |
| | replacement = random.choice(all_positives_pool) |
| | |
| | replace_idx = random.randint(0, target_count - 1) |
| | sampled_data[replace_idx] = replacement |
| |
|
| | return sampled_data |
| | |
| | except Exception as e: |
| | print(f"Error reading {file_path}: {e}") |
| | return [] |
| |
|
| | def load_dataset_batch(dataset_name): |
| | """ |
| | 从指定数据集文件中随机采样4条数据 |
| | 返回: (input_list, id_list, dataset_name) |
| | """ |
| | file_path = f"data/router_val_prompt_and_gts/{dataset_name}.jsonl" |
| | |
| | inputs = [] |
| | ids = [] |
| | |
| | try: |
| | if os.path.exists(file_path): |
| | with open(file_path, "r", encoding="utf-8") as f: |
| | lines = f.readlines() |
| | |
| | |
| | count = 4 |
| | total_lines = len(lines) |
| | |
| | if total_lines > 0: |
| | |
| | sampled_indices = random.sample(range(total_lines), min(count, total_lines)) |
| | |
| | for idx in sampled_indices: |
| | line = lines[idx] |
| | try: |
| | item = json.loads(line) |
| | inputs.append(item.get("input", "")) |
| | |
| | |
| | item_id = item.get("id", item.get("idx", idx)) |
| | ids.append(item_id) |
| | except: |
| | inputs.append("Error parsing JSON") |
| | ids.append(-1) |
| | else: |
| | inputs = [] |
| | ids = [] |
| |
|
| | else: |
| | inputs = [f"File not found: {file_path}"] * 4 |
| | ids = [-1] * 4 |
| | |
| | except Exception as e: |
| | inputs = [f"Error loading dataset: {str(e)}"] * 4 |
| | ids = [-1] * 4 |
| |
|
| | while len(inputs) < 4: |
| | inputs.append("") |
| | ids.append(-1) |
| | |
| | return inputs[0], inputs[1], inputs[2], inputs[3], ids, dataset_name |
| |
|
| | |
| | |
| | |
| |
|
| | def process_custom_file(file_obj): |
| | """解析用户上传的 JSONL 文件""" |
| | if file_obj is None: |
| | return [] |
| | |
| | data_list = [] |
| | try: |
| | with open(file_obj.name, 'r', encoding='utf-8') as f: |
| | for line in f: |
| | line = line.strip() |
| | if not line: continue |
| | try: |
| | item = json.loads(line) |
| | prompt = item.get('prompt') |
| | if not prompt: continue |
| | |
| | |
| | score = 0.0 |
| | if 'score' in item: |
| | score = float(item['score']) |
| | elif 'mean@10' in item: |
| | score = (float(item['mean@10']) + 1.0) / 2.0 |
| | elif 'is_correct' in item: |
| | score = 1.0 if item['is_correct'] else 0.0 |
| | |
| | is_correct = score > 0.6 |
| | data_list.append({ |
| | "prompt": prompt, |
| | "score": score, |
| | "is_correct": is_correct |
| | }) |
| | except: |
| | continue |
| | except Exception as e: |
| | print(f"Error parsing custom file: {e}") |
| | return [] |
| | |
| | return data_list |
| |
|
| | def format_model_card(data_list, model_name, is_empty=False): |
| | if is_empty or not data_list: |
| | return f""" |
| | <div class='model-card empty'> |
| | <div class='card-title'>Empty Slot</div> |
| | <div class='card-subtitle'>Select a model above to load</div> |
| | </div> |
| | """ |
| | |
| | total = len(data_list) |
| | rows_html = "" |
| | preview_limit = 3 |
| | preview_data = data_list[:preview_limit] |
| |
|
| | for item in preview_data: |
| | p_text = item.get('prompt', '') |
| | if len(p_text) > 45: |
| | p_text = p_text[:45] + " ... " |
| | |
| | score = item.get('score', 0.0) |
| | if score >= 0.8: |
| | status_class = "status-green" |
| | icon = "✔" |
| | elif score >= 0.4: |
| | status_class = "status-yellow" |
| | icon = "~" |
| | else: |
| | status_class = "status-red" |
| | icon = "✘" |
| | |
| | rows_html += f""" |
| | <div class='history-row'> |
| | <div class='status-box {status_class}'>{icon}</div> |
| | <div class='prompt-text'>{p_text}</div> |
| | <span style='opacity:0.9; font-size:0.9em; white-space: nowrap; margin-left: 8px;'>{score:.2f}</span> |
| | </div> |
| | """ |
| | |
| | remaining = total - preview_limit |
| | if remaining > 0: |
| | rows_html += f"<div class='history-more'>+ {remaining} more items</div>" |
| |
|
| | return f""" |
| | <div class='model-card populated'> |
| | <div class='card-header'> |
| | <span class='model-name'>{model_name}</span> |
| | <span class='acc-badge'>Context Size: {total}</span> |
| | </div> |
| | <div class='card-body'> |
| | <div class='history-container'> |
| | {rows_html} |
| | </div> |
| | </div> |
| | </div> |
| | """ |
| |
|
| | def select_model(btn_name, current_pointer, data_s1, name_s1, data_s2, name_s2): |
| | if btn_name == name_s1 or btn_name == name_s2: |
| | html1 = format_model_card(data_s1, name_s1, is_empty=(name_s1 is None)) |
| | html2 = format_model_card(data_s2, name_s2, is_empty=(name_s2 is None)) |
| | return (html1, html2, current_pointer, data_s1, name_s1, data_s2, name_s2) |
| |
|
| | new_data = load_real_model_data(btn_name) |
| | |
| | if current_pointer == 0: |
| | html1 = format_model_card(new_data, btn_name) |
| | html2 = format_model_card(data_s2, name_s2, is_empty=(name_s2 is None)) |
| | return (html1, html2, 1, new_data, btn_name, data_s2, name_s2) |
| | else: |
| | html1 = format_model_card(data_s1, name_s1, is_empty=(name_s1 is None)) |
| | html2 = format_model_card(new_data, btn_name) |
| | return (html1, html2, 0, data_s1, name_s1, new_data, btn_name) |
| |
|
| | def handle_custom_upload(file_obj, current_pointer, data_s1, name_s1, data_s2, name_s2): |
| | """处理自定义文件上传""" |
| | if file_obj is None: |
| | html1 = format_model_card(data_s1, name_s1, is_empty=(name_s1 is None)) |
| | html2 = format_model_card(data_s2, name_s2, is_empty=(name_s2 is None)) |
| | return (html1, html2, current_pointer, data_s1, name_s1, data_s2, name_s2) |
| | |
| | new_data = process_custom_file(file_obj) |
| | custom_name = f"Custom: {os.path.basename(file_obj.name)[:15]}" |
| | |
| | if not new_data: |
| | |
| | custom_name = "Upload Failed" |
| | new_data = [{"prompt": "Invalid JSONL format", "score": 0.0, "is_correct": False}] |
| |
|
| | if current_pointer == 0: |
| | html1 = format_model_card(new_data, custom_name) |
| | html2 = format_model_card(data_s2, name_s2, is_empty=(name_s2 is None)) |
| | return (html1, html2, 1, new_data, custom_name, data_s2, name_s2) |
| | else: |
| | html1 = format_model_card(data_s1, name_s1, is_empty=(name_s1 is None)) |
| | html2 = format_model_card(new_data, custom_name) |
| | return (html1, html2, 0, data_s1, name_s1, new_data, custom_name) |
| |
|
| | def invalidate_specific_id(current_ids, index): |
| | """ |
| | 当用户修改了第 index 个文本框时,将对应的 ID 设为 -1 (N/A) |
| | """ |
| | if current_ids is None: |
| | return [-1, -1, -1, -1] |
| | |
| | new_ids = list(current_ids) |
| | |
| | while len(new_ids) <= index: |
| | new_ids.append(-1) |
| | |
| | new_ids[index] = -1 |
| | return new_ids |
| |
|
| | def predict_performance(data1, name1, data2, name2, t1, t2, t3, t4, current_ids, current_dataset): |
| | """ |
| | 使用 V0 模型进行性能预测 |
| | """ |
| | if v0_model is None: |
| | return pd.DataFrame([{"Error": "V0 Model not loaded successfully on server."}]) |
| |
|
| | |
| | inputs_list = [t1, t2, t3, t4] |
| | target_prompts = [] |
| | target_ids = [] |
| | target_indices = [] |
| | |
| | for i, t in enumerate(inputs_list): |
| | if t.strip(): |
| | target_prompts.append(t) |
| | target_indices.append(i) |
| | if current_ids and i < len(current_ids): |
| | target_ids.append(current_ids[i]) |
| | else: |
| | target_ids.append(-1) |
| |
|
| | if not target_prompts: |
| | return pd.DataFrame([{"Error": "Please enter at least one target instruction."}]) |
| |
|
| | |
| | active_models = [] |
| | if name1 and data1: active_models.append((name1, data1)) |
| | if name2 and data2: active_models.append((name2, data2)) |
| | |
| | if not active_models: |
| | return pd.DataFrame([{"Error": "Please select at least one model above."}]) |
| |
|
| | results = [] |
| |
|
| | for m_name, m_history in active_models: |
| | |
| | |
| | |
| | |
| | context_prompts = [] |
| | context_labels = [] |
| | |
| | |
| | valid_items = [x for x in m_history if isinstance(x.get('score'), (int, float)) and x.get('prompt')] |
| | |
| | batch_context = valid_items[:256] |
| | |
| | for item in batch_context: |
| | context_prompts.append(item['prompt']) |
| | |
| | |
| | score_val = item['score'] |
| | label = 1 if score_val > 0.5 else 0 |
| | context_labels.append(label) |
| |
|
| | if not context_prompts: |
| | results.append({"Model": m_name, "Error": "No valid context data found."}) |
| | continue |
| |
|
| | |
| | real_perf_map = {} |
| | if current_dataset and "Custom:" not in m_name: |
| | real_perf_map = load_performance_data(m_name, current_dataset) |
| |
|
| | |
| | try: |
| | |
| | pred_scores = v0_model.predict( |
| | context_prompts=context_prompts, |
| | context_labels=context_labels, |
| | target_prompts=target_prompts |
| | ) |
| | except Exception as e: |
| | print(f"Inference error for {m_name}: {e}") |
| | pred_scores = [0.0] * len(target_prompts) |
| |
|
| | |
| | for idx, t_text in enumerate(target_prompts): |
| | cur_id = target_ids[idx] |
| | |
| | |
| | if cur_id != -1 and cur_id in real_perf_map: |
| | real_val = real_perf_map[cur_id] |
| | real_mean_10_str = f"{real_val:.2f}" |
| | else: |
| | real_mean_10_str = "N/A" |
| |
|
| | |
| | final_score = float(pred_scores[idx]) |
| | pred_str = "✔ Success" if final_score > 0.5 else "✘ Failure" |
| | |
| | results.append({ |
| | "Model": m_name, |
| | "Instruction": t_text, |
| | "Actual mean@10": real_mean_10_str, |
| | "Predicted Score": f"{final_score:.4f}", |
| | "Predicted Class": pred_str |
| | }) |
| | |
| | return pd.DataFrame(results) |
| |
|
| | |
| | |
| | |
| | css = """ |
| | /* 全局变量 */ |
| | :root { |
| | --primary: #10b981; |
| | --primary-light: #ecfdf5; |
| | --primary-dark: #047857; |
| | --bg-card: #ffffff; |
| | --border-sub: #e5e7eb; |
| | --text-main: #1f2937; |
| | --text-sub: #6b7280; |
| | --success: #10b981; |
| | --fail: #ef4444; |
| | --warning: #f59e0b; |
| | } |
| | .dark { |
| | --bg-card: #1f2937; |
| | --border-sub: #374151; |
| | --text-main: #f3f4f6; |
| | --text-sub: #9ca3af; |
| | --primary-light: #064e3b; |
| | --primary-dark: #ecfdf5; |
| | } |
| | |
| | /* 顶部 Banner */ |
| | .concept-banner { |
| | background: linear-gradient(135deg, rgba(16, 185, 129, 0.08) 0%, rgba(59, 130, 246, 0.05) 100%); |
| | border: 1px solid var(--primary-light); |
| | border-radius: 12px; |
| | padding: 24px; |
| | text-align: center; |
| | margin-bottom: 20px; |
| | } |
| | .concept-title { |
| | font-size: 1.8em; font-weight: 700; color: var(--text-main); margin-bottom: 8px; |
| | } |
| | .concept-subtitle { font-size: 1em; color: var(--text-sub); max-width: 600px; margin: 0 auto; line-height: 1.5; } |
| | |
| | /* 步骤标题 */ |
| | .step-header { |
| | display: flex; align-items: center; margin-bottom: 15px; |
| | border-bottom: 2px solid var(--border-sub); padding-bottom: 10px; margin-top: 20px; |
| | } |
| | .step-num { |
| | background: var(--primary); color: white; width: 28px; height: 28px; |
| | border-radius: 50%; display: flex; align-items: center; justify-content: center; |
| | font-weight: bold; margin-right: 10px; font-size: 0.9em; |
| | } |
| | .step-title { font-size: 1.2em; font-weight: 600; color: var(--text-main); } |
| | .step-desc { font-size: 0.93em; color: var(--text-sub); margin-left: auto; font-style: italic;} |
| | |
| | /* --- 模型选择按钮容器 --- */ |
| | .model-select-container { |
| | display: flex !important; |
| | justify-content: flex-start !important; |
| | flex-wrap: wrap !important; |
| | align-items: center !important; |
| | gap: 10px !important; |
| | width: 100% !important; |
| | margin-bottom: 20px !important; |
| | } |
| | |
| | /* 按钮通用样式 */ |
| | button.model-btn { |
| | background: linear-gradient(145deg, #ecfdf5 0%, #d1fae5 100%); |
| | border: 1px solid #6ee7b7 !important; |
| | color: #065f46 !important; |
| | font-weight: 600 !important; |
| | font-size: 0.9em !important; |
| | border-radius: 12px !important; |
| | box-shadow: 0 2px 4px rgba(16, 185, 129, 0.1); |
| | transition: all 0.2s ease; |
| | height: 46px !important; |
| | min-height: 46px !important; |
| | white-space: nowrap; |
| | padding: 0 30px !important; |
| | min-width: 220px !important; |
| | text-align: center !important; |
| | flex-grow: 0 !important; |
| | width: auto !important; |
| | } |
| | .dark button.model-btn { |
| | background: linear-gradient(145deg, #064e3b 0%, #065f46 100%); |
| | border-color: #059669 !important; |
| | color: #ecfdf5 !important; |
| | } |
| | |
| | button.model-btn:hover { |
| | transform: translateY(-2px); |
| | box-shadow: 0 4px 8px rgba(16, 185, 129, 0.25); |
| | background: linear-gradient(145deg, #d1fae5 0%, #a7f3d0 100%); |
| | } |
| | |
| | .dark button.model-btn:hover { |
| | background: linear-gradient(145deg, #059669 0%, #047857 100%) !important; |
| | color: #ffffff !important; |
| | box-shadow: 0 4px 8px rgba(0, 0, 0, 0.4); |
| | } |
| | |
| | /* --- 卡片样式 --- */ |
| | .model-card { |
| | background: var(--bg-card); border: 1px solid var(--border-sub); |
| | border-radius: 10px; padding: 16px; margin-bottom: 15px; height: 100%; |
| | transition: all 0.2s; position: relative; overflow: hidden; |
| | min-height: 200px; |
| | } |
| | .model-card.populated { border-left: 5px solid var(--primary); box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.05); } |
| | .model-card.empty { |
| | border: 2px dashed var(--border-sub); |
| | display: flex; flex-direction: column; justify-content: center; align-items: center; |
| | opacity: 0.6; |
| | } |
| | .card-title { font-weight: bold; color: var(--text-sub); font-size: 1.1em; } |
| | .card-subtitle { font-size: 0.85em; color: var(--text-sub); margin-top: 5px;} |
| | .card-header { display: flex; justify-content: space-between; align-items: center; margin-bottom: 12px; } |
| | .model-name { font-weight: bold; font-size: 1.05em; color: var(--text-main); } |
| | .acc-badge { background: var(--primary-light); color: var(--primary-dark); font-size: 0.75em; padding: 3px 8px; border-radius: 12px; font-weight: 700; } |
| | .history-container { display: flex; flex-direction: column; gap: 6px; } |
| | .history-row { display: flex; align-items: center; background: rgba(0,0,0,0.03); padding: 5px 8px; border-radius: 6px; } |
| | .status-box { |
| | width: 20px; height: 20px; border-radius: 5px; display: flex; align-items: center; justify-content: center; |
| | color: white; font-size: 0.7em; font-weight: bold; margin-right: 8px; flex-shrink: 0; |
| | } |
| | .status-green { background-color: var(--success); } |
| | .status-red { background-color: var(--fail); } |
| | .status-yellow { background-color: var(--warning); } |
| | |
| | .prompt-text { |
| | flex: 1; |
| | min-width: 0; |
| | font-size: 0.85em; color: var(--text-main); |
| | white-space: nowrap; overflow: hidden; text-overflow: ellipsis; |
| | } |
| | .history-more { font-size: 0.8em; color: var(--text-sub); text-align: center; font-style: italic; margin-top: 4px; } |
| | |
| | /* 按钮和链接 */ |
| | .custom-btn { font-weight: bold !important; font-size: 1.1em !important; } |
| | .sample-btn { |
| | border: 1px dashed rgba(16, 185, 129, 0.3) !important; |
| | color: var(--primary-dark) !important; |
| | background: rgba(16, 185, 129, 0.05) !important; |
| | } |
| | .sample-btn:hover { |
| | background: rgba(16, 185, 129, 0.15) !important; |
| | } |
| | .paper-link { |
| | font-size: 0.5em; vertical-align: middle; color: var(--primary); |
| | text-decoration: none; border: 1px solid var(--primary); padding: 4px 10px; |
| | border-radius: 15px; font-weight: normal; transition: all 0.2s; background: transparent; |
| | } |
| | .paper-link:hover { background: var(--primary); color: white; } |
| | |
| | /* Custom Upload Styling */ |
| | .upload-container { |
| | display: flex; align-items: center; gap: 10px; margin-bottom: 20px; |
| | padding: 8px 12px; background: rgba(0,0,0,0.02); border-radius: 8px; |
| | border: 1px dashed var(--border-sub); |
| | } |
| | .upload-label { font-size: 0.9em; font-weight: 600; color: var(--text-sub); white-space: nowrap;} |
| | .code-inline { |
| | font-family: monospace; background: rgba(0,0,0,0.05); padding: 2px 4px; |
| | border-radius: 4px; font-size: 0.85em; color: var(--primary-dark); |
| | } |
| | """ |
| |
|
| | |
| | |
| | |
| | with gr.Blocks(theme=gr.themes.Soft(primary_hue="emerald"), css=css, title="V0 Predictor") as demo: |
| | |
| | |
| | initial_model = TARGET_MODELS[2] |
| | initial_data = load_real_model_data(initial_model) |
| | |
| | init_in1, init_in2, init_in3, init_in4, init_ids, init_ds_name = load_dataset_batch(TARGET_DATASETS[0]) |
| | |
| | state_pointer = gr.State(value=1) |
| | state_data1 = gr.State(value=initial_data) |
| | state_name1 = gr.State(value=initial_model) |
| | state_data2 = gr.State(value=None) |
| | state_name2 = gr.State(value=None) |
| | |
| | state_current_ids = gr.State(value=init_ids) |
| | state_current_dataset = gr.State(value=init_ds_name) |
| |
|
| | gr.HTML(""" |
| | <div class="concept-banner"> |
| | <div class="concept-title"> |
| | Generalist Value Model V<sub>0</sub> |
| | <a href="https://arxiv.org/abs/2602.03584" class="paper-link">Paper ↗</a> |
| | <a href="https://github.com/Now-Join-Us/V0" class="paper-link">Code ↗</a> |
| | </div> |
| | <div class="concept-subtitle"> |
| | <span style="color: var(--primary); font-weight: bold;">Function:</span> V<sub>0</sub> uses a model's historical performance to predict<br> |
| | how it will perform on unseen instructions without running the model itself. |
| | </div> |
| | </div> |
| | """) |
| |
|
| | |
| | gr.HTML(""" |
| | <div class="step-header"> |
| | <div class="step-num">1</div> |
| | <div class="step-title">Select Model (with instruction-performance pairs)</div> |
| | <div class="step-desc">Click buttons below to load real pairs, or upload your own.</div> |
| | </div> |
| | """) |
| |
|
| | with gr.Row(elem_classes=["model-select-container"]): |
| | model_buttons = [] |
| | for m_name in TARGET_MODELS: |
| | btn = gr.Button(m_name, elem_classes=["model-btn"], scale=0) |
| | model_buttons.append((btn, m_name)) |
| | |
| | with gr.Row(): |
| | with gr.Column(variant="panel"): |
| | slot1_html = gr.HTML(format_model_card(initial_data, initial_model)) |
| | with gr.Column(variant="panel"): |
| | slot2_html = gr.HTML(format_model_card(None, None, True)) |
| |
|
| | with gr.Row(elem_classes=["upload-container"]): |
| | gr.HTML( |
| | f"""<div class='upload-label'> |
| | [Optional] Upload Custom Model (.jsonl) | |
| | <span class='code-inline'>Format: {{"prompt": "...", "score": 0.8 }}\\n{{"prompt": "...", "score": 0.2}} ...</span> |
| | </div>""" |
| | ) |
| | upload_btn = gr.File( |
| | file_count="single", |
| | file_types=[".jsonl"], |
| | label=None, |
| | show_label=False, |
| | container=False, |
| | height=30, |
| | scale=1 |
| | ) |
| |
|
| | |
| | gr.HTML(""" |
| | <div class="step-header" style="margin-top: 40px;"> |
| | <div class="step-num">2</div> |
| | <div class="step-title">Select Instructions (to predict how well the above models perform on them)</div> |
| | <div class="step-desc">Click to randomly sample 4 from the corresponding dataset.</div> |
| | </div> |
| | """) |
| | |
| | with gr.Row(elem_classes=["model-select-container"]): |
| | dataset_buttons = [] |
| | for d_name in TARGET_DATASETS: |
| | d_btn = gr.Button(d_name, elem_classes=["model-btn"], scale=0) |
| | dataset_buttons.append((d_btn, d_name)) |
| |
|
| | with gr.Row(): |
| | t1 = gr.Textbox(label="Instruction 1", value=init_in1, lines=2) |
| | t2 = gr.Textbox(label="Instruction 2", value=init_in2, lines=2) |
| | |
| | with gr.Row(): |
| | t3 = gr.Textbox(label="Instruction 3", value=init_in3, lines=2) |
| | t4 = gr.Textbox(label="Instruction 4", value=init_in4, lines=2) |
| |
|
| | gr.HTML("<div style='height: 8px;'></div>") |
| |
|
| | |
| | with gr.Row(): |
| | with gr.Column(): |
| | predict_btn = gr.Button("Run V₀ Prediction", variant="primary", size="lg", elem_classes=["custom-btn"]) |
| | |
| | gr.HTML(""" |
| | <div class="step-header" style="margin-top: 20px; border-bottom: none;"> |
| | <div class="step-num">3</div> |
| | <div class="step-title">Results</div> |
| | <div class="step-desc">Limited to 16 context size (vs. standard 256) by HF CPU limits; GPU enables ms-level speeds.</div> |
| | </div> |
| | """) |
| | |
| | output_df = gr.Dataframe( |
| | headers=["Model", "Instruction", "Actual mean@10", "Predicted Score", "Predicted Class"], |
| | datatype=["str", "str", "str", "str", "str"], |
| | interactive=False, |
| | column_widths=["20%", "44%", "10%", "13%", "13%"] |
| | ) |
| |
|
| | |
| | def make_click_handler(model_name): |
| | return dict( |
| | fn=lambda ptr, d1, n1, d2, n2: select_model(model_name, ptr, d1, n1, d2, n2), |
| | inputs=[state_pointer, state_data1, state_name1, state_data2, state_name2], |
| | outputs=[slot1_html, slot2_html, state_pointer, state_data1, state_name1, state_data2, state_name2] |
| | ) |
| |
|
| | for btn, name in model_buttons: |
| | btn.click(**make_click_handler(name)) |
| |
|
| | upload_btn.upload( |
| | fn=handle_custom_upload, |
| | inputs=[upload_btn, state_pointer, state_data1, state_name1, state_data2, state_name2], |
| | outputs=[slot1_html, slot2_html, state_pointer, state_data1, state_name1, state_data2, state_name2] |
| | ) |
| |
|
| | def make_dataset_handler(ds_name): |
| | def handler(): |
| | i1, i2, i3, i4, ids, d_name = load_dataset_batch(ds_name) |
| | return i1, i2, i3, i4, ids, d_name |
| | return dict( |
| | fn=handler, |
| | inputs=None, |
| | outputs=[t1, t2, t3, t4, state_current_ids, state_current_dataset], |
| | show_progress="hidden" |
| | ) |
| |
|
| | for d_btn, d_name in dataset_buttons: |
| | d_btn.click(**make_dataset_handler(d_name)) |
| |
|
| | t1.input(fn=lambda ids: invalidate_specific_id(ids, 0), inputs=[state_current_ids], outputs=[state_current_ids]) |
| | t2.input(fn=lambda ids: invalidate_specific_id(ids, 1), inputs=[state_current_ids], outputs=[state_current_ids]) |
| | t3.input(fn=lambda ids: invalidate_specific_id(ids, 2), inputs=[state_current_ids], outputs=[state_current_ids]) |
| | t4.input(fn=lambda ids: invalidate_specific_id(ids, 3), inputs=[state_current_ids], outputs=[state_current_ids]) |
| |
|
| | predict_btn.click( |
| | fn=predict_performance, |
| | inputs=[state_data1, state_name1, state_data2, state_name2, t1, t2, t3, t4, state_current_ids, state_current_dataset], |
| | outputs=[output_df] |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch() |