File size: 9,904 Bytes
d6808f3
 
 
 
ddff00f
 
 
 
d6808f3
 
 
 
ddff00f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6808f3
 
 
 
ddff00f
d6808f3
 
 
 
 
 
 
 
 
 
 
ddff00f
 
d6808f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddff00f
 
 
 
 
 
 
 
 
 
 
 
 
 
d6808f3
 
 
ddff00f
d6808f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddff00f
 
 
d6808f3
 
 
 
 
 
 
 
ddff00f
a0450b5
d6808f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddff00f
d6808f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddff00f
d6808f3
 
 
ddff00f
 
d6808f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddff00f
d6808f3
 
 
 
 
 
 
 
 
 
 
 
 
ddff00f
d6808f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddff00f
d6808f3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
import os
import json
import gradio as gr
from pathlib import Path
from io import BytesIO
import base64
import requests
from PIL import Image

# =============== 配置:HF dataset 仓库 ID =================
HF_DATASET_ID = "maplebb/UniREditBench-Results"
HF_BASE_URL = f"https://huggingface.co/datasets/{HF_DATASET_ID}/resolve/main"
BASE_HF_URL = f"https://huggingface.co/datasets/{HF_DATASET_ID}/resolve/main"

_image_cache = {}
_html_cache = {}

def get_url_response(url):
    try:
        resp = requests.get(url, timeout=10)
        resp.raise_for_status()
        return resp
    except Exception as e:
        print(f"Error fetching {url}: {e}")
        return None

def load_image_uniredit(rel_path: str):
    """从 UniREditBench 数据集拉图(后端 requests)"""
    rel_path = rel_path.lstrip("/")
    if rel_path in _image_cache:
        return _image_cache[rel_path]
    url = f"{BASE_HF_URL}/{rel_path}"
    resp = get_url_response(url)
    if not resp:
        return None
    img = Image.open(BytesIO(resp.content)).convert("RGB")
    _image_cache[rel_path] = img
    return img

def pil_to_base64(img):
    if img is None:
        return ""
    buf = BytesIO()
    img.save(buf, format="PNG")
    s = base64.b64encode(buf.getvalue()).decode("utf-8")
    return f"data:image/png;base64,{s}"

ROOT_DIR = Path(__file__).resolve().parent

# =============== data.json 还是放在 Space 本地 =================
# 如果想把 data.json 也放到 dataset 里,可以改成用 hf_hub_download(后面我再写)
JSON_PATH = ROOT_DIR / "data.json"

# =============== 读 json & 建索引 =================
with open(JSON_PATH, "r", encoding="utf-8") as f:
    data = json.load(f)

ALL_NAMES = sorted({item["name"] for item in data})

# (name, idx_str) -> item
INDEX_MAP = {(item["name"], str(item["idx"])): item for item in data}

# =============== baseline 模型列表(现在不能再 os.listdir,本地没图了) =================
# 建议直接手动写齐你有的模型
ALL_MODELS = [
    "Bagel-Think",
    "DreamOmni2",
    "GPT-4o",
    "Lumina-DiMOO",
    "MagicBrush",
    "Nano-Banana",
    "Omnigen2",
    "Qwen-Image-Edit",
    "Seedream4.0",
    "UniREdit-Bagel",
    "UniWorld-V2",
]

PRIORITY_MODELS = ["Bagel-Think", "GPT-4o", "Qwen-Image-Edit", "UniREdit-Bagel", "Nano-Banana"]

BASELINE_MODELS = (
    [m for m in PRIORITY_MODELS if m in ALL_MODELS]
    + sorted([m for m in ALL_MODELS if m not in PRIORITY_MODELS])
)

HTML_HEAD = '<table class="center">'
HTML_TAIL = "</table>"
N_COL = 4
WIDTH = 100 // N_COL


def get_indices_for_name(name: str):
    """给定 name,返回该类下面所有 idx(字符串),按数字排序。"""
    if not name:
        return []
    idxs = {str(item["idx"]) for item in data if item["name"] == name}
    return sorted(idxs, key=lambda x: int(x))


def render_img_html(rel_path: str, max_h=512):
    if not rel_path:
        return "<p>No original image.</p>"

    key = (rel_path, None)  # 不区分高度
    if key in _html_cache:
        data_url = _html_cache[key]
        return f'<img src="{data_url}" style="max-width:100%; max-height:{max_h}px;">'

    img = load_image_uniredit(rel_path)
    if img is None:
        return "<p>Failed to load image.</p>"

    # 不再 resize,直接按原分辨率编码
    data_url = pil_to_base64(img)
    _html_cache[key] = data_url

    return f'<img src="{data_url}" style="max-width:100%; max-height:{max_h}px;">'


