#!/usr/bin/env python3 """ ReflectionBench Case 展示 Gradio 应用 (HuggingFace Spaces 版本 v2) 启动时自动解压 images.zip 和 edited_images.zip """ import gradio as gr import json import zipfile from pathlib import Path from PIL import Image import numpy as np from typing import Dict, List, Optional # ============================================================================ # 初始化:解压图像文件 # ============================================================================ CURRENT_DIR = Path(__file__).parent def extract_if_needed(): """如果 zip 文件存在且未解压,则解压""" for zip_name in ["images.zip", "edited_images.zip"]: zip_path = CURRENT_DIR / zip_name extract_dir = CURRENT_DIR / zip_name.replace(".zip", "") if zip_path.exists() and not extract_dir.exists(): print(f"Extracting {zip_name}...") with zipfile.ZipFile(zip_path, 'r') as zf: zf.extractall(CURRENT_DIR) print(f" Done: {extract_dir}") # 启动时解压 extract_if_needed() # ============================================================================ # 配置 # ============================================================================ EVAL_MODELS = { "qwen": "qwen_eval", "qwen235b": "qwen235b_eval", } BASELINE_MODELS = [ "qwen_ft_v1_step900", "qwen_ft_v3_step730", "qwen_ft_v4_step765", "qwen_ft_v5_step920", "qwen_ft_v6_step910", "qwen_ft_v7_step445", "qwen_ft_v8_step365", "qwen_ft_v8_step795", ] COMPARISON_MODELS = [ "qwen", "qwen3vl", "qwen3vl_thinking", "bagel", "omnigen2", "omniverifier", "unicot", "sld", "reflect_dit", "reflectionflow_qwen8b", "thinkgen", "reasonedit", ] EDITOR_CONFIG = { "qwen_ft_v1_step900": "qwen_image_2511", "qwen_ft_v3_step730": "qwen_image_2511", "qwen_ft_v4_step765": "qwen_image_2511", "qwen_ft_v5_step920": "qwen_image_2511", "qwen_ft_v6_step910": "qwen_image_2511", "qwen_ft_v7_step445": "qwen_image_2511", "qwen_ft_v8_step365": "qwen_image_2511", "qwen_ft_v8_step795": "qwen_image_2511", "qwen": "qwen_image_2511", "qwen3vl": "qwen_image_2511", "qwen3vl_thinking": "qwen_image_2511", "omniverifier": "qwen_image_2511", "sld": "qwen_image_2511", "bagel": "bagel", "omnigen2": "omnigen2", "unicot": "unicot", "reflect_dit": "reflect_dit", "reflectionflow_qwen8b": "reflectionflow", "thinkgen": "thinkgen", "reasonedit": "reasonedit", } CASE_TYPE_NAMES = { "type1_answer_wrong": "Type 1: Answer 错误", "type2_explanation_wrong": "Type 2: Explanation 错误→编辑失败", "type3_edit_better": "Type 3: Edit Prompt 更优", } # ============================================================================ # 数据加载 # ============================================================================ def load_cases(eval_model_key: str) -> Dict: filename = f"detailed_cases_{eval_model_key}.json" filepath = CURRENT_DIR / filename if filepath.exists(): with open(filepath, "r", encoding="utf-8") as f: return json.load(f) return {} _cases_cache = {} def get_cases(eval_model_key: str) -> Dict: if eval_model_key not in _cases_cache: _cases_cache[eval_model_key] = load_cases(eval_model_key) return _cases_cache[eval_model_key] # ============================================================================ # 图片处理 # ============================================================================ def load_image_256(path: Path) -> Optional[np.ndarray]: if path and path.exists(): try: img = Image.open(path) img.thumbnail((256, 256), Image.Resampling.LANCZOS) return np.array(img) except Exception: pass return None def get_bad_image_path(bad_image_rel: str) -> Path: return CURRENT_DIR / "images" / bad_image_rel def get_edited_image_path(bad_image_rel: str, verifier: str, editor: str) -> Path: bad_image_path_obj = Path(bad_image_rel) filename = f"{editor}_{verifier}_{bad_image_path_obj.stem}{bad_image_path_obj.suffix}" return CURRENT_DIR / "edited_images" / filename # ============================================================================ # Gradio 回调函数 # ============================================================================ def get_case_list(eval_model_key: str, baseline: str, comparison: str, case_type: str) -> List[str]: cases_data = get_cases(eval_model_key) if not cases_data: return [] cases = cases_data.get(baseline, {}).get(comparison, {}).get(case_type, []) return [f"idx={c['idx']} | {c['category']} | {c['original_prompt'][:40]}..." for c in cases] def update_case_list(eval_model_key: str, baseline: str, comparison: str, case_type: str): cases = get_case_list(eval_model_key, baseline, comparison, case_type) return gr.update(choices=cases, value=cases[0] if cases else None) def get_comparison_choices(eval_model_key: str, baseline: str) -> List[str]: cases_data = get_cases(eval_model_key) if not cases_data or baseline not in cases_data: return COMPARISON_MODELS available = [] for comp in COMPARISON_MODELS: comp_data = cases_data.get(baseline, {}).get(comp, {}) total = sum(len(comp_data.get(t, [])) for t in CASE_TYPE_NAMES.keys()) if total > 0: available.append(comp) return available if available else COMPARISON_MODELS def update_comparison_choices(eval_model_key: str, baseline: str): choices = get_comparison_choices(eval_model_key, baseline) return gr.update(choices=choices, value=choices[0] if choices else None) def show_case(eval_model_key: str, baseline: str, comparison: str, case_type: str, case_idx_str: str): empty_result = ("请选择 case", "", "", "", "", "", None, None, None) if not case_idx_str: return empty_result cases_data = get_cases(eval_model_key) if not cases_data: return ("数据未加载", "", "", "", "", "", None, None, None) try: idx = int(case_idx_str.split("|")[0].replace("idx=", "").strip()) except (ValueError, IndexError): return ("无法解析 idx", "", "", "", "", "", None, None, None) cases = cases_data.get(baseline, {}).get(comparison, {}).get(case_type, []) case = next((c for c in cases if c["idx"] == idx), None) if not case: return ("未找到 case", "", "", "", "", "", None, None, None) bad_img = load_image_256(get_bad_image_path(case["bad_image"])) baseline_editor = case.get("baseline_editor", EDITOR_CONFIG.get(baseline, "qwen_image_2511")) comparison_editor = case.get("comparison_editor", EDITOR_CONFIG.get(comparison, "qwen_image_2511")) baseline_img = load_image_256(get_edited_image_path(case["bad_image"], baseline, baseline_editor)) comparison_img = load_image_256(get_edited_image_path(case["bad_image"], comparison, comparison_editor)) info_col = f"""### Case 详情 - **评估模型**: {eval_model_key.upper()} - **类型**: {CASE_TYPE_NAMES.get(case_type, case_type)} - **idx**: {case["idx"]} | **类别**: {case["category"]} - **Prompt**: *{case["original_prompt"]}*""" gt_col = f"""### Ground Truth - **Answer**: `{case["gt_answer"]}` - **Explanation**: {case.get("gt_explanation", "N/A")}""" baseline_col = f"""### {baseline} (基准) - **Answer**: `{case["baseline_answer"]}` {"✓" if case["baseline_answer_correct"] else "✗"} | **Exp评估**: {"✓" if case["baseline_explanation_correct"] else "✗"} - **Explanation**: {case["baseline_explanation"]} - **Edit指令**: {case["baseline_edit_prompt"]} - **I_Score**: **{case["baseline_i_score"]:.3f}** | Edited Acc: **{case["baseline_edited_acc"]:.3f}**""" comp_col = f"""### {comparison} (对比) - **Answer**: `{case["comparison_answer"]}` {"✓" if case["comparison_answer_correct"] else "✗"} | **Exp评估**: {"✓" if case["comparison_explanation_correct"] else "✗"} - **Explanation**: {case["comparison_explanation"] if case["comparison_explanation"] else "—"} - **Edit指令**: {case["comparison_edit_prompt"] if case["comparison_edit_prompt"] else "—"} - **I_Score**: **{case["comparison_i_score"]:.3f}** | Edited Acc: **{case["comparison_edited_acc"]:.3f}**""" i_diff = case["baseline_i_score"] - case["comparison_i_score"] ans_adv = "✓" if case["baseline_answer_correct"] and not case["comparison_answer_correct"] else "-" exp_adv = "✓" if case["baseline_explanation_correct"] and not case["comparison_explanation_correct"] else "-" summary = f"""**对比总结**: Answer({ans_adv}) | Explanation({exp_adv}) | I_Score: {baseline} **{case["baseline_i_score"]:.3f}** vs {comparison} {case["comparison_i_score"]:.3f} = **+{i_diff:.3f}** | Edited Acc: **{case["baseline_edited_acc"]:.3f}** vs {case["comparison_edited_acc"]:.3f}""" img_labels = f"原图 → {baseline} 编辑 ({baseline_editor}) → {comparison} 编辑 ({comparison_editor})" return (info_col, gt_col, baseline_col, comp_col, summary, img_labels, bad_img, baseline_img, comparison_img) def get_statistics(eval_model_key: str) -> str: cases_data = get_cases(eval_model_key) if not cases_data: return "数据未加载" lines = [f"### 统计摘要 ({eval_model_key.upper()})\n"] for baseline in BASELINE_MODELS: if baseline not in cases_data: continue lines.append(f"\n**{baseline}**:\n") lines.append("| 对比模型 | Type1 | Type2 | Type3 | 总计 |") lines.append("|----------|-------|-------|-------|------|") for comparison in COMPARISON_MODELS: comp_data = cases_data.get(baseline, {}).get(comparison, {}) t1 = len(comp_data.get("type1_answer_wrong", [])) t2 = len(comp_data.get("type2_explanation_wrong", [])) t3 = len(comp_data.get("type3_edit_better", [])) total = t1 + t2 + t3 if total > 0: lines.append(f"| {comparison} | {t1} | {t2} | {t3} | {total} |") return "\n".join(lines) # ============================================================================ # Gradio 界面 # ============================================================================ def create_app(): with gr.Blocks(title="ReflectionBench Case Viewer", theme=gr.themes.Soft()) as demo: gr.Markdown("# ReflectionBench Case 展示") gr.Markdown("对比 Reflector (基准模型) 与其他 baseline 模型的表现差异") with gr.Accordion("统计摘要", open=False): stats_md = gr.Markdown() with gr.Row(): eval_model = gr.Radio(choices=list(EVAL_MODELS.keys()), label="评估模型", value="qwen", scale=1) baseline = gr.Dropdown(choices=BASELINE_MODELS, label="基准模型 (Reflector)", value=BASELINE_MODELS[0], scale=1) comparison = gr.Dropdown(choices=COMPARISON_MODELS, label="对比模型", value=COMPARISON_MODELS[0], scale=1) with gr.Row(): case_type = gr.Radio( choices=[ ("Type1: Answer错误", "type1_answer_wrong"), ("Type2: Exp错误→编辑失败", "type2_explanation_wrong"), ("Type3: Edit更优", "type3_edit_better"), ], label="Case类型", value="type1_answer_wrong", scale=2 ) case_dropdown = gr.Dropdown(choices=[], label="选择Case", scale=2) with gr.Row(): info_md = gr.Markdown() gt_md = gr.Markdown() with gr.Row(): baseline_md = gr.Markdown() comparison_md = gr.Markdown() summary_md = gr.Markdown() img_label = gr.Markdown() with gr.Row(): bad_image = gr.Image(label="原图", scale=1, height=200) baseline_edited = gr.Image(label="基准模型编辑后", scale=1, height=200) comparison_edited = gr.Image(label="对比模型编辑后", scale=1, height=200) outputs = [info_md, gt_md, baseline_md, comparison_md, summary_md, img_label, bad_image, baseline_edited, comparison_edited] def on_eval_model_change(eval_key, base, comp, ctype): stats = get_statistics(eval_key) comp_update = update_comparison_choices(eval_key, base) case_update = update_case_list(eval_key, base, comp, ctype) return stats, comp_update, case_update eval_model.change(fn=on_eval_model_change, inputs=[eval_model, baseline, comparison, case_type], outputs=[stats_md, comparison, case_dropdown]) def on_baseline_change(eval_key, base, comp, ctype): comp_update = update_comparison_choices(eval_key, base) case_update = update_case_list(eval_key, base, comp, ctype) return comp_update, case_update baseline.change(fn=on_baseline_change, inputs=[eval_model, baseline, comparison, case_type], outputs=[comparison, case_dropdown]) comparison.change(fn=update_case_list, inputs=[eval_model, baseline, comparison, case_type], outputs=[case_dropdown]) case_type.change(fn=update_case_list, inputs=[eval_model, baseline, comparison, case_type], outputs=[case_dropdown]) case_dropdown.change(fn=show_case, inputs=[eval_model, baseline, comparison, case_type, case_dropdown], outputs=outputs) def on_load(eval_key, base, comp, ctype): stats = get_statistics(eval_key) cases = get_case_list(eval_key, base, comp, ctype) case_val = cases[0] if cases else None if case_val: case_result = show_case(eval_key, base, comp, ctype, case_val) else: case_result = ("请选择 case", "", "", "", "", "", None, None, None) return (stats, gr.update(choices=cases, value=case_val)) + case_result demo.load(fn=on_load, inputs=[eval_model, baseline, comparison, case_type], outputs=[stats_md, case_dropdown] + outputs) return demo if __name__ == "__main__": demo = create_app() demo.launch()