maplebb's picture
Update app.py
ddff00f verified
import os
import json
import gradio as gr
from pathlib import Path
from io import BytesIO
import base64
import requests
from PIL import Image
# =============== 配置:HF dataset 仓库 ID =================
HF_DATASET_ID = "maplebb/UniREditBench-Results"
HF_BASE_URL = f"https://huggingface.co/datasets/{HF_DATASET_ID}/resolve/main"
BASE_HF_URL = f"https://huggingface.co/datasets/{HF_DATASET_ID}/resolve/main"
_image_cache = {}
_html_cache = {}
def get_url_response(url):
try:
resp = requests.get(url, timeout=10)
resp.raise_for_status()
return resp
except Exception as e:
print(f"Error fetching {url}: {e}")
return None
def load_image_uniredit(rel_path: str):
"""从 UniREditBench 数据集拉图(后端 requests)"""
rel_path = rel_path.lstrip("/")
if rel_path in _image_cache:
return _image_cache[rel_path]
url = f"{BASE_HF_URL}/{rel_path}"
resp = get_url_response(url)
if not resp:
return None
img = Image.open(BytesIO(resp.content)).convert("RGB")
_image_cache[rel_path] = img
return img
def pil_to_base64(img):
if img is None:
return ""
buf = BytesIO()
img.save(buf, format="PNG")
s = base64.b64encode(buf.getvalue()).decode("utf-8")
return f"data:image/png;base64,{s}"
ROOT_DIR = Path(__file__).resolve().parent
# =============== data.json 还是放在 Space 本地 =================
# 如果想把 data.json 也放到 dataset 里,可以改成用 hf_hub_download(后面我再写)
JSON_PATH = ROOT_DIR / "data.json"
# =============== 读 json & 建索引 =================
with open(JSON_PATH, "r", encoding="utf-8") as f:
data = json.load(f)
ALL_NAMES = sorted({item["name"] for item in data})
# (name, idx_str) -> item
INDEX_MAP = {(item["name"], str(item["idx"])): item for item in data}
# =============== baseline 模型列表(现在不能再 os.listdir,本地没图了) =================
# 建议直接手动写齐你有的模型
ALL_MODELS = [
"Bagel-Think",
"DreamOmni2",
"GPT-4o",
"Lumina-DiMOO",
"MagicBrush",
"Nano-Banana",
"Omnigen2",
"Qwen-Image-Edit",
"Seedream4.0",
"UniREdit-Bagel",
"UniWorld-V2",
]
PRIORITY_MODELS = ["Bagel-Think", "GPT-4o", "Qwen-Image-Edit", "UniREdit-Bagel", "Nano-Banana"]
BASELINE_MODELS = (
[m for m in PRIORITY_MODELS if m in ALL_MODELS]
+ sorted([m for m in ALL_MODELS if m not in PRIORITY_MODELS])
)
HTML_HEAD = '<table class="center">'
HTML_TAIL = "</table>"
N_COL = 4
WIDTH = 100 // N_COL
def get_indices_for_name(name: str):
"""给定 name,返回该类下面所有 idx(字符串),按数字排序。"""
if not name:
return []
idxs = {str(item["idx"]) for item in data if item["name"] == name}
return sorted(idxs, key=lambda x: int(x))
def render_img_html(rel_path: str, max_h=512):
if not rel_path:
return "<p>No original image.</p>"
key = (rel_path, None) # 不区分高度
if key in _html_cache:
data_url = _html_cache[key]
return f'<img src="{data_url}" style="max-width:100%; max-height:{max_h}px;">'
img = load_image_uniredit(rel_path)
if img is None:
return "<p>Failed to load image.</p>"
# 不再 resize,直接按原分辨率编码
data_url = pil_to_base64(img)
_html_cache[key] = data_url
return f'<img src="{data_url}" style="max-width:100%; max-height:{max_h}px;">'
def get_baseline_gallery(name: str, idx: str, models):
"""生成 baseline 图像的 HTML 表格(使用远程 URL)."""
if not name or not idx:
return "<p>Please select name and idx.</p>"
# 没勾选就默认显示所有 baseline
if not models:
models = BASELINE_MODELS
models = list(models)
html = HTML_HEAD
num_models = len(models)
for row in range((num_models - 1) // N_COL + 1):
sub_models = models[row * N_COL : (row + 1) * N_COL]
# 第一行:模型名
html += "<tr>"
for m in sub_models:
html += (
f'<td width="{WIDTH}%" style="text-align:center;"><h4>{m}</h4></td>'
)
for _ in range(N_COL - len(sub_models)):
html += f'<td width="{WIDTH}%"></td>'
html += "</tr>"
# 第二行:对应图片
html += "<tr>"
for m in sub_models:
rel_path = f"Unireditbench_baseline_images/{m}/{name}/{idx}.png"
cell = render_img_html(rel_path, max_h=256)
html += f'<td width="{WIDTH}%" style="text-align:center;">{cell}</td>'
for _ in range(N_COL - len(sub_models)):
html += f'<td width="{WIDTH}%"></td>'
html += "</tr>"
html += HTML_TAIL
return html
def update_idx_dropdown(name):
"""当 name 改变时,更新 idx 下拉选项."""
idxs = get_indices_for_name(name)
default = idxs[0] if idxs else None
return gr.Dropdown(choices=idxs, value=default)
def load_sample(name, idx, selected_models):
key = (name, str(idx))
item = INDEX_MAP.get(key)
if item is None:
info_md = f"**Not found:** name = {name}, idx = {idx}"
return info_md, "", "", "<p>Sample not found.</p>", "<p>Sample not found.</p>"
info_md = (
f"**Category (name):** {item['name']} \n"
f"**Index (idx):** {item['idx']}"
)
instruction = item.get("instruction", "")
rules = item.get("rules", "")
# data.json 里原本的 original_image_path 建议直接保持为相对路径:
# "original_image_path": "original_image/jewel2/0001.png"
orig_rel = item.get("original_image_path", "")
orig_html = render_img_html(orig_rel, max_h=512)
gallery_html = get_baseline_gallery(name, str(idx), selected_models)
return info_md, instruction, rules, orig_html, gallery_html
def _step_idx(name: str, idx: str, direction: int):
"""direction=-1 上一个, direction=+1 下一个;到边界就停住。"""
idxs = get_indices_for_name(name)
if not idxs:
return None
idx = str(idx)
if idx not in idxs:
cur = 0
else:
cur = idxs.index(idx)
new_pos = cur + direction
new_pos = max(0, min(len(idxs) - 1, new_pos))
return idxs[new_pos]
def prev_sample(name, idx, selected_models):
new_idx = _step_idx(name, idx, direction=-1)
if new_idx is None:
# 保持不变
info_md, instruction, rules, orig_html, gallery_html = load_sample(name, idx, selected_models)
return gr.update(), info_md, instruction, rules, orig_html, gallery_html
info_md, instruction, rules, orig_html, gallery_html = load_sample(name, new_idx, selected_models)
return gr.update(value=new_idx), info_md, instruction, rules, orig_html, gallery_html
def next_sample(name, idx, selected_models):
new_idx = _step_idx(name, idx, direction=+1)
if new_idx is None:
info_md, instruction, rules, orig_html, gallery_html = load_sample(name, idx, selected_models)
return gr.update(), info_md, instruction, rules, orig_html, gallery_html
info_md, instruction, rules, orig_html, gallery_html = load_sample(name, new_idx, selected_models)
return gr.update(value=new_idx), info_md, instruction, rules, orig_html, gallery_html
# ================== Gradio UI ==================
with gr.Blocks() as demo:
gr.Markdown("# UniREditBench Gallery")
with gr.Row():
with gr.Column(scale=1):
default_name = ALL_NAMES[0] if ALL_NAMES else None
default_idxs = get_indices_for_name(default_name) if default_name else []
default_idx = default_idxs[0] if default_idxs else None
name_dropdown = gr.Dropdown(
label="Category (name)",
choices=ALL_NAMES,
value=default_name,
)
idx_dropdown = gr.Dropdown(
label="Idx",
choices=default_idxs,
value=default_idx,
)
with gr.Row():
prev_button = gr.Button("Prev")
next_button = gr.Button("Next")
load_button = gr.Button("Load sample")
model_checkboxes = gr.CheckboxGroup(
label="Baselines to show",
choices=BASELINE_MODELS, # 还是所有模型都可以选
value=PRIORITY_MODELS, # ✅ 默认只勾选优先那几个
)
with gr.Column(scale=2):
info_markdown = gr.Markdown(label="Info")
instruction_box = gr.Textbox(
label="Instruction", lines=4, interactive=False
)
rules_box = gr.Textbox(
label="Rules", lines=3, interactive=False
)
with gr.Column(scale=2):
gr.Markdown("### Original Image")
orig_image_html = gr.HTML()
gallery_html = gr.HTML()
name_dropdown.change(
fn=update_idx_dropdown,
inputs=name_dropdown,
outputs=idx_dropdown,
)
load_button.click(
fn=load_sample,
inputs=[name_dropdown, idx_dropdown, model_checkboxes],
outputs=[info_markdown, instruction_box, rules_box, orig_image_html, gallery_html]
)
# Prev / Next:会同时更新 idx_dropdown + 刷新所有内容
prev_button.click(
fn=prev_sample,
inputs=[name_dropdown, idx_dropdown, model_checkboxes],
outputs=[idx_dropdown, info_markdown, instruction_box, rules_box, orig_image_html, gallery_html]
)
next_button.click(
fn=next_sample,
inputs=[name_dropdown, idx_dropdown, model_checkboxes],
outputs=[idx_dropdown, info_markdown, instruction_box, rules_box, orig_image_html, gallery_html]
)
if __name__ == "__main__":
# 不再需要 allowed_paths / set_static_paths
demo.launch(server_name="0.0.0.0", server_port=7860)