| | |
| | """ |
| | ReflectionBench Case 展示 Gradio 应用 (HuggingFace Spaces 版本 v2) |
| | |
| | 启动时自动解压 images.zip 和 edited_images.zip |
| | """ |
| |
|
| | import gradio as gr |
| | import json |
| | import zipfile |
| | from pathlib import Path |
| | from PIL import Image |
| | import numpy as np |
| | from typing import Dict, List, Optional |
| |
|
| | |
| | |
| | |
| |
|
| | CURRENT_DIR = Path(__file__).parent |
| |
|
| | def extract_if_needed(): |
| | """如果 zip 文件存在且未解压,则解压""" |
| | for zip_name in ["images.zip", "edited_images.zip"]: |
| | zip_path = CURRENT_DIR / zip_name |
| | extract_dir = CURRENT_DIR / zip_name.replace(".zip", "") |
| |
|
| | if zip_path.exists() and not extract_dir.exists(): |
| | print(f"Extracting {zip_name}...") |
| | with zipfile.ZipFile(zip_path, 'r') as zf: |
| | zf.extractall(CURRENT_DIR) |
| | print(f" Done: {extract_dir}") |
| |
|
| | |
| | extract_if_needed() |
| |
|
| | |
| | |
| | |
| |
|
| | EVAL_MODELS = { |
| | "qwen": "qwen_eval", |
| | "qwen235b": "qwen235b_eval", |
| | } |
| |
|
| | BASELINE_MODELS = [ |
| | "qwen_ft_v1_step900", |
| | "qwen_ft_v3_step730", |
| | "qwen_ft_v4_step765", |
| | "qwen_ft_v5_step920", |
| | "qwen_ft_v6_step910", |
| | "qwen_ft_v7_step445", |
| | "qwen_ft_v8_step365", |
| | "qwen_ft_v8_step795", |
| | ] |
| |
|
| | COMPARISON_MODELS = [ |
| | "qwen", |
| | "qwen3vl", |
| | "qwen3vl_thinking", |
| | "bagel", |
| | "omnigen2", |
| | "omniverifier", |
| | "unicot", |
| | "sld", |
| | "reflect_dit", |
| | "reflectionflow_qwen8b", |
| | "thinkgen", |
| | "reasonedit", |
| | ] |
| |
|
| | EDITOR_CONFIG = { |
| | "qwen_ft_v1_step900": "qwen_image_2511", |
| | "qwen_ft_v3_step730": "qwen_image_2511", |
| | "qwen_ft_v4_step765": "qwen_image_2511", |
| | "qwen_ft_v5_step920": "qwen_image_2511", |
| | "qwen_ft_v6_step910": "qwen_image_2511", |
| | "qwen_ft_v7_step445": "qwen_image_2511", |
| | "qwen_ft_v8_step365": "qwen_image_2511", |
| | "qwen_ft_v8_step795": "qwen_image_2511", |
| | "qwen": "qwen_image_2511", |
| | "qwen3vl": "qwen_image_2511", |
| | "qwen3vl_thinking": "qwen_image_2511", |
| | "omniverifier": "qwen_image_2511", |
| | "sld": "qwen_image_2511", |
| | "bagel": "bagel", |
| | "omnigen2": "omnigen2", |
| | "unicot": "unicot", |
| | "reflect_dit": "reflect_dit", |
| | "reflectionflow_qwen8b": "reflectionflow", |
| | "thinkgen": "thinkgen", |
| | "reasonedit": "reasonedit", |
| | } |
| |
|
| | CASE_TYPE_NAMES = { |
| | "type1_answer_wrong": "Type 1: Answer 错误", |
| | "type2_explanation_wrong": "Type 2: Explanation 错误→编辑失败", |
| | "type3_edit_better": "Type 3: Edit Prompt 更优", |
| | } |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def load_cases(eval_model_key: str) -> Dict: |
| | filename = f"detailed_cases_{eval_model_key}.json" |
| | filepath = CURRENT_DIR / filename |
| | if filepath.exists(): |
| | with open(filepath, "r", encoding="utf-8") as f: |
| | return json.load(f) |
| | return {} |
| |
|
| | _cases_cache = {} |
| |
|
| | def get_cases(eval_model_key: str) -> Dict: |
| | if eval_model_key not in _cases_cache: |
| | _cases_cache[eval_model_key] = load_cases(eval_model_key) |
| | return _cases_cache[eval_model_key] |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def load_image_256(path: Path) -> Optional[np.ndarray]: |
| | if path and path.exists(): |
| | try: |
| | img = Image.open(path) |
| | img.thumbnail((256, 256), Image.Resampling.LANCZOS) |
| | return np.array(img) |
| | except Exception: |
| | pass |
| | return None |
| |
|
| |
|
| | def get_bad_image_path(bad_image_rel: str) -> Path: |
| | return CURRENT_DIR / "images" / bad_image_rel |
| |
|
| |
|
| | def get_edited_image_path(bad_image_rel: str, verifier: str, editor: str) -> Path: |
| | bad_image_path_obj = Path(bad_image_rel) |
| | filename = f"{editor}_{verifier}_{bad_image_path_obj.stem}{bad_image_path_obj.suffix}" |
| | return CURRENT_DIR / "edited_images" / filename |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def get_case_list(eval_model_key: str, baseline: str, comparison: str, case_type: str) -> List[str]: |
| | cases_data = get_cases(eval_model_key) |
| | if not cases_data: |
| | return [] |
| | cases = cases_data.get(baseline, {}).get(comparison, {}).get(case_type, []) |
| | return [f"idx={c['idx']} | {c['category']} | {c['original_prompt'][:40]}..." for c in cases] |
| |
|
| |
|
| | def update_case_list(eval_model_key: str, baseline: str, comparison: str, case_type: str): |
| | cases = get_case_list(eval_model_key, baseline, comparison, case_type) |
| | return gr.update(choices=cases, value=cases[0] if cases else None) |
| |
|
| |
|
| | def get_comparison_choices(eval_model_key: str, baseline: str) -> List[str]: |
| | cases_data = get_cases(eval_model_key) |
| | if not cases_data or baseline not in cases_data: |
| | return COMPARISON_MODELS |
| | available = [] |
| | for comp in COMPARISON_MODELS: |
| | comp_data = cases_data.get(baseline, {}).get(comp, {}) |
| | total = sum(len(comp_data.get(t, [])) for t in CASE_TYPE_NAMES.keys()) |
| | if total > 0: |
| | available.append(comp) |
| | return available if available else COMPARISON_MODELS |
| |
|
| |
|
| | def update_comparison_choices(eval_model_key: str, baseline: str): |
| | choices = get_comparison_choices(eval_model_key, baseline) |
| | return gr.update(choices=choices, value=choices[0] if choices else None) |
| |
|
| |
|
| | def show_case(eval_model_key: str, baseline: str, comparison: str, case_type: str, case_idx_str: str): |
| | empty_result = ("请选择 case", "", "", "", "", "", None, None, None) |
| |
|
| | if not case_idx_str: |
| | return empty_result |
| |
|
| | cases_data = get_cases(eval_model_key) |
| | if not cases_data: |
| | return ("数据未加载", "", "", "", "", "", None, None, None) |
| |
|
| | try: |
| | idx = int(case_idx_str.split("|")[0].replace("idx=", "").strip()) |
| | except (ValueError, IndexError): |
| | return ("无法解析 idx", "", "", "", "", "", None, None, None) |
| |
|
| | cases = cases_data.get(baseline, {}).get(comparison, {}).get(case_type, []) |
| | case = next((c for c in cases if c["idx"] == idx), None) |
| |
|
| | if not case: |
| | return ("未找到 case", "", "", "", "", "", None, None, None) |
| |
|
| | bad_img = load_image_256(get_bad_image_path(case["bad_image"])) |
| | baseline_editor = case.get("baseline_editor", EDITOR_CONFIG.get(baseline, "qwen_image_2511")) |
| | comparison_editor = case.get("comparison_editor", EDITOR_CONFIG.get(comparison, "qwen_image_2511")) |
| |
|
| | baseline_img = load_image_256(get_edited_image_path(case["bad_image"], baseline, baseline_editor)) |
| | comparison_img = load_image_256(get_edited_image_path(case["bad_image"], comparison, comparison_editor)) |
| |
|
| | info_col = f"""### Case 详情 |
| | - **评估模型**: {eval_model_key.upper()} |
| | - **类型**: {CASE_TYPE_NAMES.get(case_type, case_type)} |
| | - **idx**: {case["idx"]} | **类别**: {case["category"]} |
| | - **Prompt**: *{case["original_prompt"]}*""" |
| |
|
| | gt_col = f"""### Ground Truth |
| | - **Answer**: `{case["gt_answer"]}` |
| | - **Explanation**: {case.get("gt_explanation", "N/A")}""" |
| |
|
| | baseline_col = f"""### {baseline} (基准) |
| | - **Answer**: `{case["baseline_answer"]}` {"✓" if case["baseline_answer_correct"] else "✗"} | **Exp评估**: {"✓" if case["baseline_explanation_correct"] else "✗"} |
| | - **Explanation**: {case["baseline_explanation"]} |
| | - **Edit指令**: {case["baseline_edit_prompt"]} |
| | - **I_Score**: **{case["baseline_i_score"]:.3f}** | Edited Acc: **{case["baseline_edited_acc"]:.3f}**""" |
| |
|
| | comp_col = f"""### {comparison} (对比) |
| | - **Answer**: `{case["comparison_answer"]}` {"✓" if case["comparison_answer_correct"] else "✗"} | **Exp评估**: {"✓" if case["comparison_explanation_correct"] else "✗"} |
| | - **Explanation**: {case["comparison_explanation"] if case["comparison_explanation"] else "—"} |
| | - **Edit指令**: {case["comparison_edit_prompt"] if case["comparison_edit_prompt"] else "—"} |
| | - **I_Score**: **{case["comparison_i_score"]:.3f}** | Edited Acc: **{case["comparison_edited_acc"]:.3f}**""" |
| |
|
| | i_diff = case["baseline_i_score"] - case["comparison_i_score"] |
| | ans_adv = "✓" if case["baseline_answer_correct"] and not case["comparison_answer_correct"] else "-" |
| | exp_adv = "✓" if case["baseline_explanation_correct"] and not case["comparison_explanation_correct"] else "-" |
| | summary = f"""**对比总结**: Answer({ans_adv}) | Explanation({exp_adv}) | I_Score: {baseline} **{case["baseline_i_score"]:.3f}** vs {comparison} {case["comparison_i_score"]:.3f} = **+{i_diff:.3f}** | Edited Acc: **{case["baseline_edited_acc"]:.3f}** vs {case["comparison_edited_acc"]:.3f}""" |
| |
|
| | img_labels = f"原图 → {baseline} 编辑 ({baseline_editor}) → {comparison} 编辑 ({comparison_editor})" |
| |
|
| | return (info_col, gt_col, baseline_col, comp_col, summary, img_labels, bad_img, baseline_img, comparison_img) |
| |
|
| |
|
| | def get_statistics(eval_model_key: str) -> str: |
| | cases_data = get_cases(eval_model_key) |
| | if not cases_data: |
| | return "数据未加载" |
| |
|
| | lines = [f"### 统计摘要 ({eval_model_key.upper()})\n"] |
| |
|
| | for baseline in BASELINE_MODELS: |
| | if baseline not in cases_data: |
| | continue |
| | lines.append(f"\n**{baseline}**:\n") |
| | lines.append("| 对比模型 | Type1 | Type2 | Type3 | 总计 |") |
| | lines.append("|----------|-------|-------|-------|------|") |
| |
|
| | for comparison in COMPARISON_MODELS: |
| | comp_data = cases_data.get(baseline, {}).get(comparison, {}) |
| | t1 = len(comp_data.get("type1_answer_wrong", [])) |
| | t2 = len(comp_data.get("type2_explanation_wrong", [])) |
| | t3 = len(comp_data.get("type3_edit_better", [])) |
| | total = t1 + t2 + t3 |
| | if total > 0: |
| | lines.append(f"| {comparison} | {t1} | {t2} | {t3} | {total} |") |
| |
|
| | return "\n".join(lines) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def create_app(): |
| | with gr.Blocks(title="ReflectionBench Case Viewer", theme=gr.themes.Soft()) as demo: |
| | gr.Markdown("# ReflectionBench Case 展示") |
| | gr.Markdown("对比 Reflector (基准模型) 与其他 baseline 模型的表现差异") |
| |
|
| | with gr.Accordion("统计摘要", open=False): |
| | stats_md = gr.Markdown() |
| |
|
| | with gr.Row(): |
| | eval_model = gr.Radio(choices=list(EVAL_MODELS.keys()), label="评估模型", value="qwen", scale=1) |
| | baseline = gr.Dropdown(choices=BASELINE_MODELS, label="基准模型 (Reflector)", value=BASELINE_MODELS[0], scale=1) |
| | comparison = gr.Dropdown(choices=COMPARISON_MODELS, label="对比模型", value=COMPARISON_MODELS[0], scale=1) |
| |
|
| | with gr.Row(): |
| | case_type = gr.Radio( |
| | choices=[ |
| | ("Type1: Answer错误", "type1_answer_wrong"), |
| | ("Type2: Exp错误→编辑失败", "type2_explanation_wrong"), |
| | ("Type3: Edit更优", "type3_edit_better"), |
| | ], |
| | label="Case类型", value="type1_answer_wrong", scale=2 |
| | ) |
| | case_dropdown = gr.Dropdown(choices=[], label="选择Case", scale=2) |
| |
|
| | with gr.Row(): |
| | info_md = gr.Markdown() |
| | gt_md = gr.Markdown() |
| |
|
| | with gr.Row(): |
| | baseline_md = gr.Markdown() |
| | comparison_md = gr.Markdown() |
| |
|
| | summary_md = gr.Markdown() |
| | img_label = gr.Markdown() |
| |
|
| | with gr.Row(): |
| | bad_image = gr.Image(label="原图", scale=1, height=200) |
| | baseline_edited = gr.Image(label="基准模型编辑后", scale=1, height=200) |
| | comparison_edited = gr.Image(label="对比模型编辑后", scale=1, height=200) |
| |
|
| | outputs = [info_md, gt_md, baseline_md, comparison_md, summary_md, img_label, bad_image, baseline_edited, comparison_edited] |
| |
|
| | def on_eval_model_change(eval_key, base, comp, ctype): |
| | stats = get_statistics(eval_key) |
| | comp_update = update_comparison_choices(eval_key, base) |
| | case_update = update_case_list(eval_key, base, comp, ctype) |
| | return stats, comp_update, case_update |
| |
|
| | eval_model.change(fn=on_eval_model_change, inputs=[eval_model, baseline, comparison, case_type], outputs=[stats_md, comparison, case_dropdown]) |
| |
|
| | def on_baseline_change(eval_key, base, comp, ctype): |
| | comp_update = update_comparison_choices(eval_key, base) |
| | case_update = update_case_list(eval_key, base, comp, ctype) |
| | return comp_update, case_update |
| |
|
| | baseline.change(fn=on_baseline_change, inputs=[eval_model, baseline, comparison, case_type], outputs=[comparison, case_dropdown]) |
| | comparison.change(fn=update_case_list, inputs=[eval_model, baseline, comparison, case_type], outputs=[case_dropdown]) |
| | case_type.change(fn=update_case_list, inputs=[eval_model, baseline, comparison, case_type], outputs=[case_dropdown]) |
| | case_dropdown.change(fn=show_case, inputs=[eval_model, baseline, comparison, case_type, case_dropdown], outputs=outputs) |
| |
|
| | def on_load(eval_key, base, comp, ctype): |
| | stats = get_statistics(eval_key) |
| | cases = get_case_list(eval_key, base, comp, ctype) |
| | case_val = cases[0] if cases else None |
| | if case_val: |
| | case_result = show_case(eval_key, base, comp, ctype, case_val) |
| | else: |
| | case_result = ("请选择 case", "", "", "", "", "", None, None, None) |
| | return (stats, gr.update(choices=cases, value=case_val)) + case_result |
| |
|
| | demo.load(fn=on_load, inputs=[eval_model, baseline, comparison, case_type], outputs=[stats_md, case_dropdown] + outputs) |
| |
|
| | return demo |
| |
|
| |
|
| | if __name__ == "__main__": |
| | demo = create_app() |
| | demo.launch() |
| |
|