Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import gradio as gr | |
| from pathlib import Path | |
| from io import BytesIO | |
| import base64 | |
| import requests | |
| from PIL import Image | |
| # =============== 配置:HF dataset 仓库 ID ================= | |
| HF_DATASET_ID = "maplebb/UniREditBench-Results" | |
| HF_BASE_URL = f"https://huggingface.co/datasets/{HF_DATASET_ID}/resolve/main" | |
| BASE_HF_URL = f"https://huggingface.co/datasets/{HF_DATASET_ID}/resolve/main" | |
| _image_cache = {} | |
| _html_cache = {} | |
| def get_url_response(url): | |
| try: | |
| resp = requests.get(url, timeout=10) | |
| resp.raise_for_status() | |
| return resp | |
| except Exception as e: | |
| print(f"Error fetching {url}: {e}") | |
| return None | |
| def load_image_uniredit(rel_path: str): | |
| """从 UniREditBench 数据集拉图(后端 requests)""" | |
| rel_path = rel_path.lstrip("/") | |
| if rel_path in _image_cache: | |
| return _image_cache[rel_path] | |
| url = f"{BASE_HF_URL}/{rel_path}" | |
| resp = get_url_response(url) | |
| if not resp: | |
| return None | |
| img = Image.open(BytesIO(resp.content)).convert("RGB") | |
| _image_cache[rel_path] = img | |
| return img | |
| def pil_to_base64(img): | |
| if img is None: | |
| return "" | |
| buf = BytesIO() | |
| img.save(buf, format="PNG") | |
| s = base64.b64encode(buf.getvalue()).decode("utf-8") | |
| return f"data:image/png;base64,{s}" | |
| ROOT_DIR = Path(__file__).resolve().parent | |
| # =============== data.json 还是放在 Space 本地 ================= | |
| # 如果想把 data.json 也放到 dataset 里,可以改成用 hf_hub_download(后面我再写) | |
| JSON_PATH = ROOT_DIR / "data.json" | |
| # =============== 读 json & 建索引 ================= | |
| with open(JSON_PATH, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| ALL_NAMES = sorted({item["name"] for item in data}) | |
| # (name, idx_str) -> item | |
| INDEX_MAP = {(item["name"], str(item["idx"])): item for item in data} | |
| # =============== baseline 模型列表(现在不能再 os.listdir,本地没图了) ================= | |
| # 建议直接手动写齐你有的模型 | |
| ALL_MODELS = [ | |
| "Bagel-Think", | |
| "DreamOmni2", | |
| "GPT-4o", | |
| "Lumina-DiMOO", | |
| "MagicBrush", | |
| "Nano-Banana", | |
| "Omnigen2", | |
| "Qwen-Image-Edit", | |
| "Seedream4.0", | |
| "UniREdit-Bagel", | |
| "UniWorld-V2", | |
| ] | |
| PRIORITY_MODELS = ["Bagel-Think", "GPT-4o", "Qwen-Image-Edit", "UniREdit-Bagel", "Nano-Banana"] | |
| BASELINE_MODELS = ( | |
| [m for m in PRIORITY_MODELS if m in ALL_MODELS] | |
| + sorted([m for m in ALL_MODELS if m not in PRIORITY_MODELS]) | |
| ) | |
| HTML_HEAD = '<table class="center">' | |
| HTML_TAIL = "</table>" | |
| N_COL = 4 | |
| WIDTH = 100 // N_COL | |
| def get_indices_for_name(name: str): | |
| """给定 name,返回该类下面所有 idx(字符串),按数字排序。""" | |
| if not name: | |
| return [] | |
| idxs = {str(item["idx"]) for item in data if item["name"] == name} | |
| return sorted(idxs, key=lambda x: int(x)) | |
| def render_img_html(rel_path: str, max_h=512): | |
| if not rel_path: | |
| return "<p>No original image.</p>" | |
| key = (rel_path, None) # 不区分高度 | |
| if key in _html_cache: | |
| data_url = _html_cache[key] | |
| return f'<img src="{data_url}" style="max-width:100%; max-height:{max_h}px;">' | |
| img = load_image_uniredit(rel_path) | |
| if img is None: | |
| return "<p>Failed to load image.</p>" | |
| # 不再 resize,直接按原分辨率编码 | |
| data_url = pil_to_base64(img) | |
| _html_cache[key] = data_url | |
| return f'<img src="{data_url}" style="max-width:100%; max-height:{max_h}px;">' | |
| def get_baseline_gallery(name: str, idx: str, models): | |
| """生成 baseline 图像的 HTML 表格(使用远程 URL).""" | |
| if not name or not idx: | |
| return "<p>Please select name and idx.</p>" | |
| # 没勾选就默认显示所有 baseline | |
| if not models: | |
| models = BASELINE_MODELS | |
| models = list(models) | |
| html = HTML_HEAD | |
| num_models = len(models) | |
| for row in range((num_models - 1) // N_COL + 1): | |
| sub_models = models[row * N_COL : (row + 1) * N_COL] | |
| # 第一行:模型名 | |
| html += "<tr>" | |
| for m in sub_models: | |
| html += ( | |
| f'<td width="{WIDTH}%" style="text-align:center;"><h4>{m}</h4></td>' | |
| ) | |
| for _ in range(N_COL - len(sub_models)): | |
| html += f'<td width="{WIDTH}%"></td>' | |
| html += "</tr>" | |
| # 第二行:对应图片 | |
| html += "<tr>" | |
| for m in sub_models: | |
| rel_path = f"Unireditbench_baseline_images/{m}/{name}/{idx}.png" | |
| cell = render_img_html(rel_path, max_h=256) | |
| html += f'<td width="{WIDTH}%" style="text-align:center;">{cell}</td>' | |
| for _ in range(N_COL - len(sub_models)): | |
| html += f'<td width="{WIDTH}%"></td>' | |
| html += "</tr>" | |
| html += HTML_TAIL | |
| return html | |
| def update_idx_dropdown(name): | |
| """当 name 改变时,更新 idx 下拉选项.""" | |
| idxs = get_indices_for_name(name) | |
| default = idxs[0] if idxs else None | |
| return gr.Dropdown(choices=idxs, value=default) | |
| def load_sample(name, idx, selected_models): | |
| key = (name, str(idx)) | |
| item = INDEX_MAP.get(key) | |
| if item is None: | |
| info_md = f"**Not found:** name = {name}, idx = {idx}" | |
| return info_md, "", "", "<p>Sample not found.</p>", "<p>Sample not found.</p>" | |
| info_md = ( | |
| f"**Category (name):** {item['name']} \n" | |
| f"**Index (idx):** {item['idx']}" | |
| ) | |
| instruction = item.get("instruction", "") | |
| rules = item.get("rules", "") | |
| # data.json 里原本的 original_image_path 建议直接保持为相对路径: | |
| # "original_image_path": "original_image/jewel2/0001.png" | |
| orig_rel = item.get("original_image_path", "") | |
| orig_html = render_img_html(orig_rel, max_h=512) | |
| gallery_html = get_baseline_gallery(name, str(idx), selected_models) | |
| return info_md, instruction, rules, orig_html, gallery_html | |
| def _step_idx(name: str, idx: str, direction: int): | |
| """direction=-1 上一个, direction=+1 下一个;到边界就停住。""" | |
| idxs = get_indices_for_name(name) | |
| if not idxs: | |
| return None | |
| idx = str(idx) | |
| if idx not in idxs: | |
| cur = 0 | |
| else: | |
| cur = idxs.index(idx) | |
| new_pos = cur + direction | |
| new_pos = max(0, min(len(idxs) - 1, new_pos)) | |
| return idxs[new_pos] | |
| def prev_sample(name, idx, selected_models): | |
| new_idx = _step_idx(name, idx, direction=-1) | |
| if new_idx is None: | |
| # 保持不变 | |
| info_md, instruction, rules, orig_html, gallery_html = load_sample(name, idx, selected_models) | |
| return gr.update(), info_md, instruction, rules, orig_html, gallery_html | |
| info_md, instruction, rules, orig_html, gallery_html = load_sample(name, new_idx, selected_models) | |
| return gr.update(value=new_idx), info_md, instruction, rules, orig_html, gallery_html | |
| def next_sample(name, idx, selected_models): | |
| new_idx = _step_idx(name, idx, direction=+1) | |
| if new_idx is None: | |
| info_md, instruction, rules, orig_html, gallery_html = load_sample(name, idx, selected_models) | |
| return gr.update(), info_md, instruction, rules, orig_html, gallery_html | |
| info_md, instruction, rules, orig_html, gallery_html = load_sample(name, new_idx, selected_models) | |
| return gr.update(value=new_idx), info_md, instruction, rules, orig_html, gallery_html | |
| # ================== Gradio UI ================== | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# UniREditBench Gallery") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| default_name = ALL_NAMES[0] if ALL_NAMES else None | |
| default_idxs = get_indices_for_name(default_name) if default_name else [] | |
| default_idx = default_idxs[0] if default_idxs else None | |
| name_dropdown = gr.Dropdown( | |
| label="Category (name)", | |
| choices=ALL_NAMES, | |
| value=default_name, | |
| ) | |
| idx_dropdown = gr.Dropdown( | |
| label="Idx", | |
| choices=default_idxs, | |
| value=default_idx, | |
| ) | |
| with gr.Row(): | |
| prev_button = gr.Button("Prev") | |
| next_button = gr.Button("Next") | |
| load_button = gr.Button("Load sample") | |
| model_checkboxes = gr.CheckboxGroup( | |
| label="Baselines to show", | |
| choices=BASELINE_MODELS, # 还是所有模型都可以选 | |
| value=PRIORITY_MODELS, # ✅ 默认只勾选优先那几个 | |
| ) | |
| with gr.Column(scale=2): | |
| info_markdown = gr.Markdown(label="Info") | |
| instruction_box = gr.Textbox( | |
| label="Instruction", lines=4, interactive=False | |
| ) | |
| rules_box = gr.Textbox( | |
| label="Rules", lines=3, interactive=False | |
| ) | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Original Image") | |
| orig_image_html = gr.HTML() | |
| gallery_html = gr.HTML() | |
| name_dropdown.change( | |
| fn=update_idx_dropdown, | |
| inputs=name_dropdown, | |
| outputs=idx_dropdown, | |
| ) | |
| load_button.click( | |
| fn=load_sample, | |
| inputs=[name_dropdown, idx_dropdown, model_checkboxes], | |
| outputs=[info_markdown, instruction_box, rules_box, orig_image_html, gallery_html] | |
| ) | |
| # Prev / Next:会同时更新 idx_dropdown + 刷新所有内容 | |
| prev_button.click( | |
| fn=prev_sample, | |
| inputs=[name_dropdown, idx_dropdown, model_checkboxes], | |
| outputs=[idx_dropdown, info_markdown, instruction_box, rules_box, orig_image_html, gallery_html] | |
| ) | |
| next_button.click( | |
| fn=next_sample, | |
| inputs=[name_dropdown, idx_dropdown, model_checkboxes], | |
| outputs=[idx_dropdown, info_markdown, instruction_box, rules_box, orig_image_html, gallery_html] | |
| ) | |
| if __name__ == "__main__": | |
| # 不再需要 allowed_paths / set_static_paths | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |