# -*- encoding: utf-8 -*- """ @File : app.py @Time : 2025/8/29 15:25:00 @Author : lh9171338 @Version : 1.0 @Contact : 2909171338@qq.com """ import os import gradio as gr from PIL import Image import io import logging import matplotlib.pyplot as plt import numpy as np from datasets import load_dataset, DatasetDict import utils.camera as cam import utils.bezier as bez dataset_dict = dict() dataset = None default_split_selector_info = dict( choices=["train", "test"], label="Split", value="train", interactive=False, ) default_index_slider_info = dict( minimum=0, maximum=1, step=1, label="Index", value=0, interactive=False, ) default_order_slider_info = dict( minimum=0, maximum=6, step=1, label="Order", value=0, interactive=False, ) sample_info = dict( dataset=dataset, split="train", index=0, order=0, image1=None, image2=None, ) def get_dataset(dataset_name): """ Get dataset Args: dataset_name (str): dataset name or path Returns: dataset (datasets.Dataset): dataset """ global dataset_dict if dataset_name in dataset_dict: dataset = dataset_dict[dataset_name] else: if os.path.exists(dataset_name): dataset = load_dataset("imagefolder", data_dir=dataset_name) else: dataset = load_dataset(dataset_name) dataset_dict[dataset_name] = dataset return dataset def submit_callback(dataset_name, order): """ Submit callback function Args: dataset_name (str): dataset name or path order (int): order of the Bezier curve Returns: split_selector_info (dict): updated split selector info index_slider_info (dict): updated index slider info order_slider_info (dict): updated slider info image1 (np.ndarray): updated image image2 (np.ndarray): updated image """ global dataset try: dataset = get_dataset(dataset_name) except Exception as e: dataset = None logging.error(f"Load dataset failed: {e}") split_selector_info = gr.update(**default_split_selector_info) index_slider_info = gr.update(**default_index_slider_info) order_slider_info = gr.update(**default_order_slider_info) return split_selector_info, index_slider_info, order_slider_info, None, None if not isinstance(dataset, DatasetDict): dataset = {str(dataset.split): dataset} splits = list(dataset.keys()) split = splits[0] maximum = len(dataset[split]) - 1 index = 0 split_selector_info = gr.update(choices=splits, value=split, interactive=True) index_slider_info = gr.update(minimum=0, maximum=maximum, value=index, interactive=True) order_slider_info = gr.update(interactive=True) image1, image2 = show_image(split=split, index=index, order=order) return split_selector_info, index_slider_info, order_slider_info, image1, image2 def selector_change_callback(split, order): """ Selector change callback function Args: split (str): selected split, value must be one of ["train", "test"] order (int): order of the Bezier curve Returns: index_slider_info (dict): updated slider info image1 (np.ndarray): updated image image2 (np.ndarray): updated image """ global dataset if dataset is None: index_slider_info = gr.update(**default_index_slider_info) return index_slider_info, None, None maximum = len(dataset[split]) - 1 index = 0 index_slider_info = gr.update(minimum=0, maximum=maximum, value=index) image1, image2 = show_image(split=split, index=0, order=order) return index_slider_info, image1, image2 def draw_lines(image, lines, camera_type="pinhole", camera_coeff=None, order=None): """ Draw lines on image Args: image (np.ndarray): input image lines (np.ndarray): list of lines, with shape [N, 2, 2] camera_type (str): camera type, value must be one of ["pinhole", "fisheye", "spherical"] camera_coeff (dict | None): dict of camera coefficients order (int | None): order of the Bezier curve Returns: image (PIL.Image | None): drawn image """ if order == 0: # Show original image return image assert camera_type in ["pinhole", "fisheye", "spherical"] height, width = image.shape[:2] if camera_type == "pinhole": camera = cam.Pinhole(coeff=camera_coeff) elif camera_type == "fisheye": camera = cam.Fisheye(coeff=camera_coeff) else: camera = cam.Spherical(image_size=(width, height), coeff=camera_coeff) fig = plt.figure() fig.set_size_inches(width / height, 1, forward=False) ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0]) ax.set_axis_off() fig.add_axes(ax) plt.xlim([-0.5, width - 0.5]) plt.ylim([height - 0.5, -0.5]) plt.imshow(image) lines = camera.truncate_line(lines) pts_list = camera.interp_line(lines) if order is not None: # Draw Bezier curve bezier = bez.Bezier(order=order) lines, t_list = bezier.fit_line(pts_list) pts_list = bezier.interp_line(lines, t_list) for pts in pts_list: pts = pts - 0.5 plt.plot(pts[:, 0], pts[:, 1], color="orange", linewidth=0.5) plt.scatter(pts[[0, -1], 0], pts[[0, -1], 1], color="#33FFFF", s=1.2, edgecolors="none", zorder=5) buf = io.BytesIO() fig.savefig(buf, format="png", dpi=height, bbox_inches=0) buf.seek(0) plt.close(fig) image = Image.open(buf) return image def show_image(split, index, order): """ Show image Args: split (str): split name, value must be one of ["train", "test"] index (int): index of the sample order (int): order of the Bezier curve Returns: image1 (PIL.Image): drawn image image2 (PIL.Image): drawn image """ global dataset if dataset is None: return None, None global sample_info old_sample_info = dict( dataset=sample_info["dataset"], split=sample_info["split"], index=sample_info["index"], order=sample_info["order"], ) new_sample_info = dict(dataset=dataset, split=split, index=index, order=order) if old_sample_info == new_sample_info: # No need to update logging.info("No need to update") return sample_info["image1"], sample_info["image2"] old_sample_info.pop("order") new_sample_info.pop("order") sample = dataset[split][index] image = np.array(sample["image"]) lines = np.array(sample["lines"]) camera_type = sample.get("camera_type", "pinhole") camera_coeff = sample.get("camera_coeff", None) if old_sample_info == new_sample_info: # No need to update origin label image1 = sample_info["image1"] logging.info("Only update Bezier curve") else: image1 = draw_lines(image, lines, camera_type, camera_coeff) image2 = draw_lines(image, lines, camera_type, camera_coeff, order) sample_info.update(new_sample_info) sample_info["order"] = order sample_info["image1"] = image1 sample_info["image2"] = image2 logging.info("Update") return image1, image2 def main(): """ Main Args: None Returns: None """ with gr.Blocks() as demo: dataset_textbox = gr.Textbox(value="lh9171338/Wireframe", label="Dataset name or path") split_selector = gr.Dropdown(**default_split_selector_info) index_slider = gr.Slider(**default_index_slider_info) order_slider = gr.Slider(**default_order_slider_info) with gr.Row(): image1 = gr.Image(label="Original Label") image2 = gr.Image(label="Bezier Curve") dataset_textbox.submit( submit_callback, [dataset_textbox, order_slider], [split_selector, index_slider, order_slider, image1, image2], ) split_selector.change(selector_change_callback, [split_selector, order_slider], [index_slider, image1, image2]) index_slider.change(show_image, [split_selector, index_slider, order_slider], [image1, image2]) order_slider.change(show_image, [split_selector, index_slider, order_slider], [image1, image2]) demo.load( submit_callback, [dataset_textbox, order_slider], [split_selector, index_slider, order_slider, image1, image2], ) demo.launch(share=False) if __name__ == "__main__": # set base logging config fmt = "[%(asctime)s - %(levelname)s - %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(format=fmt, level=logging.INFO) main()