maplebb commited on
Commit
a0450b5
·
verified ·
1 Parent(s): d6808f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -14
app.py CHANGED
@@ -2,10 +2,46 @@ import os
2
  import json
3
  import gradio as gr
4
  from pathlib import Path
 
 
 
5
 
6
  # =============== 配置:HF dataset 仓库 ID =================
7
  HF_DATASET_ID = "maplebb/UniREditBench-Results"
8
  HF_BASE_URL = f"https://huggingface.co/datasets/{HF_DATASET_ID}/resolve/main"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  ROOT_DIR = Path(__file__).resolve().parent
11
 
@@ -60,16 +96,22 @@ def get_indices_for_name(name: str):
60
 
61
 
62
  def render_img_html(rel_path: str, max_h=512):
63
- """
64
- 现在不再检查本地文件,直接把相对路径拼到 HF dataset 的 URL 上。
65
- rel_path 例如: 'original_image/jewel2/0001.png'
66
- """
67
  if not rel_path:
68
  return "<p>No original image.</p>"
69
 
70
- rel_path = str(rel_path).lstrip("/") # 防止开头有 '/'
71
- src = f"{HF_BASE_URL}/{rel_path}"
72
- return f'<img src="{src}" style="max-width:100%; max-height:{max_h}px;">'
 
 
 
 
 
 
 
 
 
 
73
 
74
 
75
  def get_baseline_gallery(name: str, idx: str, models):
@@ -102,14 +144,9 @@ def get_baseline_gallery(name: str, idx: str, models):
102
  # 第二行:对应图片
103
  html += "<tr>"
104
  for m in sub_models:
105
- # 关键:这里直接拼 dataset 里的路径
106
- # 对应 dataset 目录:Unireditbench_baseline_images/{model}/{name}/{idx}.png
107
  rel_path = f"Unireditbench_baseline_images/{m}/{name}/{idx}.png"
108
- src = f"{HF_BASE_URL}/{rel_path}"
109
- cell = f'<img src="{src}" style="max-width:100%; max-height:256px;">'
110
- html += (
111
- f'<td width="{WIDTH}%" style="text-align:center;">{cell}</td>'
112
- )
113
  for _ in range(N_COL - len(sub_models)):
114
  html += f'<td width="{WIDTH}%"></td>'
115
  html += "</tr>"
 
2
  import json
3
  import gradio as gr
4
  from pathlib import Path
5
+ from io import BytesIO
6
+ import base64
7
+ import requests
8
 
9
  # =============== 配置:HF dataset 仓库 ID =================
10
  HF_DATASET_ID = "maplebb/UniREditBench-Results"
11
  HF_BASE_URL = f"https://huggingface.co/datasets/{HF_DATASET_ID}/resolve/main"
12
+ BASE_HF_URL = f"https://huggingface.co/datasets/{HF_DATASET_ID}/resolve/main"
13
+
14
+ _image_cache = {}
15
+
16
+ def get_url_response(url):
17
+ try:
18
+ resp = requests.get(url, timeout=10)
19
+ resp.raise_for_status()
20
+ return resp
21
+ except Exception as e:
22
+ print(f"Error fetching {url}: {e}")
23
+ return None
24
+
25
+ def load_image_uniredit(rel_path: str):
26
+ """从 UniREditBench 数据集拉图(后端 requests)"""
27
+ rel_path = rel_path.lstrip("/")
28
+ if rel_path in _image_cache:
29
+ return _image_cache[rel_path]
30
+ url = f"{BASE_HF_URL}/{rel_path}"
31
+ resp = get_url_response(url)
32
+ if not resp:
33
+ return None
34
+ img = Image.open(BytesIO(resp.content)).convert("RGB")
35
+ _image_cache[rel_path] = img
36
+ return img
37
+
38
+ def pil_to_base64(img):
39
+ if img is None:
40
+ return ""
41
+ buf = BytesIO()
42
+ img.save(buf, format="PNG")
43
+ s = base64.b64encode(buf.getvalue()).decode("utf-8")
44
+ return f"data:image/png;base64,{s}"
45
 
46
  ROOT_DIR = Path(__file__).resolve().parent
47
 
 
96
 
97
 
98
  def render_img_html(rel_path: str, max_h=512):
 
 
 
 
99
  if not rel_path:
100
  return "<p>No original image.</p>"
101
 
102
+ img = load_image_uniredit(rel_path)
103
+ if img is None:
104
+ return "<p>Failed to load image.</p>"
105
+
106
+ # 可选:缩放一下高度
107
+ h = img.height
108
+ w = img.width
109
+ if h > max_h:
110
+ ratio = max_h / h
111
+ img = img.resize((int(w * ratio), max_h), Image.Resampling.LANCZOS)
112
+
113
+ data_url = pil_to_base64(img)
114
+ return f'<img src="{data_url}" style="max-width:100%; max-height:{max_h}px;">'
115
 
116
 
117
  def get_baseline_gallery(name: str, idx: str, models):
 
144
  # 第二行:对应图片
145
  html += "<tr>"
146
  for m in sub_models:
 
 
147
  rel_path = f"Unireditbench_baseline_images/{m}/{name}/{idx}.png"
148
+ cell = render_img_html(rel_path, max_h=256)
149
+ html += f'<td width="{WIDTH}%" style="text-align:center;">{cell}</td>'
 
 
 
150
  for _ in range(N_COL - len(sub_models)):
151
  html += f'<td width="{WIDTH}%"></td>'
152
  html += "</tr>"