| import base64 |
| from io import BytesIO |
| from PIL import Image, ImageDraw |
| from datasets import load_dataset |
| import gradio as gr |
| import json |
| import ast |
| import re |
|
|
| def _parse_json_from_text(text): |
| """ |
| 提取text中的第一个```json ...```代码块并解析为列表. |
| 如果没有代码块,则尝试直接解析text内容为列表. |
| """ |
| if not text or not isinstance(text, str): |
| return [] |
| |
| blocks = re.findall(r"```json\s*([\s\S]+?)```", text) |
| for blk in blocks: |
| try: |
| return json.loads(blk) |
| except Exception: |
| try: |
| return ast.literal_eval(blk) |
| except Exception: |
| continue |
| |
| try: |
| return json.loads(text) |
| except Exception: |
| try: |
| |
| arrtxt = text.strip() |
| arrtxt = re.split(r"\n\d+$", arrtxt)[0] |
| return ast.literal_eval(arrtxt) |
| except Exception: |
| return [] |
| return [] |
|
|
| def _parse_conversations(conversations): |
| """提取所有assistant回复中的box/point""" |
| all_boxes = [] |
| all_points = [] |
| if isinstance(conversations, str): |
| try: |
| lines = [line for line in conversations.split("\n") if line.strip()] |
| conversations = [ast.literal_eval(line) for line in lines] |
| except Exception: |
| return [], [] |
| for utter in conversations: |
| role = utter.get("from", "") |
| value = utter.get("value", "") |
| if role != "assistant": |
| continue |
| obj_list = _parse_json_from_text(value) |
| if isinstance(obj_list, list): |
| for obj in obj_list: |
| if isinstance(obj, dict): |
| if "bbox" in obj: |
| all_boxes.append(obj) |
| if "point" in obj: |
| all_points.append(obj) |
| return all_boxes, all_points |
|
|
| def _coords_normalized(c, wh): |
| |
| if not c: return False |
| for v in c: |
| if isinstance(v, float) and abs(v) < 2: |
| return True |
| return False |
|
|
| def _denorm_bbox(bbox, wh): |
| |
| if _coords_normalized(bbox, wh): |
| w, h = wh |
| return [bbox[0]*w, bbox[1]*h, bbox[2]*w, bbox[3]*h] |
| else: |
| return bbox |
|
|
| def _denorm_point(point, wh): |
| if _coords_normalized(point, wh): |
| w, h = wh |
| return [point[0]*w, point[1]*h] |
| else: |
| return point |
|
|
| def _draw_annotations(img, boxes, points): |
| img_draw = img.copy() |
| draw = ImageDraw.Draw(img_draw) |
| w, h = img_draw.size |
| |
| for obj in boxes: |
| bbox = obj.get("bbox") |
| label = obj.get("label", "") |
| if bbox and len(bbox) == 4: |
| x1, y1, x2, y2 = _denorm_bbox(bbox, (w, h)) |
| x1, y1, x2, y2 = map(int, [x1, y1, x2, y2]) |
| draw.rectangle([x1, y1, x2, y2], outline="red", width=2) |
| if label: |
| draw.text((x1, max(0, y1-10)), label, fill="red") |
| |
| for obj in points: |
| pt = obj.get("point") |
| label = obj.get("label", "") |
| if pt and len(pt) == 2: |
| x, y = _denorm_point(pt, (w, h)) |
| x, y = int(x), int(y) |
| r = 5 |
| draw.ellipse([x-r, y-r, x+r, y+r], outline="red", width=2) |
| if label: |
| draw.text((x+r+2, y), label, fill="red") |
| return img_draw |
|
|
| def get_sample_list(dataset_name, page=0, page_size=5): |
| try: |
| ds = load_dataset(dataset_name, split="train", streaming=True) |
| except Exception: |
| return [[None]*5 for _ in range(page_size)] |
|
|
| start = page * page_size |
| results = [] |
| for i, ex in enumerate(ds.skip(start)): |
| if i >= page_size: |
| break |
| meta = ex.get("meta", {}) |
| img = None |
| img_ann = None |
| if "image_0" in meta: |
| try: |
| img_str = meta["image_0"] |
| if img_str.startswith("data:image"): |
| img_str = img_str.split(",")[-1] |
| if len(img_str) % 4 != 0: |
| img_str += "=" * (4 - len(img_str) % 4) |
| img_bytes = base64.b64decode(img_str) |
| img = Image.open(BytesIO(img_bytes)).convert("RGB") |
| except Exception: |
| pass |
| conversations = ex.get("conversations", "") |
| if isinstance(conversations, list): |
| conv_txt = "\n".join(str(x) for x in conversations) |
| else: |
| conv_txt = str(conversations) |
| data_type = ex.get("data_type", "") |
| all_boxes, all_points = _parse_conversations(conversations) |
| if img is not None: |
| img_ann = _draw_annotations(img, all_boxes, all_points) |
| results.append([img, img_ann, conv_txt, data_type, str(start+i)]) |
| while len(results) < page_size: |
| results.append([None, None, "", "", ""]) |
| return results |
|
|
| def get_page(dataset_name, page=0, page_size=5): |
| infos = get_sample_list(dataset_name, page, page_size) |
| outs = [] |
| for tup in infos: |
| outs.extend(tup) |
| return outs |
|
|
| labels = ["原图", "带标注图", "conversations", "data_type", "样本idx"] |
|
|
| with gr.Blocks() as demo: |
| gr.Markdown("## Huggingface 多样本可视化\n每页显示5个样本(原图/带标注/对话内容/类型/索引),点击按钮加载。") |
| with gr.Row(): |
| with gr.Column(): |
| dataset_in = gr.Textbox(label="数据集名(建议手动优先)", value="", placeholder="优先使用此项") |
| dataset_dropdown = gr.Dropdown( |
| label="数据集名(下拉选择备选)", |
| choices=[ |
| "xmu-xiaoma666-dataset/PR1-Datasets-Counting", |
| "xmu-xiaoma666-dataset/PR1-Datasets-Grounding", |
| "xmu-xiaoma666-dataset/LVIS" |
| ], |
| value="xmu-xiaoma666-dataset/LVIS", |
| interactive=True |
| ) |
| with gr.Column(): |
| prev_btn = gr.Button("上一页") |
| load_btn = gr.Button("加载样本") |
| next_btn = gr.Button("下一页") |
| page = gr.Number(value=0, visible=False) |
| page_size = 5 |
|
|
| image_blocks = [] |
| for i in range(page_size): |
| with gr.Row(): |
| img = gr.Image(label="原图", interactive=False) |
| img_ann = gr.Image(label="带标注图", interactive=False) |
| conv_out = gr.Textbox(label="conversations", lines=3) |
| type_out = gr.Textbox(label="data_type") |
| idx_out = gr.Textbox(label="样本idx", interactive=False) |
| image_blocks.extend([img, img_ann, conv_out, type_out, idx_out]) |
|
|
| def select_dataset_name(text_val, dropdown_val): |
| |
| ds = text_val.strip() if text_val and text_val.strip() else dropdown_val.strip() |
| return ds |
|
|
| def prev_page(text_val, dropdown_val, page_num): |
| ds = select_dataset_name(text_val, dropdown_val) |
| page_num = max(0, int(page_num)-1) |
| outs = get_page(ds, page_num) |
| return [page_num] + outs |
|
|
| def next_page(text_val, dropdown_val, page_num): |
| ds = select_dataset_name(text_val, dropdown_val) |
| page_num = int(page_num) + 1 |
| outs = get_page(ds, page_num) |
| return [page_num] + outs |
|
|
| def load_current_page(text_val, dropdown_val, page_num): |
| ds = select_dataset_name(text_val, dropdown_val) |
| page_num = int(page_num) |
| outs = get_page(ds, page_num) |
| return [page_num] + outs |
|
|
| prev_btn.click(prev_page, [dataset_in, dataset_dropdown, page], [page] + image_blocks) |
| next_btn.click(next_page, [dataset_in, dataset_dropdown, page], [page] + image_blocks) |
| load_btn.click(load_current_page, [dataset_in, dataset_dropdown, page], [page] + image_blocks) |
|
|
| |
|
|
| demo.launch() |
|
|