# -*- encoding: utf-8 -*- """ @File : app.py @Time : 2025/8/29 15:25:00 @Author : lh9171338 @Version : 1.0 @Contact : 2909171338@qq.com """ import os import gradio as gr from PIL import Image import io import logging import matplotlib.pyplot as plt import numpy as np from datasets import load_dataset, DatasetDict from utils.event import Event dataset_dict = dict() dataset = None default_split_selector_info = dict( choices=["train", "test"], label="Split", value="train", interactive=False, ) default_index_slider_info = dict( minimum=0, maximum=1, step=1, label="Index", value=0, interactive=False, ) sample_info = dict( dataset=dataset, split="train", index=0, blur_image=None, event_image=None, start_image=None, end_image=None, ) def get_dataset(dataset_name): """ Get dataset Args: dataset_name (str): dataset name or path Returns: dataset (datasets.Dataset): dataset """ global dataset_dict if dataset_name in dataset_dict: dataset = dataset_dict[dataset_name] else: if os.path.exists(dataset_name): dataset = load_dataset(dataset_name, data_dir=dataset_name, trust_remote_code=True) else: dataset = load_dataset(dataset_name, trust_remote_code=True) dataset_dict[dataset_name] = dataset return dataset def submit_callback(dataset_name): """ Submit callback function Args: dataset_name (str): dataset name or path Returns: split_selector_info (dict): updated split selector info index_slider_info (dict): updated index slider info blur_image (PIL.Image): updated blur image event_image (PIL.Image): updated event image start_image (PIL.Image): updated start image end_image (PIL.Image): updated end image """ global dataset try: dataset = get_dataset(dataset_name) except Exception as e: dataset = None logging.error(f"Load dataset failed: {e}") split_selector_info = gr.update(**default_split_selector_info) index_slider_info = gr.update(**default_index_slider_info) return split_selector_info, index_slider_info, None, None, None, None if not isinstance(dataset, DatasetDict): dataset = {str(dataset.split): dataset} splits = list(dataset.keys()) split = splits[0] maximum = len(dataset[split]) - 1 index = 0 split_selector_info = gr.update(choices=splits, value=split, interactive=True) index_slider_info = gr.update(minimum=0, maximum=maximum, value=index, interactive=True) blur_image, event_image, start_image, end_image = show_image(split=split, index=index) return split_selector_info, index_slider_info, blur_image, event_image, start_image, end_image def selector_change_callback(split): """ Selector change callback function Args: split (str): selected split, value must be one of ["train", "test"] Returns: index_slider_info (dict): updated slider info blur_image (PIL.Image): updated blur image event_image (PIL.Image): updated event image start_image (PIL.Image): updated start image end_image (PIL.Image): updated end image """ global dataset if dataset is None: index_slider_info = gr.update(**default_index_slider_info) return index_slider_info, None, None, None, None maximum = len(dataset[split]) - 1 index = 0 index_slider_info = gr.update(minimum=0, maximum=maximum, value=index) blur_image, event_image, start_image, end_image = show_image(split=split, index=index) return index_slider_info, blur_image, event_image, start_image, end_image def draw_lines(image, lines): """ Draw lines on image Args: image (np.ndarray): input image lines (np.ndarray): list of lines, with shape [N, 2, 2] Returns: image (PIL.Image): drawn image """ height, width = image.shape[:2] fig = plt.figure() fig.set_size_inches(width / height, 1, forward=False) ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0]) ax.set_axis_off() fig.add_axes(ax) plt.xlim([-0.5, width - 0.5]) plt.ylim([height - 0.5, -0.5]) plt.imshow(image) for pts in lines: pts = pts - 0.5 plt.plot(pts[:, 0], pts[:, 1], color="orange", linewidth=0.5) plt.scatter(pts[:, 0], pts[:, 1], color="#33FFFF", s=1.2, edgecolors="none", zorder=5) buf = io.BytesIO() fig.savefig(buf, format="png", dpi=height, bbox_inches=0) buf.seek(0) plt.close(fig) image = Image.open(buf) return image def show_image(split, index): """ Show image Args: split (str): split name, value must be one of ["train", "test"] index (int): index of the sample Returns: blur_image (PIL.Image): drawn blurred image event_image (PIL.Image): drawn event image start_image (PIL.Image): drawn start image end_image (PIL.Image): drawn end image """ global dataset if dataset is None: return None, None, None, None global sample_info old_sample_info = dict( dataset=sample_info["dataset"], split=sample_info["split"], index=sample_info["index"], ) new_sample_info = dict(dataset=dataset, split=split, index=index) if old_sample_info == new_sample_info: # No need to update logging.info("No need to update") return sample_info["blur_image"], sample_info["event_image"], sample_info["start_image"], sample_info["end_image"] sample = dataset[split][index] blur_image = sample["blur_image"] start_image = np.array(sample["start_image"]) end_image = np.array(sample["end_image"]) lines = np.array(sample["lines"]).reshape(-1, 2, 2) event_image = Image.fromarray(Event(events=sample["events"]).event2image()) event_image = event_image.resize(blur_image.size) start_image = draw_lines(start_image, lines) end_image = draw_lines(end_image, lines) sample_info.update(new_sample_info) sample_info["blur_image"] = blur_image sample_info["event_image"] = event_image sample_info["start_image"] = start_image sample_info["end_image"] = end_image logging.info("Update") return blur_image, event_image, start_image, end_image def main(): """ Main Args: None Returns: None """ with gr.Blocks() as demo: dataset_textbox = gr.Textbox(value="lh9171338/FE-Blurframe", label="Dataset name or path") split_selector = gr.Dropdown(**default_split_selector_info) index_slider = gr.Slider(**default_index_slider_info) with gr.Row(): blur_image = gr.Image(label="Blurred Image") event_image = gr.Image(label="Event Image") start_image = gr.Image(label="Start Image") end_image = gr.Image(label="End Image") dataset_textbox.submit( submit_callback, dataset_textbox, [split_selector, index_slider, blur_image, event_image, start_image, end_image], ) split_selector.change(selector_change_callback, split_selector, [index_slider, blur_image, event_image, start_image, end_image]) index_slider.change(show_image, [split_selector, index_slider], [blur_image, event_image, start_image, end_image]) demo.load( submit_callback, dataset_textbox, [split_selector, index_slider, blur_image, event_image, start_image, end_image], ) demo.launch(share=False) if __name__ == "__main__": # set base logging config fmt = "[%(asctime)s - %(levelname)s - %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(format=fmt, level=logging.INFO) main()