FE-LineViewer / app.py
lihao57
initial commit
7ebf5dc
# -*- encoding: utf-8 -*-
"""
@File : app.py
@Time : 2025/8/29 15:25:00
@Author : lh9171338
@Version : 1.0
@Contact : 2909171338@qq.com
"""
import os
import gradio as gr
from PIL import Image
import io
import logging
import matplotlib.pyplot as plt
import numpy as np
from datasets import load_dataset, DatasetDict
from utils.event import Event
dataset_dict = dict()
dataset = None
default_split_selector_info = dict(
choices=["train", "test"],
label="Split",
value="train",
interactive=False,
)
default_index_slider_info = dict(
minimum=0,
maximum=1,
step=1,
label="Index",
value=0,
interactive=False,
)
sample_info = dict(
dataset=dataset,
split="train",
index=0,
blur_image=None,
event_image=None,
start_image=None,
end_image=None,
)
def get_dataset(dataset_name):
"""
Get dataset
Args:
dataset_name (str): dataset name or path
Returns:
dataset (datasets.Dataset): dataset
"""
global dataset_dict
if dataset_name in dataset_dict:
dataset = dataset_dict[dataset_name]
else:
if os.path.exists(dataset_name):
dataset = load_dataset(dataset_name, data_dir=dataset_name, trust_remote_code=True)
else:
dataset = load_dataset(dataset_name, trust_remote_code=True)
dataset_dict[dataset_name] = dataset
return dataset
def submit_callback(dataset_name):
"""
Submit callback function
Args:
dataset_name (str): dataset name or path
Returns:
split_selector_info (dict): updated split selector info
index_slider_info (dict): updated index slider info
blur_image (PIL.Image): updated blur image
event_image (PIL.Image): updated event image
start_image (PIL.Image): updated start image
end_image (PIL.Image): updated end image
"""
global dataset
try:
dataset = get_dataset(dataset_name)
except Exception as e:
dataset = None
logging.error(f"Load dataset failed: {e}")
split_selector_info = gr.update(**default_split_selector_info)
index_slider_info = gr.update(**default_index_slider_info)
return split_selector_info, index_slider_info, None, None, None, None
if not isinstance(dataset, DatasetDict):
dataset = {str(dataset.split): dataset}
splits = list(dataset.keys())
split = splits[0]
maximum = len(dataset[split]) - 1
index = 0
split_selector_info = gr.update(choices=splits, value=split, interactive=True)
index_slider_info = gr.update(minimum=0, maximum=maximum, value=index, interactive=True)
blur_image, event_image, start_image, end_image = show_image(split=split, index=index)
return split_selector_info, index_slider_info, blur_image, event_image, start_image, end_image
def selector_change_callback(split):
"""
Selector change callback function
Args:
split (str): selected split, value must be one of ["train", "test"]
Returns:
index_slider_info (dict): updated slider info
blur_image (PIL.Image): updated blur image
event_image (PIL.Image): updated event image
start_image (PIL.Image): updated start image
end_image (PIL.Image): updated end image
"""
global dataset
if dataset is None:
index_slider_info = gr.update(**default_index_slider_info)
return index_slider_info, None, None, None, None
maximum = len(dataset[split]) - 1
index = 0
index_slider_info = gr.update(minimum=0, maximum=maximum, value=index)
blur_image, event_image, start_image, end_image = show_image(split=split, index=index)
return index_slider_info, blur_image, event_image, start_image, end_image
def draw_lines(image, lines):
"""
Draw lines on image
Args:
image (np.ndarray): input image
lines (np.ndarray): list of lines, with shape [N, 2, 2]
Returns:
image (PIL.Image): drawn image
"""
height, width = image.shape[:2]
fig = plt.figure()
fig.set_size_inches(width / height, 1, forward=False)
ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
ax.set_axis_off()
fig.add_axes(ax)
plt.xlim([-0.5, width - 0.5])
plt.ylim([height - 0.5, -0.5])
plt.imshow(image)
for pts in lines:
pts = pts - 0.5
plt.plot(pts[:, 0], pts[:, 1], color="orange", linewidth=0.5)
plt.scatter(pts[:, 0], pts[:, 1], color="#33FFFF", s=1.2, edgecolors="none", zorder=5)
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=height, bbox_inches=0)
buf.seek(0)
plt.close(fig)
image = Image.open(buf)
return image
def show_image(split, index):
"""
Show image
Args:
split (str): split name, value must be one of ["train", "test"]
index (int): index of the sample
Returns:
blur_image (PIL.Image): drawn blurred image
event_image (PIL.Image): drawn event image
start_image (PIL.Image): drawn start image
end_image (PIL.Image): drawn end image
"""
global dataset
if dataset is None:
return None, None, None, None
global sample_info
old_sample_info = dict(
dataset=sample_info["dataset"],
split=sample_info["split"],
index=sample_info["index"],
)
new_sample_info = dict(dataset=dataset, split=split, index=index)
if old_sample_info == new_sample_info: # No need to update
logging.info("No need to update")
return sample_info["blur_image"], sample_info["event_image"], sample_info["start_image"], sample_info["end_image"]
sample = dataset[split][index]
blur_image = sample["blur_image"]
start_image = np.array(sample["start_image"])
end_image = np.array(sample["end_image"])
lines = np.array(sample["lines"]).reshape(-1, 2, 2)
event_image = Image.fromarray(Event(events=sample["events"]).event2image())
event_image = event_image.resize(blur_image.size)
start_image = draw_lines(start_image, lines)
end_image = draw_lines(end_image, lines)
sample_info.update(new_sample_info)
sample_info["blur_image"] = blur_image
sample_info["event_image"] = event_image
sample_info["start_image"] = start_image
sample_info["end_image"] = end_image
logging.info("Update")
return blur_image, event_image, start_image, end_image
def main():
"""
Main
Args:
None
Returns:
None
"""
with gr.Blocks() as demo:
dataset_textbox = gr.Textbox(value="lh9171338/FE-Blurframe", label="Dataset name or path")
split_selector = gr.Dropdown(**default_split_selector_info)
index_slider = gr.Slider(**default_index_slider_info)
with gr.Row():
blur_image = gr.Image(label="Blurred Image")
event_image = gr.Image(label="Event Image")
start_image = gr.Image(label="Start Image")
end_image = gr.Image(label="End Image")
dataset_textbox.submit(
submit_callback,
dataset_textbox,
[split_selector, index_slider, blur_image, event_image, start_image, end_image],
)
split_selector.change(selector_change_callback, split_selector, [index_slider, blur_image, event_image, start_image, end_image])
index_slider.change(show_image, [split_selector, index_slider], [blur_image, event_image, start_image, end_image])
demo.load(
submit_callback,
dataset_textbox,
[split_selector, index_slider, blur_image, event_image, start_image, end_image],
)
demo.launch(share=False)
if __name__ == "__main__":
# set base logging config
fmt = "[%(asctime)s - %(levelname)s - %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(format=fmt, level=logging.INFO)
main()