Vis_base64_data / app.py
xmu-xiaoma666's picture
Update app.py
513765b verified
import base64
from io import BytesIO
from PIL import Image, ImageDraw
from datasets import load_dataset
import gradio as gr
import json
import ast
import re
def _parse_json_from_text(text):
"""
提取text中的第一个```json ...```代码块并解析为列表.
如果没有代码块,则尝试直接解析text内容为列表.
"""
if not text or not isinstance(text, str):
return []
# 尝试提取```json ...```代码块
blocks = re.findall(r"```json\s*([\s\S]+?)```", text)
for blk in blocks:
try:
return json.loads(blk)
except Exception:
try:
return ast.literal_eval(blk)
except Exception:
continue
# 没有代码块则直接解析
try:
return json.loads(text)
except Exception:
try:
# 尝试剥离结尾如 \n3 等
arrtxt = text.strip()
arrtxt = re.split(r"\n\d+$", arrtxt)[0]
return ast.literal_eval(arrtxt)
except Exception:
return []
return []
def _parse_conversations(conversations):
"""提取所有assistant回复中的box/point"""
all_boxes = []
all_points = []
if isinstance(conversations, str):
try:
lines = [line for line in conversations.split("\n") if line.strip()]
conversations = [ast.literal_eval(line) for line in lines]
except Exception:
return [], []
for utter in conversations:
role = utter.get("from", "")
value = utter.get("value", "")
if role != "assistant":
continue
obj_list = _parse_json_from_text(value)
if isinstance(obj_list, list):
for obj in obj_list:
if isinstance(obj, dict):
if "bbox" in obj:
all_boxes.append(obj)
if "point" in obj:
all_points.append(obj)
return all_boxes, all_points
def _coords_normalized(c, wh):
# 判断c是归一化还是绝对,如任一值是float且<2,则认为是归一化
if not c: return False
for v in c:
if isinstance(v, float) and abs(v) < 2:
return True
return False
def _denorm_bbox(bbox, wh):
# 归一化box转像素坐标
if _coords_normalized(bbox, wh):
w, h = wh
return [bbox[0]*w, bbox[1]*h, bbox[2]*w, bbox[3]*h]
else:
return bbox
def _denorm_point(point, wh):
if _coords_normalized(point, wh):
w, h = wh
return [point[0]*w, point[1]*h]
else:
return point
def _draw_annotations(img, boxes, points):
img_draw = img.copy()
draw = ImageDraw.Draw(img_draw)
w, h = img_draw.size
# 画box
for obj in boxes:
bbox = obj.get("bbox")
label = obj.get("label", "")
if bbox and len(bbox) == 4:
x1, y1, x2, y2 = _denorm_bbox(bbox, (w, h))
x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
draw.rectangle([x1, y1, x2, y2], outline="red", width=2)
if label:
draw.text((x1, max(0, y1-10)), label, fill="red")
# 画point为红色
for obj in points:
pt = obj.get("point")
label = obj.get("label", "")
if pt and len(pt) == 2:
x, y = _denorm_point(pt, (w, h))
x, y = int(x), int(y)
r = 5
draw.ellipse([x-r, y-r, x+r, y+r], outline="red", width=2)
if label:
draw.text((x+r+2, y), label, fill="red")
return img_draw
def get_sample_list(dataset_name, page=0, page_size=5):
try:
ds = load_dataset(dataset_name, split="train", streaming=True)
except Exception:
return [[None]*5 for _ in range(page_size)]
start = page * page_size
results = []
for i, ex in enumerate(ds.skip(start)):
if i >= page_size:
break
meta = ex.get("meta", {})
img = None
img_ann = None
if "image_0" in meta:
try:
img_str = meta["image_0"]
if img_str.startswith("data:image"):
img_str = img_str.split(",")[-1]
if len(img_str) % 4 != 0:
img_str += "=" * (4 - len(img_str) % 4)
img_bytes = base64.b64decode(img_str)
img = Image.open(BytesIO(img_bytes)).convert("RGB")
except Exception:
pass
conversations = ex.get("conversations", "")
if isinstance(conversations, list):
conv_txt = "\n".join(str(x) for x in conversations)
else:
conv_txt = str(conversations)
data_type = ex.get("data_type", "")
all_boxes, all_points = _parse_conversations(conversations)
if img is not None:
img_ann = _draw_annotations(img, all_boxes, all_points)
results.append([img, img_ann, conv_txt, data_type, str(start+i)])
while len(results) < page_size:
results.append([None, None, "", "", ""])
return results
def get_page(dataset_name, page=0, page_size=5):
infos = get_sample_list(dataset_name, page, page_size)
outs = []
for tup in infos:
outs.extend(tup)
return outs
labels = ["原图", "带标注图", "conversations", "data_type", "样本idx"]
with gr.Blocks() as demo:
gr.Markdown("## Huggingface 多样本可视化\n每页显示5个样本(原图/带标注/对话内容/类型/索引),点击按钮加载。")
with gr.Row():
with gr.Column():
dataset_in = gr.Textbox(label="数据集名(建议手动优先)", value="", placeholder="优先使用此项")
dataset_dropdown = gr.Dropdown(
label="数据集名(下拉选择备选)",
choices=[
"xmu-xiaoma666-dataset/PR1-Datasets-Counting",
"xmu-xiaoma666-dataset/PR1-Datasets-Grounding",
"xmu-xiaoma666-dataset/LVIS"
],
value="xmu-xiaoma666-dataset/LVIS",
interactive=True
)
with gr.Column():
prev_btn = gr.Button("上一页")
load_btn = gr.Button("加载样本")
next_btn = gr.Button("下一页")
page = gr.Number(value=0, visible=False)
page_size = 5
image_blocks = []
for i in range(page_size):
with gr.Row():
img = gr.Image(label="原图", interactive=False)
img_ann = gr.Image(label="带标注图", interactive=False)
conv_out = gr.Textbox(label="conversations", lines=3)
type_out = gr.Textbox(label="data_type")
idx_out = gr.Textbox(label="样本idx", interactive=False)
image_blocks.extend([img, img_ann, conv_out, type_out, idx_out])
def select_dataset_name(text_val, dropdown_val):
# 优先手动输入,如为空再用下拉
ds = text_val.strip() if text_val and text_val.strip() else dropdown_val.strip()
return ds
def prev_page(text_val, dropdown_val, page_num):
ds = select_dataset_name(text_val, dropdown_val)
page_num = max(0, int(page_num)-1)
outs = get_page(ds, page_num)
return [page_num] + outs
def next_page(text_val, dropdown_val, page_num):
ds = select_dataset_name(text_val, dropdown_val)
page_num = int(page_num) + 1
outs = get_page(ds, page_num)
return [page_num] + outs
def load_current_page(text_val, dropdown_val, page_num):
ds = select_dataset_name(text_val, dropdown_val)
page_num = int(page_num)
outs = get_page(ds, page_num)
return [page_num] + outs
prev_btn.click(prev_page, [dataset_in, dataset_dropdown, page], [page] + image_blocks)
next_btn.click(next_page, [dataset_in, dataset_dropdown, page], [page] + image_blocks)
load_btn.click(load_current_page, [dataset_in, dataset_dropdown, page], [page] + image_blocks)
# 不自动加载,需点“加载样本”
demo.launch()