maplebb commited on
Commit
ddff00f
·
verified ·
1 Parent(s): d7f1928

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -18
app.py CHANGED
@@ -2,14 +2,53 @@ import os
2
  import json
3
  import gradio as gr
4
  from pathlib import Path
 
 
 
 
5
 
6
  # =============== 配置:HF dataset 仓库 ID =================
7
  HF_DATASET_ID = "maplebb/UniREditBench-Results"
8
  HF_BASE_URL = f"https://huggingface.co/datasets/{HF_DATASET_ID}/resolve/main"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  ROOT_DIR = Path(__file__).resolve().parent
11
 
12
  # =============== data.json 还是放在 Space 本地 =================
 
13
  JSON_PATH = ROOT_DIR / "data.json"
14
 
15
  # =============== 读 json & 建索引 =================
@@ -21,7 +60,8 @@ ALL_NAMES = sorted({item["name"] for item in data})
21
  # (name, idx_str) -> item
22
  INDEX_MAP = {(item["name"], str(item["idx"])): item for item in data}
23
 
24
- # =============== baseline 模型列表 =================
 
25
  ALL_MODELS = [
26
  "Bagel-Think",
27
  "DreamOmni2",
@@ -58,20 +98,27 @@ def get_indices_for_name(name: str):
58
 
59
 
60
  def render_img_html(rel_path: str, max_h=512):
61
- """
62
- 像第二段代码那样,直接返回远程 URL,不在后端拉图 / 编码。
63
- rel_path 例如: 'original_image/jewel2/0001.png'
64
- """
65
  if not rel_path:
66
  return "<p>No original image.</p>"
67
 
68
- rel_path = str(rel_path).lstrip("/") # 防止开头有 '/'
69
- src = f"{HF_BASE_URL}/{rel_path}"
70
- return f'<img src="{src}" style="max-width:100%; max-height:{max_h}px;">'
 
 
 
 
 
 
 
 
 
 
 
71
 
72
 
73
  def get_baseline_gallery(name: str, idx: str, models):
74
- """生成 baseline 图像的 HTML 表格(和第二段类似,img src 直接是 URL)."""
75
  if not name or not idx:
76
  return "<p>Please select name and idx.</p>"
77
 
@@ -90,7 +137,9 @@ def get_baseline_gallery(name: str, idx: str, models):
90
  # 第一行:模型名
91
  html += "<tr>"
92
  for m in sub_models:
93
- html += f'<td width="{WIDTH}%" style="text-align:center;"><h4>{m}</h4></td>'
 
 
94
  for _ in range(N_COL - len(sub_models)):
95
  html += f'<td width="{WIDTH}%"></td>'
96
  html += "</tr>"
@@ -98,10 +147,8 @@ def get_baseline_gallery(name: str, idx: str, models):
98
  # 第二行:对应图片
99
  html += "<tr>"
100
  for m in sub_models:
101
- # 对应 dataset 目录:Unireditbench_baseline_images/{model}/{name}/{idx}.png
102
  rel_path = f"Unireditbench_baseline_images/{m}/{name}/{idx}.png"
103
- src = f"{HF_BASE_URL}/{rel_path}"
104
- cell = f'<img src="{src}" style="max-width:100%; max-height:256px;">'
105
  html += f'<td width="{WIDTH}%" style="text-align:center;">{cell}</td>'
106
  for _ in range(N_COL - len(sub_models)):
107
  html += f'<td width="{WIDTH}%"></td>'
@@ -133,7 +180,7 @@ def load_sample(name, idx, selected_models):
133
  instruction = item.get("instruction", "")
134
  rules = item.get("rules", "")
135
 
136
- # data.json original_image_path 仍然保持为相对路径:
137
  # "original_image_path": "original_image/jewel2/0001.png"
138
  orig_rel = item.get("original_image_path", "")
139
  orig_html = render_img_html(orig_rel, max_h=512)
@@ -142,7 +189,6 @@ def load_sample(name, idx, selected_models):
142
 
143
  return info_md, instruction, rules, orig_html, gallery_html
144
 
145
-
146
  def _step_idx(name: str, idx: str, direction: int):
147
  """direction=-1 上一个, direction=+1 下一个;到边界就停住。"""
148
  idxs = get_indices_for_name(name)
@@ -206,11 +252,12 @@ with gr.Blocks() as demo:
206
  prev_button = gr.Button("Prev")
207
  next_button = gr.Button("Next")
208
  load_button = gr.Button("Load sample")
 
209
 
210
  model_checkboxes = gr.CheckboxGroup(
211
  label="Baselines to show",
212
- choices=BASELINE_MODELS,
213
- value=BASELINE_MODELS,
214
  )
215
 
216
  with gr.Column(scale=2):
@@ -226,6 +273,7 @@ with gr.Blocks() as demo:
226
  gr.Markdown("### Original Image")
227
  orig_image_html = gr.HTML()
228
 
 
229
  gallery_html = gr.HTML()
230
 
231
  name_dropdown.change(
@@ -239,7 +287,7 @@ with gr.Blocks() as demo:
239
  inputs=[name_dropdown, idx_dropdown, model_checkboxes],
240
  outputs=[info_markdown, instruction_box, rules_box, orig_image_html, gallery_html]
241
  )
242
-
243
  # Prev / Next:会同时更新 idx_dropdown + 刷新所有内容
244
  prev_button.click(
245
  fn=prev_sample,
@@ -255,4 +303,5 @@ with gr.Blocks() as demo:
255
 
256
 
257
  if __name__ == "__main__":
 
258
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
2
  import json
3
  import gradio as gr
4
  from pathlib import Path
5
+ from io import BytesIO
6
+ import base64
7
+ import requests
8
+ from PIL import Image
9
 
10
  # =============== 配置:HF dataset 仓库 ID =================
11
  HF_DATASET_ID = "maplebb/UniREditBench-Results"
12
  HF_BASE_URL = f"https://huggingface.co/datasets/{HF_DATASET_ID}/resolve/main"
13
+ BASE_HF_URL = f"https://huggingface.co/datasets/{HF_DATASET_ID}/resolve/main"
14
+
15
+ _image_cache = {}
16
+ _html_cache = {}
17
+
18
+ def get_url_response(url):
19
+ try:
20
+ resp = requests.get(url, timeout=10)
21
+ resp.raise_for_status()
22
+ return resp
23
+ except Exception as e:
24
+ print(f"Error fetching {url}: {e}")
25
+ return None
26
+
27
+ def load_image_uniredit(rel_path: str):
28
+ """从 UniREditBench 数据集拉图(后端 requests)"""
29
+ rel_path = rel_path.lstrip("/")
30
+ if rel_path in _image_cache:
31
+ return _image_cache[rel_path]
32
+ url = f"{BASE_HF_URL}/{rel_path}"
33
+ resp = get_url_response(url)
34
+ if not resp:
35
+ return None
36
+ img = Image.open(BytesIO(resp.content)).convert("RGB")
37
+ _image_cache[rel_path] = img
38
+ return img
39
+
40
+ def pil_to_base64(img):
41
+ if img is None:
42
+ return ""
43
+ buf = BytesIO()
44
+ img.save(buf, format="PNG")
45
+ s = base64.b64encode(buf.getvalue()).decode("utf-8")
46
+ return f"data:image/png;base64,{s}"
47
 
48
  ROOT_DIR = Path(__file__).resolve().parent
49
 
50
  # =============== data.json 还是放在 Space 本地 =================
51
+ # 如果想把 data.json 也放到 dataset 里,可以改成用 hf_hub_download(后面我再写)
52
  JSON_PATH = ROOT_DIR / "data.json"
53
 
54
  # =============== 读 json & 建索引 =================
 
60
  # (name, idx_str) -> item
61
  INDEX_MAP = {(item["name"], str(item["idx"])): item for item in data}
62
 
63
+ # =============== baseline 模型列表(现在不能再 os.listdir,本地没图了) =================
64
+ # 建议直接手动写齐你有的模型
65
  ALL_MODELS = [
66
  "Bagel-Think",
67
  "DreamOmni2",
 
98
 
99
 
100
  def render_img_html(rel_path: str, max_h=512):
 
 
 
 
101
  if not rel_path:
102
  return "<p>No original image.</p>"
103
 
104
+ key = (rel_path, None) # 不区分高度
105
+ if key in _html_cache:
106
+ data_url = _html_cache[key]
107
+ return f'<img src="{data_url}" style="max-width:100%; max-height:{max_h}px;">'
108
+
109
+ img = load_image_uniredit(rel_path)
110
+ if img is None:
111
+ return "<p>Failed to load image.</p>"
112
+
113
+ # 不再 resize,直接按原分辨率编码
114
+ data_url = pil_to_base64(img)
115
+ _html_cache[key] = data_url
116
+
117
+ return f'<img src="{data_url}" style="max-width:100%; max-height:{max_h}px;">'
118
 
119
 
120
  def get_baseline_gallery(name: str, idx: str, models):
121
+ """生成 baseline 图像的 HTML 表格(使用远程 URL)."""
122
  if not name or not idx:
123
  return "<p>Please select name and idx.</p>"
124
 
 
137
  # 第一行:模型名
138
  html += "<tr>"
139
  for m in sub_models:
140
+ html += (
141
+ f'<td width="{WIDTH}%" style="text-align:center;"><h4>{m}</h4></td>'
142
+ )
143
  for _ in range(N_COL - len(sub_models)):
144
  html += f'<td width="{WIDTH}%"></td>'
145
  html += "</tr>"
 
147
  # 第二行:对应图片
148
  html += "<tr>"
149
  for m in sub_models:
 
150
  rel_path = f"Unireditbench_baseline_images/{m}/{name}/{idx}.png"
151
+ cell = render_img_html(rel_path, max_h=256)
 
152
  html += f'<td width="{WIDTH}%" style="text-align:center;">{cell}</td>'
153
  for _ in range(N_COL - len(sub_models)):
154
  html += f'<td width="{WIDTH}%"></td>'
 
180
  instruction = item.get("instruction", "")
181
  rules = item.get("rules", "")
182
 
183
+ # data.json 里原本的 original_image_path 建议直接保持为相对路径:
184
  # "original_image_path": "original_image/jewel2/0001.png"
185
  orig_rel = item.get("original_image_path", "")
186
  orig_html = render_img_html(orig_rel, max_h=512)
 
189
 
190
  return info_md, instruction, rules, orig_html, gallery_html
191
 
 
192
  def _step_idx(name: str, idx: str, direction: int):
193
  """direction=-1 上一个, direction=+1 下一个;到边界就停住。"""
194
  idxs = get_indices_for_name(name)
 
252
  prev_button = gr.Button("Prev")
253
  next_button = gr.Button("Next")
254
  load_button = gr.Button("Load sample")
255
+
256
 
257
  model_checkboxes = gr.CheckboxGroup(
258
  label="Baselines to show",
259
+ choices=BASELINE_MODELS, # 还是所有模型都可以选
260
+ value=PRIORITY_MODELS, # ✅ 默认只勾选优先那几个
261
  )
262
 
263
  with gr.Column(scale=2):
 
273
  gr.Markdown("### Original Image")
274
  orig_image_html = gr.HTML()
275
 
276
+
277
  gallery_html = gr.HTML()
278
 
279
  name_dropdown.change(
 
287
  inputs=[name_dropdown, idx_dropdown, model_checkboxes],
288
  outputs=[info_markdown, instruction_box, rules_box, orig_image_html, gallery_html]
289
  )
290
+
291
  # Prev / Next:会同时更新 idx_dropdown + 刷新所有内容
292
  prev_button.click(
293
  fn=prev_sample,
 
303
 
304
 
305
  if __name__ == "__main__":
306
+ # 不再需要 allowed_paths / set_static_paths
307
  demo.launch(server_name="0.0.0.0", server_port=7860)