Spaces:
Sleeping
Sleeping
| # -*- encoding: utf-8 -*- | |
| """ | |
| @File : app.py | |
| @Time : 2025/8/29 15:25:00 | |
| @Author : lh9171338 | |
| @Version : 1.0 | |
| @Contact : 2909171338@qq.com | |
| """ | |
| import os | |
| import gradio as gr | |
| from PIL import Image | |
| import io | |
| import logging | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from datasets import load_dataset, DatasetDict | |
| from utils.event import Event | |
| dataset_dict = dict() | |
| dataset = None | |
| default_split_selector_info = dict( | |
| choices=["train", "test"], | |
| label="Split", | |
| value="train", | |
| interactive=False, | |
| ) | |
| default_index_slider_info = dict( | |
| minimum=0, | |
| maximum=1, | |
| step=1, | |
| label="Index", | |
| value=0, | |
| interactive=False, | |
| ) | |
| sample_info = dict( | |
| dataset=dataset, | |
| split="train", | |
| index=0, | |
| blur_image=None, | |
| event_image=None, | |
| start_image=None, | |
| end_image=None, | |
| ) | |
| def get_dataset(dataset_name): | |
| """ | |
| Get dataset | |
| Args: | |
| dataset_name (str): dataset name or path | |
| Returns: | |
| dataset (datasets.Dataset): dataset | |
| """ | |
| global dataset_dict | |
| if dataset_name in dataset_dict: | |
| dataset = dataset_dict[dataset_name] | |
| else: | |
| if os.path.exists(dataset_name): | |
| dataset = load_dataset(dataset_name, data_dir=dataset_name, trust_remote_code=True) | |
| else: | |
| dataset = load_dataset(dataset_name, trust_remote_code=True) | |
| dataset_dict[dataset_name] = dataset | |
| return dataset | |
| def submit_callback(dataset_name): | |
| """ | |
| Submit callback function | |
| Args: | |
| dataset_name (str): dataset name or path | |
| Returns: | |
| split_selector_info (dict): updated split selector info | |
| index_slider_info (dict): updated index slider info | |
| blur_image (PIL.Image): updated blur image | |
| event_image (PIL.Image): updated event image | |
| start_image (PIL.Image): updated start image | |
| end_image (PIL.Image): updated end image | |
| """ | |
| global dataset | |
| try: | |
| dataset = get_dataset(dataset_name) | |
| except Exception as e: | |
| dataset = None | |
| logging.error(f"Load dataset failed: {e}") | |
| split_selector_info = gr.update(**default_split_selector_info) | |
| index_slider_info = gr.update(**default_index_slider_info) | |
| return split_selector_info, index_slider_info, None, None, None, None | |
| if not isinstance(dataset, DatasetDict): | |
| dataset = {str(dataset.split): dataset} | |
| splits = list(dataset.keys()) | |
| split = splits[0] | |
| maximum = len(dataset[split]) - 1 | |
| index = 0 | |
| split_selector_info = gr.update(choices=splits, value=split, interactive=True) | |
| index_slider_info = gr.update(minimum=0, maximum=maximum, value=index, interactive=True) | |
| blur_image, event_image, start_image, end_image = show_image(split=split, index=index) | |
| return split_selector_info, index_slider_info, blur_image, event_image, start_image, end_image | |
| def selector_change_callback(split): | |
| """ | |
| Selector change callback function | |
| Args: | |
| split (str): selected split, value must be one of ["train", "test"] | |
| Returns: | |
| index_slider_info (dict): updated slider info | |
| blur_image (PIL.Image): updated blur image | |
| event_image (PIL.Image): updated event image | |
| start_image (PIL.Image): updated start image | |
| end_image (PIL.Image): updated end image | |
| """ | |
| global dataset | |
| if dataset is None: | |
| index_slider_info = gr.update(**default_index_slider_info) | |
| return index_slider_info, None, None, None, None | |
| maximum = len(dataset[split]) - 1 | |
| index = 0 | |
| index_slider_info = gr.update(minimum=0, maximum=maximum, value=index) | |
| blur_image, event_image, start_image, end_image = show_image(split=split, index=index) | |
| return index_slider_info, blur_image, event_image, start_image, end_image | |
| def draw_lines(image, lines): | |
| """ | |
| Draw lines on image | |
| Args: | |
| image (np.ndarray): input image | |
| lines (np.ndarray): list of lines, with shape [N, 2, 2] | |
| Returns: | |
| image (PIL.Image): drawn image | |
| """ | |
| height, width = image.shape[:2] | |
| fig = plt.figure() | |
| fig.set_size_inches(width / height, 1, forward=False) | |
| ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0]) | |
| ax.set_axis_off() | |
| fig.add_axes(ax) | |
| plt.xlim([-0.5, width - 0.5]) | |
| plt.ylim([height - 0.5, -0.5]) | |
| plt.imshow(image) | |
| for pts in lines: | |
| pts = pts - 0.5 | |
| plt.plot(pts[:, 0], pts[:, 1], color="orange", linewidth=0.5) | |
| plt.scatter(pts[:, 0], pts[:, 1], color="#33FFFF", s=1.2, edgecolors="none", zorder=5) | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format="png", dpi=height, bbox_inches=0) | |
| buf.seek(0) | |
| plt.close(fig) | |
| image = Image.open(buf) | |
| return image | |
| def show_image(split, index): | |
| """ | |
| Show image | |
| Args: | |
| split (str): split name, value must be one of ["train", "test"] | |
| index (int): index of the sample | |
| Returns: | |
| blur_image (PIL.Image): drawn blurred image | |
| event_image (PIL.Image): drawn event image | |
| start_image (PIL.Image): drawn start image | |
| end_image (PIL.Image): drawn end image | |
| """ | |
| global dataset | |
| if dataset is None: | |
| return None, None, None, None | |
| global sample_info | |
| old_sample_info = dict( | |
| dataset=sample_info["dataset"], | |
| split=sample_info["split"], | |
| index=sample_info["index"], | |
| ) | |
| new_sample_info = dict(dataset=dataset, split=split, index=index) | |
| if old_sample_info == new_sample_info: # No need to update | |
| logging.info("No need to update") | |
| return sample_info["blur_image"], sample_info["event_image"], sample_info["start_image"], sample_info["end_image"] | |
| sample = dataset[split][index] | |
| blur_image = sample["blur_image"] | |
| start_image = np.array(sample["start_image"]) | |
| end_image = np.array(sample["end_image"]) | |
| lines = np.array(sample["lines"]).reshape(-1, 2, 2) | |
| event_image = Image.fromarray(Event(events=sample["events"]).event2image()) | |
| event_image = event_image.resize(blur_image.size) | |
| start_image = draw_lines(start_image, lines) | |
| end_image = draw_lines(end_image, lines) | |
| sample_info.update(new_sample_info) | |
| sample_info["blur_image"] = blur_image | |
| sample_info["event_image"] = event_image | |
| sample_info["start_image"] = start_image | |
| sample_info["end_image"] = end_image | |
| logging.info("Update") | |
| return blur_image, event_image, start_image, end_image | |
| def main(): | |
| """ | |
| Main | |
| Args: | |
| None | |
| Returns: | |
| None | |
| """ | |
| with gr.Blocks() as demo: | |
| dataset_textbox = gr.Textbox(value="lh9171338/FE-Blurframe", label="Dataset name or path") | |
| split_selector = gr.Dropdown(**default_split_selector_info) | |
| index_slider = gr.Slider(**default_index_slider_info) | |
| with gr.Row(): | |
| blur_image = gr.Image(label="Blurred Image") | |
| event_image = gr.Image(label="Event Image") | |
| start_image = gr.Image(label="Start Image") | |
| end_image = gr.Image(label="End Image") | |
| dataset_textbox.submit( | |
| submit_callback, | |
| dataset_textbox, | |
| [split_selector, index_slider, blur_image, event_image, start_image, end_image], | |
| ) | |
| split_selector.change(selector_change_callback, split_selector, [index_slider, blur_image, event_image, start_image, end_image]) | |
| index_slider.change(show_image, [split_selector, index_slider], [blur_image, event_image, start_image, end_image]) | |
| demo.load( | |
| submit_callback, | |
| dataset_textbox, | |
| [split_selector, index_slider, blur_image, event_image, start_image, end_image], | |
| ) | |
| demo.launch(share=False) | |
| if __name__ == "__main__": | |
| # set base logging config | |
| fmt = "[%(asctime)s - %(levelname)s - %(filename)s:%(lineno)s] %(message)s" | |
| logging.basicConfig(format=fmt, level=logging.INFO) | |
| main() | |