maplebb commited on
Commit
d7f1928
·
verified ·
1 Parent(s): 2607abb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -63
app.py CHANGED
@@ -2,52 +2,14 @@ import os
2
  import json
3
  import gradio as gr
4
  from pathlib import Path
5
- from io import BytesIO
6
- import base64
7
- import requests
8
- from PIL import Image
9
 
10
  # =============== 配置:HF dataset 仓库 ID =================
11
  HF_DATASET_ID = "maplebb/UniREditBench-Results"
12
  HF_BASE_URL = f"https://huggingface.co/datasets/{HF_DATASET_ID}/resolve/main"
13
- BASE_HF_URL = f"https://huggingface.co/datasets/{HF_DATASET_ID}/resolve/main"
14
-
15
- _image_cache = {}
16
-
17
- def get_url_response(url):
18
- try:
19
- resp = requests.get(url, timeout=10)
20
- resp.raise_for_status()
21
- return resp
22
- except Exception as e:
23
- print(f"Error fetching {url}: {e}")
24
- return None
25
-
26
- def load_image_uniredit(rel_path: str):
27
- """从 UniREditBench 数据集拉图(后端 requests)"""
28
- rel_path = rel_path.lstrip("/")
29
- if rel_path in _image_cache:
30
- return _image_cache[rel_path]
31
- url = f"{BASE_HF_URL}/{rel_path}"
32
- resp = get_url_response(url)
33
- if not resp:
34
- return None
35
- img = Image.open(BytesIO(resp.content)).convert("RGB")
36
- _image_cache[rel_path] = img
37
- return img
38
-
39
- def pil_to_base64(img):
40
- if img is None:
41
- return ""
42
- buf = BytesIO()
43
- img.save(buf, format="PNG")
44
- s = base64.b64encode(buf.getvalue()).decode("utf-8")
45
- return f"data:image/png;base64,{s}"
46
 
47
  ROOT_DIR = Path(__file__).resolve().parent
48
 
49
  # =============== data.json 还是放在 Space 本地 =================
50
- # 如果想把 data.json 也放到 dataset 里,可以改成用 hf_hub_download(后面我再写)
51
  JSON_PATH = ROOT_DIR / "data.json"
52
 
53
  # =============== 读 json & 建索引 =================
@@ -59,8 +21,7 @@ ALL_NAMES = sorted({item["name"] for item in data})
59
  # (name, idx_str) -> item
60
  INDEX_MAP = {(item["name"], str(item["idx"])): item for item in data}
61
 
