import os import json import gradio as gr from pathlib import Path from io import BytesIO import base64 import requests from PIL import Image # =============== 配置:HF dataset 仓库 ID ================= HF_DATASET_ID = "maplebb/UniREditBench-Results" HF_BASE_URL = f"https://huggingface.co/datasets/{HF_DATASET_ID}/resolve/main" BASE_HF_URL = f"https://huggingface.co/datasets/{HF_DATASET_ID}/resolve/main" _image_cache = {} _html_cache = {} def get_url_response(url): try: resp = requests.get(url, timeout=10) resp.raise_for_status() return resp except Exception as e: print(f"Error fetching {url}: {e}") return None def load_image_uniredit(rel_path: str): """从 UniREditBench 数据集拉图(后端 requests)""" rel_path = rel_path.lstrip("/") if rel_path in _image_cache: return _image_cache[rel_path] url = f"{BASE_HF_URL}/{rel_path}" resp = get_url_response(url) if not resp: return None img = Image.open(BytesIO(resp.content)).convert("RGB") _image_cache[rel_path] = img return img def pil_to_base64(img): if img is None: return "" buf = BytesIO() img.save(buf, format="PNG") s = base64.b64encode(buf.getvalue()).decode("utf-8") return f"data:image/png;base64,{s}" ROOT_DIR = Path(__file__).resolve().parent # =============== data.json 还是放在 Space 本地 ================= # 如果想把 data.json 也放到 dataset 里,可以改成用 hf_hub_download(后面我再写) JSON_PATH = ROOT_DIR / "data.json" # =============== 读 json & 建索引 ================= with open(JSON_PATH, "r", encoding="utf-8") as f: data = json.load(f) ALL_NAMES = sorted({item["name"] for item in data}) # (name, idx_str) -> item INDEX_MAP = {(item["name"], str(item["idx"])): item for item in data} # =============== baseline 模型列表(现在不能再 os.listdir,本地没图了) ================= # 建议直接手动写齐你有的模型 ALL_MODELS = [ "Bagel-Think", "DreamOmni2", "GPT-4o", "Lumina-DiMOO", "MagicBrush", "Nano-Banana", "Omnigen2", "Qwen-Image-Edit", "Seedream4.0", "UniREdit-Bagel", "UniWorld-V2", ] PRIORITY_MODELS = ["Bagel-Think", "GPT-4o", "Qwen-Image-Edit", "UniREdit-Bagel", "Nano-Banana"] BASELINE_MODELS = ( [m for m in PRIORITY_MODELS if m in ALL_MODELS] + sorted([m for m in ALL_MODELS if m not in PRIORITY_MODELS]) ) HTML_HEAD = '' HTML_TAIL = "
" N_COL = 4 WIDTH = 100 // N_COL def get_indices_for_name(name: str): """给定 name,返回该类下面所有 idx(字符串),按数字排序。""" if not name: return [] idxs = {str(item["idx"]) for item in data if item["name"] == name} return sorted(idxs, key=lambda x: int(x)) def render_img_html(rel_path: str, max_h=512): if not rel_path: return "

No original image.

" key = (rel_path, None) # 不区分高度 if key in _html_cache: data_url = _html_cache[key] return f'' img = load_image_uniredit(rel_path) if img is None: return "

Failed to load image.

" # 不再 resize,直接按原分辨率编码 data_url = pil_to_base64(img) _html_cache[key] = data_url return f'' def get_baseline_gallery(name: str, idx: str, models): """生成 baseline 图像的 HTML 表格(使用远程 URL).""" if not name or not idx: return "

Please select name and idx.

" # 没勾选就默认显示所有 baseline if not models: models = BASELINE_MODELS models = list(models) html = HTML_HEAD num_models = len(models) for row in range((num_models - 1) // N_COL + 1): sub_models = models[row * N_COL : (row + 1) * N_COL] # 第一行:模型名 html += "" for m in sub_models: html += ( f'

{m}

' ) for _ in range(N_COL - len(sub_models)): html += f'' html += "" # 第二行:对应图片 html += "" for m in sub_models: rel_path = f"Unireditbench_baseline_images/{m}/{name}/{idx}.png" cell = render_img_html(rel_path, max_h=256) html += f'{cell}' for _ in range(N_COL - len(sub_models)): html += f'' html += "" html += HTML_TAIL return html def update_idx_dropdown(name): """当 name 改变时,更新 idx 下拉选项.""" idxs = get_indices_for_name(name) default = idxs[0] if idxs else None return gr.Dropdown(choices=idxs, value=default) def load_sample(name, idx, selected_models): key = (name, str(idx)) item = INDEX_MAP.get(key) if item is None: info_md = f"**Not found:** name = {name}, idx = {idx}" return info_md, "", "", "

Sample not found.

", "

Sample not found.

" info_md = ( f"**Category (name):** {item['name']} \n" f"**Index (idx):** {item['idx']}" ) instruction = item.get("instruction", "") rules = item.get("rules", "") # data.json 里原本的 original_image_path 建议直接保持为相对路径: # "original_image_path": "original_image/jewel2/0001.png" orig_rel = item.get("original_image_path", "") orig_html = render_img_html(orig_rel, max_h=512) gallery_html = get_baseline_gallery(name, str(idx), selected_models) return info_md, instruction, rules, orig_html, gallery_html def _step_idx(name: str, idx: str, direction: int): """direction=-1 上一个, direction=+1 下一个;到边界就停住。""" idxs = get_indices_for_name(name) if not idxs: return None idx = str(idx) if idx not in idxs: cur = 0 else: cur = idxs.index(idx) new_pos = cur + direction new_pos = max(0, min(len(idxs) - 1, new_pos)) return idxs[new_pos] def prev_sample(name, idx, selected_models): new_idx = _step_idx(name, idx, direction=-1) if new_idx is None: # 保持不变 info_md, instruction, rules, orig_html, gallery_html = load_sample(name, idx, selected_models) return gr.update(), info_md, instruction, rules, orig_html, gallery_html info_md, instruction, rules, orig_html, gallery_html = load_sample(name, new_idx, selected_models) return gr.update(value=new_idx), info_md, instruction, rules, orig_html, gallery_html def next_sample(name, idx, selected_models): new_idx = _step_idx(name, idx, direction=+1) if new_idx is None: info_md, instruction, rules, orig_html, gallery_html = load_sample(name, idx, selected_models) return gr.update(), info_md, instruction, rules, orig_html, gallery_html info_md, instruction, rules, orig_html, gallery_html = load_sample(name, new_idx, selected_models) return gr.update(value=new_idx), info_md, instruction, rules, orig_html, gallery_html # ================== Gradio UI ================== with gr.Blocks() as demo: gr.Markdown("# UniREditBench Gallery") with gr.Row(): with gr.Column(scale=1): default_name = ALL_NAMES[0] if ALL_NAMES else None default_idxs = get_indices_for_name(default_name) if default_name else [] default_idx = default_idxs[0] if default_idxs else None name_dropdown = gr.Dropdown( label="Category (name)", choices=ALL_NAMES, value=default_name, ) idx_dropdown = gr.Dropdown( label="Idx", choices=default_idxs, value=default_idx, ) with gr.Row(): prev_button = gr.Button("Prev") next_button = gr.Button("Next") load_button = gr.Button("Load sample") model_checkboxes = gr.CheckboxGroup( label="Baselines to show", choices=BASELINE_MODELS, # 还是所有模型都可以选 value=PRIORITY_MODELS, # ✅ 默认只勾选优先那几个 ) with gr.Column(scale=2): info_markdown = gr.Markdown(label="Info") instruction_box = gr.Textbox( label="Instruction", lines=4, interactive=False ) rules_box = gr.Textbox( label="Rules", lines=3, interactive=False ) with gr.Column(scale=2): gr.Markdown("### Original Image") orig_image_html = gr.HTML() gallery_html = gr.HTML() name_dropdown.change( fn=update_idx_dropdown, inputs=name_dropdown, outputs=idx_dropdown, ) load_button.click( fn=load_sample, inputs=[name_dropdown, idx_dropdown, model_checkboxes], outputs=[info_markdown, instruction_box, rules_box, orig_image_html, gallery_html] ) # Prev / Next:会同时更新 idx_dropdown + 刷新所有内容 prev_button.click( fn=prev_sample, inputs=[name_dropdown, idx_dropdown, model_checkboxes], outputs=[idx_dropdown, info_markdown, instruction_box, rules_box, orig_image_html, gallery_html] ) next_button.click( fn=next_sample, inputs=[name_dropdown, idx_dropdown, model_checkboxes], outputs=[idx_dropdown, info_markdown, instruction_box, rules_box, orig_image_html, gallery_html] ) if __name__ == "__main__": # 不再需要 allowed_paths / set_static_paths demo.launch(server_name="0.0.0.0", server_port=7860)