import gradio as gr import pandas as pd import json import random import os from huggingface_hub import hf_hub_download, snapshot_download # ========================================== # V0 Model Import & Initialization # ========================================== MODEL_REPO_ID = "Now-Join-Us/Generalist-Value-Model-V0" EMBEDDING_REPO_ID = "Qwen/Qwen3-Embedding-0.6B" v0_model = None print(">>> Starting V0 App...") try: # Ensure v0_core is accessible. # If v0_core is in the root directory, this import works. from v0_core.models.v0 import V0 print(f">>> Downloading models from Hugging Face Hub...") # 1. Download V0 Checkpoint print(f" - Fetching checkpoint from {MODEL_REPO_ID}...") checkpoint_path = hf_hub_download( repo_id=MODEL_REPO_ID, filename="v_0_for_router_demo.pt" ) # 2. Download TabPFN Classifier print(f" - Fetching TabPFN head from {MODEL_REPO_ID}...") tabpfn_path = hf_hub_download( repo_id=MODEL_REPO_ID, filename="pretrained/tabpfn-v2.5-classifier-v2.5_default.ckpt" ) # 3. Download Qwen Embedding Model (Snapshot because we need config, tokenizer, etc.) print(f" - Fetching Embedding model from {EMBEDDING_REPO_ID}...") embedding_path = snapshot_download( repo_id=EMBEDDING_REPO_ID ) print(">>> All assets downloaded. Initializing V0 class...") # Determine device (Defaulting to CPU as per request, but CUDA is preferred if available) # device = "cuda" if torch.cuda.is_available() else "cpu" device = "cpu" print(f">>> Device selected: {device}") # Initialize the Model v0_model = V0.from_pretrained( checkpoint_path=checkpoint_path, embedding_model_path=embedding_path, tabpfn_head_path=tabpfn_path, device=device ) # Set to eval mode if applicable (V0 usually handles this internally, but good practice) if hasattr(v0_model, 'eval'): v0_model.eval() print(f">>> V0 Model Loaded Successfully on {device}!") except Exception as e: print(f"Warning: Failed to load V0 model or dependencies. Prediction will fail. Error: {e}") v0_model = None # ========================================== # 0. 配置与数据加载逻辑 # ========================================== # 指定的目标模型列表 TARGET_MODELS = [ 'Qwen3-30B-A3B-Instruct-2507', 'Qwen3-4B-Instruct-2507', 'DeepSeek-R1-Distill-Qwen-1.5B', "Qwen3-32B", "DeepSeek-R1-Distill-Qwen-7B", "Qwen3-0.6B", ] # 指定的数据集列表 TARGET_DATASETS = [ 'aime_2024', 'aime_2025', 'amc23', 'gaokao_math_cloze', 'gpqa_diamond', 'olympiad' ] # 全局变量:Prompt 字典缓存 PROMPT_DICT_CACHE = None # 全局变量:性能数据缓存 { "Model_Dataset": { id_int: score_float } } PERFORMANCE_DB = {} def load_prompt_dict(): """读取 Prompt 字典文件""" global PROMPT_DICT_CACHE if PROMPT_DICT_CACHE is not None: return PROMPT_DICT_CACHE path = "data/router_context_sampled/prompt_dict.json" try: with open(path, "r", encoding="utf-8") as f: PROMPT_DICT_CACHE = json.load(f) print(f"Loaded {len(PROMPT_DICT_CACHE)} prompts from {path}") except FileNotFoundError: print(f"Warning: {path} not found. Using empty dict.") PROMPT_DICT_CACHE = {} return PROMPT_DICT_CACHE def load_performance_data(model_name, dataset_name): """ 按需加载特定模型在特定数据集上的性能数据到内存缓存中 Key格式: f"{model_name}_{dataset_name}" """ cache_key = f"{model_name}_{dataset_name}" # 如果已经缓存过,直接返回 if cache_key in PERFORMANCE_DB: return PERFORMANCE_DB[cache_key] file_path = f"data/router_val_rollouts_and_performance/{model_name}_{dataset_name}.jsonl" data_map = {} if os.path.exists(file_path): try: with open(file_path, "r", encoding="utf-8") as f: for line in f: try: item = json.loads(line) idx = item.get("id") score = item.get("mean@10") if idx is not None and score is not None: data_map[idx] = (float(score) + 1.0) / 2.0 except json.JSONDecodeError: continue print(f"Loaded performance map for {cache_key}: {len(data_map)} entries.") except Exception as e: print(f"Error loading performance file {file_path}: {e}") else: pass PERFORMANCE_DB[cache_key] = data_map return data_map def load_real_model_data(model_name): prompt_db = load_prompt_dict() file_path = f"data/router_context_sampled/performance/{model_name}.jsonl" # 临时存储解析后的所有数据 all_parsed_data = [] if not os.path.exists(file_path): return [{ "prompt": f"Error: File {file_path} not found.", "score": 0.0, "is_correct": False }] try: with open(file_path, "r", encoding="utf-8") as f: lines = f.readlines() # 1. 先解析所有行,获取 label 信息 for line in lines: try: item = json.loads(line) dataset = item.get("dataset", "") idx = item.get("id", "") key = f"{dataset}_{idx}" prompt_text = prompt_db.get(key, f"[Prompt missing for key: {key}]") raw_score = float(item.get("mean@10", -1.0)) normalized_score = (raw_score + 1.0) / 2.0 is_correct = normalized_score > 0.6 all_parsed_data.append({ "prompt": prompt_text, "score": normalized_score, "is_correct": is_correct }) except json.JSONDecodeError: continue # 2. 采样逻辑 (Target = 16) target_count = 16 if len(all_parsed_data) <= target_count: return all_parsed_data # 初步随机采样 16 个 sampled_data = random.sample(all_parsed_data, target_count) # 检查正负样本分布 positives = [x for x in sampled_data if x['is_correct']] negatives = [x for x in sampled_data if not x['is_correct']] # 情况 A: 全是正样本 (All Positive) -> 强行塞入一个负样本 if len(negatives) == 0: # 从原始池中寻找所有负样本 all_negatives_pool = [x for x in all_parsed_data if not x['is_correct']] if all_negatives_pool: # 随机选一个负样本 replacement = random.choice(all_negatives_pool) # 随机替换掉当前采样中的一个 replace_idx = random.randint(0, target_count - 1) sampled_data[replace_idx] = replacement # 情况 B: 全是负样本 (All Negative) -> 强行塞入一个正样本 elif len(positives) == 0: # 从原始池中寻找所有正样本 all_positives_pool = [x for x in all_parsed_data if x['is_correct']] if all_positives_pool: # 随机选一个正样本 replacement = random.choice(all_positives_pool) # 随机替换掉当前采样中的一个 replace_idx = random.randint(0, target_count - 1) sampled_data[replace_idx] = replacement return sampled_data except Exception as e: print(f"Error reading {file_path}: {e}") return [] def load_dataset_batch(dataset_name): """ 从指定数据集文件中随机采样4条数据 返回: (input_list, id_list, dataset_name) """ file_path = f"data/router_val_prompt_and_gts/{dataset_name}.jsonl" inputs = [] ids = [] try: if os.path.exists(file_path): with open(file_path, "r", encoding="utf-8") as f: lines = f.readlines() # 随机采样 4 条,需要保留原始索引或ID信息 count = 4 total_lines = len(lines) if total_lines > 0: # 随机生成索引 sampled_indices = random.sample(range(total_lines), min(count, total_lines)) for idx in sampled_indices: line = lines[idx] try: item = json.loads(line) inputs.append(item.get("input", "")) # 优先使用 JSON 中的 "id",如果没有则尝试使用 "idx",最后使用文件行号 # 注意:需要确保这里的 id 能够对应到 performance 文件中的 id item_id = item.get("id", item.get("idx", idx)) ids.append(item_id) except: inputs.append("Error parsing JSON") ids.append(-1) else: inputs = [] ids = [] else: inputs = [f"File not found: {file_path}"] * 4 ids = [-1] * 4 except Exception as e: inputs = [f"Error loading dataset: {str(e)}"] * 4 ids = [-1] * 4 while len(inputs) < 4: inputs.append("") ids.append(-1) return inputs[0], inputs[1], inputs[2], inputs[3], ids, dataset_name # ========================================== # 1. 核心逻辑 # ========================================== def process_custom_file(file_obj): """解析用户上传的 JSONL 文件""" if file_obj is None: return [] data_list = [] try: with open(file_obj.name, 'r', encoding='utf-8') as f: for line in f: line = line.strip() if not line: continue try: item = json.loads(line) prompt = item.get('prompt') if not prompt: continue # 尝试多种分数格式 score = 0.0 if 'score' in item: score = float(item['score']) elif 'mean@10' in item: score = (float(item['mean@10']) + 1.0) / 2.0 elif 'is_correct' in item: score = 1.0 if item['is_correct'] else 0.0 is_correct = score > 0.6 data_list.append({ "prompt": prompt, "score": score, "is_correct": is_correct }) except: continue except Exception as e: print(f"Error parsing custom file: {e}") return [] return data_list def format_model_card(data_list, model_name, is_empty=False): if is_empty or not data_list: return f"""
Empty Slot
Select a model above to load
""" total = len(data_list) rows_html = "" preview_limit = 3 preview_data = data_list[:preview_limit] for item in preview_data: p_text = item.get('prompt', '') if len(p_text) > 45: p_text = p_text[:45] + " ... " score = item.get('score', 0.0) if score >= 0.8: status_class = "status-green" icon = "✔" elif score >= 0.4: status_class = "status-yellow" icon = "~" else: status_class = "status-red" icon = "✘" rows_html += f"""
{icon}
{p_text}
{score:.2f}
""" remaining = total - preview_limit if remaining > 0: rows_html += f"
+ {remaining} more items
" return f"""
{model_name} Context Size: {total}
{rows_html}
""" def select_model(btn_name, current_pointer, data_s1, name_s1, data_s2, name_s2): if btn_name == name_s1 or btn_name == name_s2: html1 = format_model_card(data_s1, name_s1, is_empty=(name_s1 is None)) html2 = format_model_card(data_s2, name_s2, is_empty=(name_s2 is None)) return (html1, html2, current_pointer, data_s1, name_s1, data_s2, name_s2) new_data = load_real_model_data(btn_name) if current_pointer == 0: html1 = format_model_card(new_data, btn_name) html2 = format_model_card(data_s2, name_s2, is_empty=(name_s2 is None)) return (html1, html2, 1, new_data, btn_name, data_s2, name_s2) else: html1 = format_model_card(data_s1, name_s1, is_empty=(name_s1 is None)) html2 = format_model_card(new_data, btn_name) return (html1, html2, 0, data_s1, name_s1, new_data, btn_name) def handle_custom_upload(file_obj, current_pointer, data_s1, name_s1, data_s2, name_s2): """处理自定义文件上传""" if file_obj is None: html1 = format_model_card(data_s1, name_s1, is_empty=(name_s1 is None)) html2 = format_model_card(data_s2, name_s2, is_empty=(name_s2 is None)) return (html1, html2, current_pointer, data_s1, name_s1, data_s2, name_s2) new_data = process_custom_file(file_obj) custom_name = f"Custom: {os.path.basename(file_obj.name)[:15]}" if not new_data: # 文件解析失败处理 custom_name = "Upload Failed" new_data = [{"prompt": "Invalid JSONL format", "score": 0.0, "is_correct": False}] if current_pointer == 0: html1 = format_model_card(new_data, custom_name) html2 = format_model_card(data_s2, name_s2, is_empty=(name_s2 is None)) return (html1, html2, 1, new_data, custom_name, data_s2, name_s2) else: html1 = format_model_card(data_s1, name_s1, is_empty=(name_s1 is None)) html2 = format_model_card(new_data, custom_name) return (html1, html2, 0, data_s1, name_s1, new_data, custom_name) def invalidate_specific_id(current_ids, index): """ 当用户修改了第 index 个文本框时,将对应的 ID 设为 -1 (N/A) """ if current_ids is None: return [-1, -1, -1, -1] new_ids = list(current_ids) while len(new_ids) <= index: new_ids.append(-1) new_ids[index] = -1 return new_ids def predict_performance(data1, name1, data2, name2, t1, t2, t3, t4, current_ids, current_dataset): """ 使用 V0 模型进行性能预测 """ if v0_model is None: return pd.DataFrame([{"Error": "V0 Model not loaded successfully on server."}]) # 1. 收集目标 Instruction inputs_list = [t1, t2, t3, t4] target_prompts = [] target_ids = [] target_indices = [] # 记录原始索引以便填回 for i, t in enumerate(inputs_list): if t.strip(): target_prompts.append(t) target_indices.append(i) if current_ids and i < len(current_ids): target_ids.append(current_ids[i]) else: target_ids.append(-1) if not target_prompts: return pd.DataFrame([{"Error": "Please enter at least one target instruction."}]) # 2. 确定要运行的模型 active_models = [] if name1 and data1: active_models.append((name1, data1)) if name2 and data2: active_models.append((name2, data2)) if not active_models: return pd.DataFrame([{"Error": "Please select at least one model above."}]) results = [] for m_name, m_history in active_models: # --- 准备 Context (History) --- # 提取历史数据中的 prompt 和 score # 限制最大 Context 长度为 256 (模型要求) context_prompts = [] context_labels = [] # 过滤无效数据并取前256个 valid_items = [x for x in m_history if isinstance(x.get('score'), (int, float)) and x.get('prompt')] # 这里假设 m_history 已经是随机采样过的(在 load_real_model_data 中处理) batch_context = valid_items[:256] for item in batch_context: context_prompts.append(item['prompt']) # 逻辑:先确保 score 在 0-1 之间 (load_real_model_data 已经做了 normalized), # 然后以 0.5 为分界,转化为 0 或 1 的标签 score_val = item['score'] label = 1 if score_val > 0.5 else 0 context_labels.append(label) if not context_prompts: results.append({"Model": m_name, "Error": "No valid context data found."}) continue # --- 获取 Ground Truth (如果可用) --- real_perf_map = {} if current_dataset and "Custom:" not in m_name: real_perf_map = load_performance_data(m_name, current_dataset) # --- 运行 V0 预测 --- try: # V0 模型一次性处理所有 targets,共享同一个 context pred_scores = v0_model.predict( context_prompts=context_prompts, context_labels=context_labels, target_prompts=target_prompts ) except Exception as e: print(f"Inference error for {m_name}: {e}") pred_scores = [0.0] * len(target_prompts) # --- 格式化结果 --- for idx, t_text in enumerate(target_prompts): cur_id = target_ids[idx] # 真实分数 if cur_id != -1 and cur_id in real_perf_map: real_val = real_perf_map[cur_id] real_mean_10_str = f"{real_val:.2f}" else: real_mean_10_str = "N/A" # 预测分数 final_score = float(pred_scores[idx]) pred_str = "✔ Success" if final_score > 0.5 else "✘ Failure" results.append({ "Model": m_name, "Instruction": t_text, "Actual mean@10": real_mean_10_str, "Predicted Score": f"{final_score:.4f}", "Predicted Class": pred_str }) return pd.DataFrame(results) # ========================================== # 2. CSS 样式 # ========================================== css = """ /* 全局变量 */ :root { --primary: #10b981; --primary-light: #ecfdf5; --primary-dark: #047857; --bg-card: #ffffff; --border-sub: #e5e7eb; --text-main: #1f2937; --text-sub: #6b7280; --success: #10b981; --fail: #ef4444; --warning: #f59e0b; } .dark { --bg-card: #1f2937; --border-sub: #374151; --text-main: #f3f4f6; --text-sub: #9ca3af; --primary-light: #064e3b; --primary-dark: #ecfdf5; } /* 顶部 Banner */ .concept-banner { background: linear-gradient(135deg, rgba(16, 185, 129, 0.08) 0%, rgba(59, 130, 246, 0.05) 100%); border: 1px solid var(--primary-light); border-radius: 12px; padding: 24px; text-align: center; margin-bottom: 20px; } .concept-title { font-size: 1.8em; font-weight: 700; color: var(--text-main); margin-bottom: 8px; } .concept-subtitle { font-size: 1em; color: var(--text-sub); max-width: 600px; margin: 0 auto; line-height: 1.5; } /* 步骤标题 */ .step-header { display: flex; align-items: center; margin-bottom: 15px; border-bottom: 2px solid var(--border-sub); padding-bottom: 10px; margin-top: 20px; } .step-num { background: var(--primary); color: white; width: 28px; height: 28px; border-radius: 50%; display: flex; align-items: center; justify-content: center; font-weight: bold; margin-right: 10px; font-size: 0.9em; } .step-title { font-size: 1.2em; font-weight: 600; color: var(--text-main); } .step-desc { font-size: 0.93em; color: var(--text-sub); margin-left: auto; font-style: italic;} /* --- 模型选择按钮容器 --- */ .model-select-container { display: flex !important; justify-content: flex-start !important; flex-wrap: wrap !important; align-items: center !important; gap: 10px !important; width: 100% !important; margin-bottom: 20px !important; } /* 按钮通用样式 */ button.model-btn { background: linear-gradient(145deg, #ecfdf5 0%, #d1fae5 100%); border: 1px solid #6ee7b7 !important; color: #065f46 !important; font-weight: 600 !important; font-size: 0.9em !important; border-radius: 12px !important; box-shadow: 0 2px 4px rgba(16, 185, 129, 0.1); transition: all 0.2s ease; height: 46px !important; min-height: 46px !important; white-space: nowrap; padding: 0 30px !important; min-width: 220px !important; text-align: center !important; flex-grow: 0 !important; width: auto !important; } .dark button.model-btn { background: linear-gradient(145deg, #064e3b 0%, #065f46 100%); border-color: #059669 !important; color: #ecfdf5 !important; } button.model-btn:hover { transform: translateY(-2px); box-shadow: 0 4px 8px rgba(16, 185, 129, 0.25); background: linear-gradient(145deg, #d1fae5 0%, #a7f3d0 100%); } .dark button.model-btn:hover { background: linear-gradient(145deg, #059669 0%, #047857 100%) !important; color: #ffffff !important; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.4); } /* --- 卡片样式 --- */ .model-card { background: var(--bg-card); border: 1px solid var(--border-sub); border-radius: 10px; padding: 16px; margin-bottom: 15px; height: 100%; transition: all 0.2s; position: relative; overflow: hidden; min-height: 200px; } .model-card.populated { border-left: 5px solid var(--primary); box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.05); } .model-card.empty { border: 2px dashed var(--border-sub); display: flex; flex-direction: column; justify-content: center; align-items: center; opacity: 0.6; } .card-title { font-weight: bold; color: var(--text-sub); font-size: 1.1em; } .card-subtitle { font-size: 0.85em; color: var(--text-sub); margin-top: 5px;} .card-header { display: flex; justify-content: space-between; align-items: center; margin-bottom: 12px; } .model-name { font-weight: bold; font-size: 1.05em; color: var(--text-main); } .acc-badge { background: var(--primary-light); color: var(--primary-dark); font-size: 0.75em; padding: 3px 8px; border-radius: 12px; font-weight: 700; } .history-container { display: flex; flex-direction: column; gap: 6px; } .history-row { display: flex; align-items: center; background: rgba(0,0,0,0.03); padding: 5px 8px; border-radius: 6px; } .status-box { width: 20px; height: 20px; border-radius: 5px; display: flex; align-items: center; justify-content: center; color: white; font-size: 0.7em; font-weight: bold; margin-right: 8px; flex-shrink: 0; } .status-green { background-color: var(--success); } .status-red { background-color: var(--fail); } .status-yellow { background-color: var(--warning); } .prompt-text { flex: 1; min-width: 0; font-size: 0.85em; color: var(--text-main); white-space: nowrap; overflow: hidden; text-overflow: ellipsis; } .history-more { font-size: 0.8em; color: var(--text-sub); text-align: center; font-style: italic; margin-top: 4px; } /* 按钮和链接 */ .custom-btn { font-weight: bold !important; font-size: 1.1em !important; } .sample-btn { border: 1px dashed rgba(16, 185, 129, 0.3) !important; color: var(--primary-dark) !important; background: rgba(16, 185, 129, 0.05) !important; } .sample-btn:hover { background: rgba(16, 185, 129, 0.15) !important; } .paper-link { font-size: 0.5em; vertical-align: middle; color: var(--primary); text-decoration: none; border: 1px solid var(--primary); padding: 4px 10px; border-radius: 15px; font-weight: normal; transition: all 0.2s; background: transparent; } .paper-link:hover { background: var(--primary); color: white; } /* Custom Upload Styling */ .upload-container { display: flex; align-items: center; gap: 10px; margin-bottom: 20px; padding: 8px 12px; background: rgba(0,0,0,0.02); border-radius: 8px; border: 1px dashed var(--border-sub); } .upload-label { font-size: 0.9em; font-weight: 600; color: var(--text-sub); white-space: nowrap;} .code-inline { font-family: monospace; background: rgba(0,0,0,0.05); padding: 2px 4px; border-radius: 4px; font-size: 0.85em; color: var(--primary-dark); } """ # ========================================== # 3. UI 构建 # ========================================== with gr.Blocks(theme=gr.themes.Soft(primary_hue="emerald"), css=css, title="V0 Predictor") as demo: # 默认加载第一个模型到左侧 slot initial_model = TARGET_MODELS[2] # DeepSeek-R1-Distill-Qwen-1.5B initial_data = load_real_model_data(initial_model) init_in1, init_in2, init_in3, init_in4, init_ids, init_ds_name = load_dataset_batch(TARGET_DATASETS[0]) state_pointer = gr.State(value=1) state_data1 = gr.State(value=initial_data) state_name1 = gr.State(value=initial_model) state_data2 = gr.State(value=None) state_name2 = gr.State(value=None) state_current_ids = gr.State(value=init_ids) state_current_dataset = gr.State(value=init_ds_name) gr.HTML("""
Generalist Value Model V0 Paper ↗ Code ↗
Function: V0 uses a model's historical performance to predict
how it will perform on unseen instructions without running the model itself.
""") # STEP 1: MODEL SELECTION gr.HTML("""
1
Select Model (with instruction-performance pairs)
Click buttons below to load real pairs, or upload your own.
""") with gr.Row(elem_classes=["model-select-container"]): model_buttons = [] for m_name in TARGET_MODELS: btn = gr.Button(m_name, elem_classes=["model-btn"], scale=0) model_buttons.append((btn, m_name)) with gr.Row(): with gr.Column(variant="panel"): slot1_html = gr.HTML(format_model_card(initial_data, initial_model)) with gr.Column(variant="panel"): slot2_html = gr.HTML(format_model_card(None, None, True)) with gr.Row(elem_classes=["upload-container"]): gr.HTML( f"""
[Optional] Upload Custom Model (.jsonl)  |  Format: {{"prompt": "...", "score": 0.8 }}\\n{{"prompt": "...", "score": 0.2}} ...
""" ) upload_btn = gr.File( file_count="single", file_types=[".jsonl"], label=None, show_label=False, container=False, height=30, scale=1 ) # STEP 2: INSTRUCTION INPUT gr.HTML("""
2
Select Instructions (to predict how well the above models perform on them)
Click to randomly sample 4 from the corresponding dataset.
""") with gr.Row(elem_classes=["model-select-container"]): dataset_buttons = [] for d_name in TARGET_DATASETS: d_btn = gr.Button(d_name, elem_classes=["model-btn"], scale=0) dataset_buttons.append((d_btn, d_name)) with gr.Row(): t1 = gr.Textbox(label="Instruction 1", value=init_in1, lines=2) t2 = gr.Textbox(label="Instruction 2", value=init_in2, lines=2) with gr.Row(): t3 = gr.Textbox(label="Instruction 3", value=init_in3, lines=2) t4 = gr.Textbox(label="Instruction 4", value=init_in4, lines=2) gr.HTML("
") # STEP 3: PREDICTION with gr.Row(): with gr.Column(): predict_btn = gr.Button("Run V₀ Prediction", variant="primary", size="lg", elem_classes=["custom-btn"]) gr.HTML("""
3
Results
Limited to 16 context size (vs. standard 256) by HF CPU limits; GPU enables ms-level speeds.
""") output_df = gr.Dataframe( headers=["Model", "Instruction", "Actual mean@10", "Predicted Score", "Predicted Class"], datatype=["str", "str", "str", "str", "str"], interactive=False, column_widths=["20%", "44%", "10%", "13%", "13%"] ) # Event Binding def make_click_handler(model_name): return dict( fn=lambda ptr, d1, n1, d2, n2: select_model(model_name, ptr, d1, n1, d2, n2), inputs=[state_pointer, state_data1, state_name1, state_data2, state_name2], outputs=[slot1_html, slot2_html, state_pointer, state_data1, state_name1, state_data2, state_name2] ) for btn, name in model_buttons: btn.click(**make_click_handler(name)) upload_btn.upload( fn=handle_custom_upload, inputs=[upload_btn, state_pointer, state_data1, state_name1, state_data2, state_name2], outputs=[slot1_html, slot2_html, state_pointer, state_data1, state_name1, state_data2, state_name2] ) def make_dataset_handler(ds_name): def handler(): i1, i2, i3, i4, ids, d_name = load_dataset_batch(ds_name) return i1, i2, i3, i4, ids, d_name return dict( fn=handler, inputs=None, outputs=[t1, t2, t3, t4, state_current_ids, state_current_dataset], show_progress="hidden" ) for d_btn, d_name in dataset_buttons: d_btn.click(**make_dataset_handler(d_name)) t1.input(fn=lambda ids: invalidate_specific_id(ids, 0), inputs=[state_current_ids], outputs=[state_current_ids]) t2.input(fn=lambda ids: invalidate_specific_id(ids, 1), inputs=[state_current_ids], outputs=[state_current_ids]) t3.input(fn=lambda ids: invalidate_specific_id(ids, 2), inputs=[state_current_ids], outputs=[state_current_ids]) t4.input(fn=lambda ids: invalidate_specific_id(ids, 3), inputs=[state_current_ids], outputs=[state_current_ids]) predict_btn.click( fn=predict_performance, inputs=[state_data1, state_name1, state_data2, state_name2, t1, t2, t3, t4, state_current_ids, state_current_dataset], outputs=[output_df] ) if __name__ == "__main__": demo.launch()