maplebb commited on
Commit
d6808f3
·
verified ·
1 Parent(s): eb63f51

Upload 4 files

Browse files
Files changed (3) hide show
  1. app.py +267 -0
  2. data.json +0 -0
  3. requirements.txt +1 -0
app.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import gradio as gr
4
+ from pathlib import Path
5
+
6
+ # =============== 配置:HF dataset 仓库 ID =================
7
+ HF_DATASET_ID = "maplebb/UniREditBench-Results"
8
+ HF_BASE_URL = f"https://huggingface.co/datasets/{HF_DATASET_ID}/resolve/main"
9
+
10
+ ROOT_DIR = Path(__file__).resolve().parent
11
+
12
+ # =============== data.json 还是放在 Space 本地 =================
13
+ # 如果想把 data.json 也放到 dataset 里,可以改成用 hf_hub_download(后面我再写)
14
+ JSON_PATH = ROOT_DIR / "data.json"
15
+
16
+ # =============== 读 json & 建索引 =================
17
+ with open(JSON_PATH, "r", encoding="utf-8") as f:
18
+ data = json.load(f)
19
+
20
+ ALL_NAMES = sorted({item["name"] for item in data})
21
+
22
+ # (name, idx_str) -> item
23
+ INDEX_MAP = {(item["name"], str(item["idx"])): item for item in data}
24
+
25
+ # =============== baseline 模型列表(现在不能再 os.listdir,本地没图了) =================
26
+ # 建议直接手动写齐你有的模型
27
+ ALL_MODELS = [
28
+ "Bagel-Think",
29
+ "DreamOmni2",
30
+ "GPT-4o",
31
+ "Lumina-DiMOO",
32
+ "MagicBrush",
33
+ "Nano-Banana",
34
+ "Omnigen2",
35
+ "Qwen-Image-Edit",
36
+ "Seedream4.0",
37
+ "UniREdit-Bagel",
38
+ "UniWorld-V2",
39
+ ]
40
+
41
+ PRIORITY_MODELS = ["Bagel-Think", "GPT-4o", "Qwen-Image-Edit", "UniREdit-Bagel", "Nano-Banana"]
42
+
43
+ BASELINE_MODELS = (
44
+ [m for m in PRIORITY_MODELS if m in ALL_MODELS]
45
+ + sorted([m for m in ALL_MODELS if m not in PRIORITY_MODELS])
46
+ )
47
+
48
+ HTML_HEAD = '<table class="center">'
49
+ HTML_TAIL = "</table>"
50
+ N_COL = 4
51
+ WIDTH = 100 // N_COL
52
+
53
+
54
+ def get_indices_for_name(name: str):
55
+ """给定 name,返回该类下面所有 idx(字符串),按数字排序。"""
56
+ if not name:
57
+ return []
58
+ idxs = {str(item["idx"]) for item in data if item["name"] == name}
59
+ return sorted(idxs, key=lambda x: int(x))
60
+
61
+
62
+ def render_img_html(rel_path: str, max_h=512):
63
+ """
64
+ 现在不再检查本地文件,直接把相对路径拼到 HF dataset 的 URL 上。
65
+ rel_path 例如: 'original_image/jewel2/0001.png'
66
+ """
67
+ if not rel_path:
68
+ return "<p>No original image.</p>"
69
+
70
+ rel_path = str(rel_path).lstrip("/") # 防止开头有 '/'
71
+ src = f"{HF_BASE_URL}/{rel_path}"
72
+ return f'<img src="{src}" style="max-width:100%; max-height:{max_h}px;">'
73
+
74
+
75
+ def get_baseline_gallery(name: str, idx: str, models):
76
+ """生成 baseline 图像的 HTML 表格(使用远程 URL)."""
77
+ if not name or not idx:
78
+ return "<p>Please select name and idx.</p>"
79
+
80
+ # 没勾选就默认显示所有 baseline
81
+ if not models:
82
+ models = BASELINE_MODELS
83
+
84
+ models = list(models)
85
+
86
+ html = HTML_HEAD
87
+ num_models = len(models)
88
+
89
+ for row in range((num_models - 1) // N_COL + 1):
90
+ sub_models = models[row * N_COL : (row + 1) * N_COL]
91
+
92
+ # 第一行:模型名
93
+ html += "<tr>"
94
+ for m in sub_models:
95
+ html += (
96
+ f'<td width="{WIDTH}%" style="text-align:center;"><h4>{m}</h4></td>'
97
+ )
98
+ for _ in range(N_COL - len(sub_models)):
99
+ html += f'<td width="{WIDTH}%"></td>'
100
+ html += "</tr>"
101
+
102
+ # 第二行:对应图片
103
+ html += "<tr>"
104
+ for m in sub_models:
105
+ # 关键:这里直接拼 dataset 里的路径
106
+ # 对应 dataset 目录:Unireditbench_baseline_images/{model}/{name}/{idx}.png
107
+ rel_path = f"Unireditbench_baseline_images/{m}/{name}/{idx}.png"
108
+ src = f"{HF_BASE_URL}/{rel_path}"
109
+ cell = f'<img src="{src}" style="max-width:100%; max-height:256px;">'
110
+ html += (
111
+ f'<td width="{WIDTH}%" style="text-align:center;">{cell}</td>'
112
+ )
113
+ for _ in range(N_COL - len(sub_models)):
114
+ html += f'<td width="{WIDTH}%"></td>'
115
+ html += "</tr>"
116
+
117
+ html += HTML_TAIL
118
+ return html
119
+
120
+
121
+ def update_idx_dropdown(name):
122
+ """当 name 改变时,更新 idx 下拉选项."""
123
+ idxs = get_indices_for_name(name)
124
+ default = idxs[0] if idxs else None
125
+ return gr.Dropdown(choices=idxs, value=default)
126
+
127
+
128
+ def load_sample(name, idx, selected_models):
129
+ key = (name, str(idx))
130
+ item = INDEX_MAP.get(key)
131
+
132
+ if item is None:
133
+ info_md = f"**Not found:** name = {name}, idx = {idx}"
134
+ return info_md, "", "", "<p>Sample not found.</p>", "<p>Sample not found.</p>"
135
+
136
+ info_md = (
137
+ f"**Category (name):** {item['name']} \n"
138
+ f"**Index (idx):** {item['idx']}"
139
+ )
140
+ instruction = item.get("instruction", "")
141
+ rules = item.get("rules", "")
142
+
143
+ # data.json 里原本的 original_image_path 建议直接保持为相对路径:
144
+ # "original_image_path": "original_image/jewel2/0001.png"
145
+ orig_rel = item.get("original_image_path", "")
146
+ orig_html = render_img_html(orig_rel, max_h=512)
147
+
148
+ gallery_html = get_baseline_gallery(name, str(idx), selected_models)
149
+
150
+ return info_md, instruction, rules, orig_html, gallery_html
151
+
152
+ def _step_idx(name: str, idx: str, direction: int):
153
+ """direction=-1 上一个, direction=+1 下一个;到边界就停住。"""
154
+ idxs = get_indices_for_name(name)
155
+ if not idxs:
156
+ return None
157
+ idx = str(idx)
158
+ if idx not in idxs:
159
+ cur = 0
160
+ else:
161
+ cur = idxs.index(idx)
162
+
163
+ new_pos = cur + direction
164
+ new_pos = max(0, min(len(idxs) - 1, new_pos))
165
+ return idxs[new_pos]
166
+
167
+
168
+ def prev_sample(name, idx, selected_models):
169
+ new_idx = _step_idx(name, idx, direction=-1)
170
+ if new_idx is None:
171
+ # 保持不变
172
+ info_md, instruction, rules, orig_html, gallery_html = load_sample(name, idx, selected_models)
173
+ return gr.update(), info_md, instruction, rules, orig_html, gallery_html
174
+
175
+ info_md, instruction, rules, orig_html, gallery_html = load_sample(name, new_idx, selected_models)
176
+ return gr.update(value=new_idx), info_md, instruction, rules, orig_html, gallery_html
177
+
178
+
179
+ def next_sample(name, idx, selected_models):
180
+ new_idx = _step_idx(name, idx, direction=+1)
181
+ if new_idx is None:
182
+ info_md, instruction, rules, orig_html, gallery_html = load_sample(name, idx, selected_models)
183
+ return gr.update(), info_md, instruction, rules, orig_html, gallery_html
184
+
185
+ info_md, instruction, rules, orig_html, gallery_html = load_sample(name, new_idx, selected_models)
186
+ return gr.update(value=new_idx), info_md, instruction, rules, orig_html, gallery_html
187
+
188
+
189
+ # ================== Gradio UI ==================
190
+ with gr.Blocks() as demo:
191
+ gr.Markdown("# UniREditBench Gallery")
192
+
193
+ with gr.Row():
194
+ with gr.Column(scale=1):
195
+ default_name = ALL_NAMES[0] if ALL_NAMES else None
196
+ default_idxs = get_indices_for_name(default_name) if default_name else []
197
+ default_idx = default_idxs[0] if default_idxs else None
198
+
199
+ name_dropdown = gr.Dropdown(
200
+ label="Category (name)",
201
+ choices=ALL_NAMES,
202
+ value=default_name,
203
+ )
204
+
205
+ idx_dropdown = gr.Dropdown(
206
+ label="Idx",
207
+ choices=default_idxs,
208
+ value=default_idx,
209
+ )
210
+
211
+ with gr.Row():
212
+ prev_button = gr.Button("Prev")
213
+ next_button = gr.Button("Next")
214
+ load_button = gr.Button("Load sample")
215
+
216
+
217
+ model_checkboxes = gr.CheckboxGroup(
218
+ label="Baselines to show",
219
+ choices=BASELINE_MODELS,
220
+ value=BASELINE_MODELS,
221
+ )
222
+
223
+ with gr.Column(scale=2):
224
+ info_markdown = gr.Markdown(label="Info")
225
+ instruction_box = gr.Textbox(
226
+ label="Instruction", lines=4, interactive=False
227
+ )
228
+ rules_box = gr.Textbox(
229
+ label="Rules", lines=3, interactive=False
230
+ )
231
+
232
+ with gr.Column(scale=2):
233
+ gr.Markdown("### Original Image")
234
+ orig_image_html = gr.HTML()
235
+
236
+
237
+ gallery_html = gr.HTML()
238
+
239
+ name_dropdown.change(
240
+ fn=update_idx_dropdown,
241
+ inputs=name_dropdown,
242
+ outputs=idx_dropdown,
243
+ )
244
+
245
+ load_button.click(
246
+ fn=load_sample,
247
+ inputs=[name_dropdown, idx_dropdown, model_checkboxes],
248
+ outputs=[info_markdown, instruction_box, rules_box, orig_image_html, gallery_html]
249
+ )
250
+
251
+ # Prev / Next:会同时更新 idx_dropdown + 刷新所有内容
252
+ prev_button.click(
253
+ fn=prev_sample,
254
+ inputs=[name_dropdown, idx_dropdown, model_checkboxes],
255
+ outputs=[idx_dropdown, info_markdown, instruction_box, rules_box, orig_image_html, gallery_html]
256
+ )
257
+
258
+ next_button.click(
259
+ fn=next_sample,
260
+ inputs=[name_dropdown, idx_dropdown, model_checkboxes],
261
+ outputs=[idx_dropdown, info_markdown, instruction_box, rules_box, orig_image_html, gallery_html]
262
+ )
263
+
264
+
265
+ if __name__ == "__main__":
266
+ # 不再需要 allowed_paths / set_static_paths
267
+ demo.launch(server_name="0.0.0.0", server_port=7860)
data.json ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ gradio>=4.0.0