Spaces:
Sleeping
Sleeping
| # -*- encoding: utf-8 -*- | |
| """ | |
| @File : app.py | |
| @Time : 2025/8/29 15:25:00 | |
| @Author : lh9171338 | |
| @Version : 1.0 | |
| @Contact : 2909171338@qq.com | |
| """ | |
| import os | |
| import gradio as gr | |
| from PIL import Image | |
| import io | |
| import logging | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from datasets import load_dataset, DatasetDict | |
| import utils.camera as cam | |
| import utils.bezier as bez | |
| dataset_dict = dict() | |
| dataset = None | |
| default_split_selector_info = dict( | |
| choices=["train", "test"], | |
| label="Split", | |
| value="train", | |
| interactive=False, | |
| ) | |
| default_index_slider_info = dict( | |
| minimum=0, | |
| maximum=1, | |
| step=1, | |
| label="Index", | |
| value=0, | |
| interactive=False, | |
| ) | |
| default_order_slider_info = dict( | |
| minimum=0, | |
| maximum=6, | |
| step=1, | |
| label="Order", | |
| value=0, | |
| interactive=False, | |
| ) | |
| sample_info = dict( | |
| dataset=dataset, | |
| split="train", | |
| index=0, | |
| order=0, | |
| image1=None, | |
| image2=None, | |
| ) | |
| def get_dataset(dataset_name): | |
| """ | |
| Get dataset | |
| Args: | |
| dataset_name (str): dataset name or path | |
| Returns: | |
| dataset (datasets.Dataset): dataset | |
| """ | |
| global dataset_dict | |
| if dataset_name in dataset_dict: | |
| dataset = dataset_dict[dataset_name] | |
| else: | |
| if os.path.exists(dataset_name): | |
| dataset = load_dataset("imagefolder", data_dir=dataset_name) | |
| else: | |
| dataset = load_dataset(dataset_name) | |
| dataset_dict[dataset_name] = dataset | |
| return dataset | |
| def submit_callback(dataset_name, order): | |
| """ | |
| Submit callback function | |
| Args: | |
| dataset_name (str): dataset name or path | |
| order (int): order of the Bezier curve | |
| Returns: | |
| split_selector_info (dict): updated split selector info | |
| index_slider_info (dict): updated index slider info | |
| order_slider_info (dict): updated slider info | |
| image1 (np.ndarray): updated image | |
| image2 (np.ndarray): updated image | |
| """ | |
| global dataset | |
| try: | |
| dataset = get_dataset(dataset_name) | |
| except Exception as e: | |
| dataset = None | |
| logging.error(f"Load dataset failed: {e}") | |
| split_selector_info = gr.update(**default_split_selector_info) | |
| index_slider_info = gr.update(**default_index_slider_info) | |
| order_slider_info = gr.update(**default_order_slider_info) | |
| return split_selector_info, index_slider_info, order_slider_info, None, None | |
| if not isinstance(dataset, DatasetDict): | |
| dataset = {str(dataset.split): dataset} | |
| splits = list(dataset.keys()) | |
| split = splits[0] | |
| maximum = len(dataset[split]) - 1 | |
| index = 0 | |
| split_selector_info = gr.update(choices=splits, value=split, interactive=True) | |
| index_slider_info = gr.update(minimum=0, maximum=maximum, value=index, interactive=True) | |
| order_slider_info = gr.update(interactive=True) | |
| image1, image2 = show_image(split=split, index=index, order=order) | |
| return split_selector_info, index_slider_info, order_slider_info, image1, image2 | |
| def selector_change_callback(split, order): | |
| """ | |
| Selector change callback function | |
| Args: | |
| split (str): selected split, value must be one of ["train", "test"] | |
| order (int): order of the Bezier curve | |
| Returns: | |
| index_slider_info (dict): updated slider info | |
| image1 (np.ndarray): updated image | |
| image2 (np.ndarray): updated image | |
| """ | |
| global dataset | |
| if dataset is None: | |
| index_slider_info = gr.update(**default_index_slider_info) | |
| return index_slider_info, None, None | |
| maximum = len(dataset[split]) - 1 | |
| index = 0 | |
| index_slider_info = gr.update(minimum=0, maximum=maximum, value=index) | |
| image1, image2 = show_image(split=split, index=0, order=order) | |
| return index_slider_info, image1, image2 | |
| def draw_lines(image, lines, camera_type="pinhole", camera_coeff=None, order=None): | |
| """ | |
| Draw lines on image | |
| Args: | |
| image (np.ndarray): input image | |
| lines (np.ndarray): list of lines, with shape [N, 2, 2] | |
| camera_type (str): camera type, value must be one of ["pinhole", "fisheye", "spherical"] | |
| camera_coeff (dict | None): dict of camera coefficients | |
| order (int | None): order of the Bezier curve | |
| Returns: | |
| image (PIL.Image | None): drawn image | |
| """ | |
| if order == 0: # Show original image | |
| return image | |
| assert camera_type in ["pinhole", "fisheye", "spherical"] | |
| height, width = image.shape[:2] | |
| if camera_type == "pinhole": | |
| camera = cam.Pinhole(coeff=camera_coeff) | |
| elif camera_type == "fisheye": | |
| camera = cam.Fisheye(coeff=camera_coeff) | |
| else: | |
| camera = cam.Spherical(image_size=(width, height), coeff=camera_coeff) | |
| fig = plt.figure() | |
| fig.set_size_inches(width / height, 1, forward=False) | |
| ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0]) | |
| ax.set_axis_off() | |
| fig.add_axes(ax) | |
| plt.xlim([-0.5, width - 0.5]) | |
| plt.ylim([height - 0.5, -0.5]) | |
| plt.imshow(image) | |
| lines = camera.truncate_line(lines) | |
| pts_list = camera.interp_line(lines) | |
| if order is not None: # Draw Bezier curve | |
| bezier = bez.Bezier(order=order) | |
| lines, t_list = bezier.fit_line(pts_list) | |
| pts_list = bezier.interp_line(lines, t_list) | |
| for pts in pts_list: | |
| pts = pts - 0.5 | |
| plt.plot(pts[:, 0], pts[:, 1], color="orange", linewidth=0.5) | |
| plt.scatter(pts[[0, -1], 0], pts[[0, -1], 1], color="#33FFFF", s=1.2, edgecolors="none", zorder=5) | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format="png", dpi=height, bbox_inches=0) | |
| buf.seek(0) | |
| plt.close(fig) | |
| image = Image.open(buf) | |
| return image | |
| def show_image(split, index, order): | |
| """ | |
| Show image | |
| Args: | |
| split (str): split name, value must be one of ["train", "test"] | |
| index (int): index of the sample | |
| order (int): order of the Bezier curve | |
| Returns: | |
| image1 (PIL.Image): drawn image | |
| image2 (PIL.Image): drawn image | |
| """ | |
| global dataset | |
| if dataset is None: | |
| return None, None | |
| global sample_info | |
| old_sample_info = dict( | |
| dataset=sample_info["dataset"], | |
| split=sample_info["split"], | |
| index=sample_info["index"], | |
| order=sample_info["order"], | |
| ) | |
| new_sample_info = dict(dataset=dataset, split=split, index=index, order=order) | |
| if old_sample_info == new_sample_info: # No need to update | |
| logging.info("No need to update") | |
| return sample_info["image1"], sample_info["image2"] | |
| old_sample_info.pop("order") | |
| new_sample_info.pop("order") | |
| sample = dataset[split][index] | |
| image = np.array(sample["image"]) | |
| lines = np.array(sample["lines"]) | |
| camera_type = sample.get("camera_type", "pinhole") | |
| camera_coeff = sample.get("camera_coeff", None) | |
| if old_sample_info == new_sample_info: # No need to update origin label | |
| image1 = sample_info["image1"] | |
| logging.info("Only update Bezier curve") | |
| else: | |
| image1 = draw_lines(image, lines, camera_type, camera_coeff) | |
| image2 = draw_lines(image, lines, camera_type, camera_coeff, order) | |
| sample_info.update(new_sample_info) | |
| sample_info["order"] = order | |
| sample_info["image1"] = image1 | |
| sample_info["image2"] = image2 | |
| logging.info("Update") | |
| return image1, image2 | |
| def main(): | |
| """ | |
| Main | |
| Args: | |
| None | |
| Returns: | |
| None | |
| """ | |
| with gr.Blocks() as demo: | |
| dataset_textbox = gr.Textbox(value="lh9171338/Wireframe", label="Dataset name or path") | |
| split_selector = gr.Dropdown(**default_split_selector_info) | |
| index_slider = gr.Slider(**default_index_slider_info) | |
| order_slider = gr.Slider(**default_order_slider_info) | |
| with gr.Row(): | |
| image1 = gr.Image(label="Original Label") | |
| image2 = gr.Image(label="Bezier Curve") | |
| dataset_textbox.submit( | |
| submit_callback, | |
| [dataset_textbox, order_slider], | |
| [split_selector, index_slider, order_slider, image1, image2], | |
| ) | |
| split_selector.change(selector_change_callback, [split_selector, order_slider], [index_slider, image1, image2]) | |
| index_slider.change(show_image, [split_selector, index_slider, order_slider], [image1, image2]) | |
| order_slider.change(show_image, [split_selector, index_slider, order_slider], [image1, image2]) | |
| demo.load( | |
| submit_callback, | |
| [dataset_textbox, order_slider], | |
| [split_selector, index_slider, order_slider, image1, image2], | |
| ) | |
| demo.launch(share=False) | |
| if __name__ == "__main__": | |
| # set base logging config | |
| fmt = "[%(asctime)s - %(levelname)s - %(filename)s:%(lineno)s] %(message)s" | |
| logging.basicConfig(format=fmt, level=logging.INFO) | |
| main() | |