def get_baseline_gallery(name: str, idx: str, models):
    """生成 baseline 图像的 HTML 表格(使用远程 URL)."""
    if not name or not idx:
        return "<p>Please select name and idx.</p>"

    # 没勾选就默认显示所有 baseline
    if not models:
        models = BASELINE_MODELS

    models = list(models)

    html = HTML_HEAD
    num_models = len(models)

    for row in range((num_models - 1) // N_COL + 1):
        sub_models = models[row * N_COL : (row + 1) * N_COL]

        # 第一行:模型名
        html += "<tr>"
        for m in sub_models:
            html += (
                f'<td width="{WIDTH}%" style="text-align:center;"><h4>{m}</h4></td>'
            )
        for _ in range(N_COL - len(sub_models)):
            html += f'<td width="{WIDTH}%"></td>'
        html += "</tr>"

        # 第二行:对应图片
        html += "<tr>"
        for m in sub_models:
            rel_path = f"Unireditbench_baseline_images/{m}/{name}/{idx}.png"
            cell = render_img_html(rel_path, max_h=256)
            html += f'<td width="{WIDTH}%" style="text-align:center;">{cell}</td>'
        for _ in range(N_COL - len(sub_models)):
            html += f'<td width="{WIDTH}%"></td>'
        html += "</tr>"

    html += HTML_TAIL
    return html


def update_idx_dropdown(name):
    """当 name 改变时,更新 idx 下拉选项."""
    idxs = get_indices_for_name(name)
    default = idxs[0] if idxs else None
    return gr.Dropdown(choices=idxs, value=default)


def load_sample(name, idx, selected_models):
    key = (name, str(idx))
    item = INDEX_MAP.get(key)

    if item is None:
        info_md = f"**Not found:** name = {name}, idx = {idx}"
        return info_md, "", "", "<p>Sample not found.</p>", "<p>Sample not found.</p>"

    info_md = (
        f"**Category (name):** {item['name']}  \n"
        f"**Index (idx):** {item['idx']}"
    )
    instruction = item.get("instruction", "")
    rules = item.get("rules", "")

    # data.json 里原本的 original_image_path 建议直接保持为相对路径:
    #   "original_image_path": "original_image/jewel2/0001.png"
    orig_rel = item.get("original_image_path", "")
    orig_html = render_img_html(orig_rel, max_h=512)

    gallery_html = get_baseline_gallery(name, str(idx), selected_models)

    return info_md, instruction, rules, orig_html, gallery_html

def _step_idx(name: str, idx: str, direction: int):
    """direction=-1 上一个, direction=+1 下一个;到边界就停住。"""
    idxs = get_indices_for_name(name)
    if not idxs:
        return None
    idx = str(idx)
    if idx not in idxs:
        cur = 0
    else:
        cur = idxs.index(idx)

    new_pos = cur + direction
    new_pos = max(0, min(len(idxs) - 1, new_pos))
    return idxs[new_pos]


def prev_sample(name, idx, selected_models):
    new_idx = _step_idx(name, idx, direction=-1)
    if new_idx is None:
        # 保持不变
        info_md, instruction, rules, orig_html, gallery_html = load_sample(name, idx, selected_models)
        return gr.update(), info_md, instruction, rules, orig_html, gallery_html

    info_md, instruction, rules, orig_html, gallery_html = load_sample(name, new_idx, selected_models)
    return gr.update(value=new_idx), info_md, instruction, rules, orig_html, gallery_html


def next_sample(name, idx, selected_models):
    new_idx = _step_idx(name, idx, direction=+1)
    if new_idx is None:
        info_md, instruction, rules, orig_html, gallery_html = load_sample(name, idx, selected_models)
        return gr.update(), info_md, instruction, rules, orig_html, gallery_html

    info_md, instruction, rules, orig_html, gallery_html = load_sample(name, new_idx, selected_models)
    return gr.update(value=new_idx), info_md, instruction, rules, orig_html, gallery_html


# ================== Gradio UI ==================
with gr.Blocks() as demo:
    gr.Markdown("# UniREditBench Gallery")

    with gr.Row():
        with gr.Column(scale=1):
            default_name = ALL_NAMES[0] if ALL_NAMES else None
            default_idxs = get_indices_for_name(default_name) if default_name else []
            default_idx = default_idxs[0] if default_idxs else None

            name_dropdown = gr.Dropdown(
                label="Category (name)",
                choices=ALL_NAMES,
                value=default_name,
            )

            idx_dropdown = gr.Dropdown(
                label="Idx",
                choices=default_idxs,
                value=default_idx,
            )

            with gr.Row():
                prev_button = gr.Button("Prev")
                next_button = gr.Button("Next")
                load_button = gr.Button("Load sample")
                

            model_checkboxes = gr.CheckboxGroup(
                label="Baselines to show",
                choices=BASELINE_MODELS,     # 还是所有模型都可以选
                value=PRIORITY_MODELS,       # ✅ 默认只勾选优先那几个
            )

        with gr.Column(scale=2):
            info_markdown = gr.Markdown(label="Info")
            instruction_box = gr.Textbox(
                label="Instruction", lines=4, interactive=False
            )
            rules_box = gr.Textbox(
                label="Rules", lines=3, interactive=False
            )

        with gr.Column(scale=2):
            gr.Markdown("### Original Image")
            orig_image_html = gr.HTML()


    gallery_html = gr.HTML()

    name_dropdown.change(
        fn=update_idx_dropdown,
        inputs=name_dropdown,
        outputs=idx_dropdown,
    )

    load_button.click(
        fn=load_sample,
        inputs=[name_dropdown, idx_dropdown, model_checkboxes],
        outputs=[info_markdown, instruction_box, rules_box, orig_image_html, gallery_html]
    )
    
    # Prev / Next:会同时更新 idx_dropdown + 刷新所有内容
    prev_button.click(
        fn=prev_sample,
        inputs=[name_dropdown, idx_dropdown, model_checkboxes],
        outputs=[idx_dropdown, info_markdown, instruction_box, rules_box, orig_image_html, gallery_html]
    )

    next_button.click(
        fn=next_sample,
        inputs=[name_dropdown, idx_dropdown, model_checkboxes],
        outputs=[idx_dropdown, info_markdown, instruction_box, rules_box, orig_image_html, gallery_html]
    )


if __name__ == "__main__":
    # 不再需要 allowed_paths / set_static_paths
    demo.launch(server_name="0.0.0.0", server_port=7860)