62
- # =============== baseline 模型列表(现在不能再 os.listdir,本地没图了) =================
63
- # 建议直接手动写齐你有的模型
64
  ALL_MODELS = [
65
  "Bagel-Think",
66
  "DreamOmni2",
@@ -97,26 +58,20 @@ def get_indices_for_name(name: str):
97
 
98
 
99
  def render_img_html(rel_path: str, max_h=512):
 
 
 
 
100
  if not rel_path:
101
  return "<p>No original image.</p>"
102
 
103
- img = load_image_uniredit(rel_path)
104
- if img is None:
105
- return "<p>Failed to load image.</p>"
106
-
107
- # 可选:缩放一下高度
108
- h = img.height
109
- w = img.width
110
- if h > max_h:
111
- ratio = max_h / h
112
- img = img.resize((int(w * ratio), max_h), Image.Resampling.LANCZOS)
113
-
114
- data_url = pil_to_base64(img)
115
- return f'<img src="{data_url}" style="max-width:100%; max-height:{max_h}px;">'
116
 
117
 
118
  def get_baseline_gallery(name: str, idx: str, models):
119
- """生成 baseline 图像的 HTML 表格(使用远程 URL)."""
120
  if not name or not idx:
121
  return "<p>Please select name and idx.</p>"
122
 
@@ -135,9 +90,7 @@ def get_baseline_gallery(name: str, idx: str, models):
135
  # 第一行:模型名
136
  html += "<tr>"
137
  for m in sub_models:
138
- html += (
139
- f'<td width="{WIDTH}%" style="text-align:center;"><h4>{m}</h4></td>'
140
- )
141
  for _ in range(N_COL - len(sub_models)):
142
  html += f'<td width="{WIDTH}%"></td>'
143
  html += "</tr>"
@@ -145,8 +98,10 @@ def get_baseline_gallery(name: str, idx: str, models):
145
  # 第二行:对应图片
146
  html += "<tr>"
147
  for m in sub_models:
 
148
  rel_path = f"Unireditbench_baseline_images/{m}/{name}/{idx}.png"
149
- cell = render_img_html(rel_path, max_h=256)
 
150
  html += f'<td width="{WIDTH}%" style="text-align:center;">{cell}</td>'
151
  for _ in range(N_COL - len(sub_models)):
152
  html += f'<td width="{WIDTH}%"></td>'
@@ -178,7 +133,7 @@ def load_sample(name, idx, selected_models):
178
  instruction = item.get("instruction", "")
179
  rules = item.get("rules", "")
180
 
181
- # data.json 里原本的 original_image_path 建议直接保持为相对路径:
182
  # "original_image_path": "original_image/jewel2/0001.png"
183
  orig_rel = item.get("original_image_path", "")
184
  orig_html = render_img_html(orig_rel, max_h=512)
@@ -187,6 +142,7 @@ def load_sample(name, idx, selected_models):
187
 
188
  return info_md, instruction, rules, orig_html, gallery_html
189
 
 
190
  def _step_idx(name: str, idx: str, direction: int):
191
  """direction=-1 上一个, direction=+1 下一个;到边界就停住。"""
192
  idxs = get_indices_for_name(name)
@@ -250,7 +206,6 @@ with gr.Blocks() as demo:
250
  prev_button = gr.Button("Prev")
251
  next_button = gr.Button("Next")
252
  load_button = gr.Button("Load sample")
253
-
254
 
255
  model_checkboxes = gr.CheckboxGroup(
256
  label="Baselines to show",
@@ -271,7 +226,6 @@ with gr.Blocks() as demo:
271
  gr.Markdown("### Original Image")
272
  orig_image_html = gr.HTML()
273
 
274
-
275
  gallery_html = gr.HTML()
276
 
277
  name_dropdown.change(
@@ -285,7 +239,7 @@ with gr.Blocks() as demo:
285
  inputs=[name_dropdown, idx_dropdown, model_checkboxes],
286
  outputs=[info_markdown, instruction_box, rules_box, orig_image_html, gallery_html]
287
  )
288
-
289
  # Prev / Next:会同时更新 idx_dropdown + 刷新所有内容
290
  prev_button.click(
291
  fn=prev_sample,
@@ -301,5 +255,4 @@ with gr.Blocks() as demo:
301
 
302
 
303
  if __name__ == "__main__":
304
- # 不再需要 allowed_paths / set_static_paths
305
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
2
  import json
3
  import gradio as gr
4
  from pathlib import Path
 
 
 
 
5
 
6
  # =============== 配置:HF dataset 仓库 ID =================
7
  HF_DATASET_ID = "maplebb/UniREditBench-Results"
8
  HF_BASE_URL = f"https://huggingface.co/datasets/{HF_DATASET_ID}/resolve/main"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  ROOT_DIR = Path(__file__).resolve().parent
11
 
12
  # =============== data.json 还是放在 Space 本地 =================
 
13
  JSON_PATH = ROOT_DIR / "data.json"
14
 
15
  # =============== 读 json & 建索引 =================
 
21
  # (name, idx_str) -> item
22
  INDEX_MAP = {(item["name"], str(item["idx"])): item for item in data}
23
 
24
+ # =============== baseline 模型列表 =================
 
25
  ALL_MODELS = [
26
  "Bagel-Think",
27
  "DreamOmni2",
 
58
 
59
 
60
  def render_img_html(rel_path: str, max_h=512):
61
+ """
62
+ 像第二段代码那样,直接返回远程 URL,不在后端拉图 / 编码。
63
+ rel_path 例如: 'original_image/jewel2/0001.png'
64
+ """
65
  if not rel_path:
66
  return "<p>No original image.</p>"
67
 
68
+ rel_path = str(rel_path).lstrip("/") # 防止开头有 '/'
69
+ src = f"{HF_BASE_URL}/{rel_path}"
70
+ return f'<img src="{src}" style="max-width:100%; max-height:{max_h}px;">'
 
 
 
 
 
 
 
 
 
 
71
 
72
 
73
  def get_baseline_gallery(name: str, idx: str, models):
74
+ """生成 baseline 图像的 HTML 表格(和第二段类似,img src 直接是 URL)."""
75
  if not name or not idx:
76
  return "<p>Please select name and idx.</p>"
77
 
 
90
  # 第一行:模型名
91
  html += "<tr>"
92
  for m in sub_models:
93
+ html += f'<td width="{WIDTH}%" style="text-align:center;"><h4>{m}</h4></td>'
 
 
94
  for _ in range(N_COL - len(sub_models)):
95
  html += f'<td width="{WIDTH}%"></td>'
96
  html += "</tr>"
 
98
  # 第二行:对应图片
99
  html += "<tr>"
100
  for m in sub_models:
101
+ # 对应 dataset 目录:Unireditbench_baseline_images/{model}/{name}/{idx}.png
102
  rel_path = f"Unireditbench_baseline_images/{m}/{name}/{idx}.png"
103
+ src = f"{HF_BASE_URL}/{rel_path}"
104
+ cell = f'<img src="{src}" style="max-width:100%; max-height:256px;">'
105
  html += f'<td width="{WIDTH}%" style="text-align:center;">{cell}</td>'
106
  for _ in range(N_COL - len(sub_models)):
107
  html += f'<td width="{WIDTH}%"></td>'
 
133
  instruction = item.get("instruction", "")
134
  rules = item.get("rules", "")
135
 
136
+ # data.json original_image_path 仍然保持为相对路径:
137
  # "original_image_path": "original_image/jewel2/0001.png"
138
  orig_rel = item.get("original_image_path", "")
139
  orig_html = render_img_html(orig_rel, max_h=512)
 
142
 
143
  return info_md, instruction, rules, orig_html, gallery_html
144
 
145
+
146
  def _step_idx(name: str, idx: str, direction: int):
147
  """direction=-1 上一个, direction=+1 下一个;到边界就停住。"""
148
  idxs = get_indices_for_name(name)
 
206
  prev_button = gr.Button("Prev")
207
  next_button = gr.Button("Next")
208
  load_button = gr.Button("Load sample")
 
209
 
210
  model_checkboxes = gr.CheckboxGroup(
211
  label="Baselines to show",
 
226
  gr.Markdown("### Original Image")
227
  orig_image_html = gr.HTML()
228
 
 
229
  gallery_html = gr.HTML()
230
 
231
  name_dropdown.change(
 
239
  inputs=[name_dropdown, idx_dropdown, model_checkboxes],
240
  outputs=[info_markdown, instruction_box, rules_box, orig_image_html, gallery_html]
241
  )
242
+
243
  # Prev / Next:会同时更新 idx_dropdown + 刷新所有内容
244
  prev_button.click(
245
  fn=prev_sample,
 
255
 
256
 
257
  if __name__ == "__main__":
 
258
  demo.launch(server_name="0.0.0.0", server_port=7860)