LineViewer / app.py
lihao57
set default dataset
17ef249
# -*- encoding: utf-8 -*-
"""
@File : app.py
@Time : 2025/8/29 15:25:00
@Author : lh9171338
@Version : 1.0
@Contact : 2909171338@qq.com
"""
import os
import gradio as gr
from PIL import Image
import io
import logging
import matplotlib.pyplot as plt
import numpy as np
from datasets import load_dataset, DatasetDict
import utils.camera as cam
import utils.bezier as bez
dataset_dict = dict()
dataset = None
default_split_selector_info = dict(
choices=["train", "test"],
label="Split",
value="train",
interactive=False,
)
default_index_slider_info = dict(
minimum=0,
maximum=1,
step=1,
label="Index",
value=0,
interactive=False,
)
default_order_slider_info = dict(
minimum=0,
maximum=6,
step=1,
label="Order",
value=0,
interactive=False,
)
sample_info = dict(
dataset=dataset,
split="train",
index=0,
order=0,
image1=None,
image2=None,
)
def get_dataset(dataset_name):
"""
Get dataset
Args:
dataset_name (str): dataset name or path
Returns:
dataset (datasets.Dataset): dataset
"""
global dataset_dict
if dataset_name in dataset_dict:
dataset = dataset_dict[dataset_name]
else:
if os.path.exists(dataset_name):
dataset = load_dataset("imagefolder", data_dir=dataset_name)
else:
dataset = load_dataset(dataset_name)
dataset_dict[dataset_name] = dataset
return dataset
def submit_callback(dataset_name, order):
"""
Submit callback function
Args:
dataset_name (str): dataset name or path
order (int): order of the Bezier curve
Returns:
split_selector_info (dict): updated split selector info
index_slider_info (dict): updated index slider info
order_slider_info (dict): updated slider info
image1 (np.ndarray): updated image
image2 (np.ndarray): updated image
"""
global dataset
try:
dataset = get_dataset(dataset_name)
except Exception as e:
dataset = None
logging.error(f"Load dataset failed: {e}")
split_selector_info = gr.update(**default_split_selector_info)
index_slider_info = gr.update(**default_index_slider_info)
order_slider_info = gr.update(**default_order_slider_info)
return split_selector_info, index_slider_info, order_slider_info, None, None
if not isinstance(dataset, DatasetDict):
dataset = {str(dataset.split): dataset}
splits = list(dataset.keys())
split = splits[0]
maximum = len(dataset[split]) - 1
index = 0
split_selector_info = gr.update(choices=splits, value=split, interactive=True)
index_slider_info = gr.update(minimum=0, maximum=maximum, value=index, interactive=True)
order_slider_info = gr.update(interactive=True)
image1, image2 = show_image(split=split, index=index, order=order)
return split_selector_info, index_slider_info, order_slider_info, image1, image2
def selector_change_callback(split, order):
"""
Selector change callback function
Args:
split (str): selected split, value must be one of ["train", "test"]
order (int): order of the Bezier curve
Returns:
index_slider_info (dict): updated slider info
image1 (np.ndarray): updated image
image2 (np.ndarray): updated image
"""
global dataset
if dataset is None:
index_slider_info = gr.update(**default_index_slider_info)
return index_slider_info, None, None
maximum = len(dataset[split]) - 1
index = 0
index_slider_info = gr.update(minimum=0, maximum=maximum, value=index)
image1, image2 = show_image(split=split, index=0, order=order)
return index_slider_info, image1, image2
def draw_lines(image, lines, camera_type="pinhole", camera_coeff=None, order=None):
"""
Draw lines on image
Args:
image (np.ndarray): input image
lines (np.ndarray): list of lines, with shape [N, 2, 2]
camera_type (str): camera type, value must be one of ["pinhole", "fisheye", "spherical"]
camera_coeff (dict | None): dict of camera coefficients
order (int | None): order of the Bezier curve
Returns:
image (PIL.Image | None): drawn image
"""
if order == 0: # Show original image
return image
assert camera_type in ["pinhole", "fisheye", "spherical"]
height, width = image.shape[:2]
if camera_type == "pinhole":
camera = cam.Pinhole(coeff=camera_coeff)
elif camera_type == "fisheye":
camera = cam.Fisheye(coeff=camera_coeff)
else:
camera = cam.Spherical(image_size=(width, height), coeff=camera_coeff)
fig = plt.figure()
fig.set_size_inches(width / height, 1, forward=False)
ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
ax.set_axis_off()
fig.add_axes(ax)
plt.xlim([-0.5, width - 0.5])
plt.ylim([height - 0.5, -0.5])
plt.imshow(image)
lines = camera.truncate_line(lines)
pts_list = camera.interp_line(lines)
if order is not None: # Draw Bezier curve
bezier = bez.Bezier(order=order)
lines, t_list = bezier.fit_line(pts_list)
pts_list = bezier.interp_line(lines, t_list)
for pts in pts_list:
pts = pts - 0.5
plt.plot(pts[:, 0], pts[:, 1], color="orange", linewidth=0.5)
plt.scatter(pts[[0, -1], 0], pts[[0, -1], 1], color="#33FFFF", s=1.2, edgecolors="none", zorder=5)
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=height, bbox_inches=0)
buf.seek(0)
plt.close(fig)
image = Image.open(buf)
return image
def show_image(split, index, order):
"""
Show image
Args:
split (str): split name, value must be one of ["train", "test"]
index (int): index of the sample
order (int): order of the Bezier curve
Returns:
image1 (PIL.Image): drawn image
image2 (PIL.Image): drawn image
"""
global dataset
if dataset is None:
return None, None
global sample_info
old_sample_info = dict(
dataset=sample_info["dataset"],
split=sample_info["split"],
index=sample_info["index"],
order=sample_info["order"],
)
new_sample_info = dict(dataset=dataset, split=split, index=index, order=order)
if old_sample_info == new_sample_info: # No need to update
logging.info("No need to update")
return sample_info["image1"], sample_info["image2"]
old_sample_info.pop("order")
new_sample_info.pop("order")
sample = dataset[split][index]
image = np.array(sample["image"])
lines = np.array(sample["lines"])
camera_type = sample.get("camera_type", "pinhole")
camera_coeff = sample.get("camera_coeff", None)
if old_sample_info == new_sample_info: # No need to update origin label
image1 = sample_info["image1"]
logging.info("Only update Bezier curve")
else:
image1 = draw_lines(image, lines, camera_type, camera_coeff)
image2 = draw_lines(image, lines, camera_type, camera_coeff, order)
sample_info.update(new_sample_info)
sample_info["order"] = order
sample_info["image1"] = image1
sample_info["image2"] = image2
logging.info("Update")
return image1, image2
def main():
"""
Main
Args:
None
Returns:
None
"""
with gr.Blocks() as demo:
dataset_textbox = gr.Textbox(value="lh9171338/Wireframe", label="Dataset name or path")
split_selector = gr.Dropdown(**default_split_selector_info)
index_slider = gr.Slider(**default_index_slider_info)
order_slider = gr.Slider(**default_order_slider_info)
with gr.Row():
image1 = gr.Image(label="Original Label")
image2 = gr.Image(label="Bezier Curve")
dataset_textbox.submit(
submit_callback,
[dataset_textbox, order_slider],
[split_selector, index_slider, order_slider, image1, image2],
)
split_selector.change(selector_change_callback, [split_selector, order_slider], [index_slider, image1, image2])
index_slider.change(show_image, [split_selector, index_slider, order_slider], [image1, image2])
order_slider.change(show_image, [split_selector, index_slider, order_slider], [image1, image2])
demo.load(
submit_callback,
[dataset_textbox, order_slider],
[split_selector, index_slider, order_slider, image1, image2],
)
demo.launch(share=False)
if __name__ == "__main__":
# set base logging config
fmt = "[%(asctime)s - %(levelname)s - %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(format=fmt, level=logging.INFO)
main()