| | import random |
| | from pathlib import Path |
| | from io import BytesIO |
| |
|
| | import gradio as gr |
| | import jsonlines |
| | import matplotlib.image as mpimg |
| | import matplotlib.pyplot as plt |
| | from PIL import Image |
| |
|
| | CURRENT_DIR = Path(__file__).parent |
| | LIST_FILE = "demo.jsonl" |
| | STATES_ROOT = Path("states/") |
| | REPEAT = 1 |
| | MAX_IMAGES_ROW = 6 |
| | TITLE = "VTT Demo" |
| | START_TEXT = "Start" |
| | PREV_TEXT = "Prev" |
| | NEXT_TEXT = "Next" |
| | CATEGORY_TEXT = "Category" |
| | TOPIC_TEXT = "Topic" |
| | TRANSFORMATIONS_TEXT = "Transformation Descriptions" |
| |
|
| |
|
| | with jsonlines.open(LIST_FILE) as reader: |
| | samples = list(reader) |
| | samples_dict = {sample["id"]: sample for sample in samples} |
| |
|
| |
|
| | def get_sample(annotation_id): |
| | validate_annotation_id(annotation_id) |
| | id = samples[annotation_id]["id"] |
| | sample = samples_dict[id] |
| | return sample |
| |
|
| |
|
| | def get_texts(annotation_id): |
| | annotation_id = validate_annotation_id(annotation_id) |
| | sample = samples[annotation_id] |
| | texts = [x['label'] for x in sample["annotation"]] |
| | return texts |
| |
|
| |
|
| | def get_transformations(annotation_id): |
| | texts = get_texts(annotation_id) |
| | return ", ".join([f"{i} -> {i+1}: {text}" for i, text in enumerate(texts)]) |
| |
|
| |
|
| | def show_figures(path_list, title=None, labels=None, show_indices=True): |
| | from textwrap import wrap |
| |
|
| | n_img = len(path_list) |
| | width, height = plt.figaspect(1) |
| |
|
| | plt.rcParams["savefig.bbox"] = "tight" |
| | plt.rcParams["axes.linewidth"] = 0 |
| | plt.rcParams["axes.titlepad"] = 6 |
| | plt.rcParams["axes.titlesize"] = 12 |
| | plt.rcParams["font.family"] = "Helvetica" |
| | plt.rcParams["axes.labelweight"] = "normal" |
| | plt.rcParams["font.size"] = 12 |
| | plt.rcParams["figure.dpi"] = 100 |
| | plt.rcParams["savefig.dpi"] = 100 |
| | plt.rcParams["figure.titlesize"] = 18 |
| |
|
| | |
| |
|
| | if n_img > MAX_IMAGES_ROW: |
| | width = width / 2 |
| | height = height / 2 |
| |
|
| | n_image_row = min(n_img, MAX_IMAGES_ROW) |
| | n_row = (n_img - 1) // n_image_row + 1 |
| | fig, axarr = plt.subplots( |
| | n_row, n_image_row, figsize=(width * n_image_row, height * n_row) |
| | ) |
| | |
| | for i in range(n_row * n_image_row): |
| | |
| | if n_row == 1: |
| | ax = axarr[i] |
| | else: |
| | ax = axarr[i // n_image_row][i % n_image_row] |
| | if i < len(path_list) and path_list[i].exists(): |
| | ax.imshow(mpimg.imread(path_list[i])) |
| | if show_indices: |
| | ax.set_title(f"{i}") |
| | if labels is not None and labels[i]: |
| | ax.set_xlabel( |
| | "\n".join(wrap(f"{i-1}-{i}: {labels[i]}", width=width * 10)) |
| | ) |
| | ax.set_xticks([]) |
| | ax.set_yticks([]) |
| |
|
| | plt.tight_layout() |
| |
|
| |
|
| | def show_sample(sample, texts): |
| | n_states = len(sample["annotation"]) + 1 |
| | state_path_list = [ |
| | STATES_ROOT / f"{sample['id']}_{n_states}_{i}.jpg" |
| | for i in range(n_states) |
| | ] |
| | show_figures( |
| | state_path_list, |
| | labels=[""] + texts, |
| | ) |
| |
|
| |
|
| | def get_image(annotation_id): |
| | sample = get_sample(annotation_id) |
| | buf = BytesIO() |
| | show_sample(sample, get_texts(annotation_id)) |
| | plt.savefig(buf, format="png") |
| | buf.seek(0) |
| | img = Image.open(buf) |
| | plt.close() |
| | return img |
| |
|
| |
|
| | def get_category_topic(annotation_id): |
| | sample = get_sample(annotation_id) |
| | return sample["category"], sample["topic"] |
| |
|
| |
|
| | def validate_annotation_id(annotation_id): |
| | annotation_id = max(0, min(int(annotation_id), len(samples) - 1)) |
| | return annotation_id |
| |
|
| |
|
| | def start(annotation_id): |
| | annotation_id = validate_annotation_id(annotation_id) |
| | category, topic = get_category_topic(annotation_id) |
| | image = get_image(annotation_id) |
| | return ( |
| | category, |
| | topic, |
| | image, |
| | get_transformations(annotation_id), |
| | ) |
| |
|
| |
|
| | def prev_sample(annotation_id): |
| | annotation_id = validate_annotation_id(annotation_id - 1) |
| | category, topic = get_category_topic(annotation_id) |
| | image = get_image(annotation_id) |
| | return ( |
| | annotation_id, |
| | category, |
| | topic, |
| | image, |
| | get_transformations(annotation_id), |
| | ) |
| |
|
| |
|
| | def next_sample(annotation_id): |
| | annotation_id = random.randint(0, len(samples) - 1) |
| | annotation_id = validate_annotation_id(annotation_id + 1) |
| | category, topic = get_category_topic(annotation_id) |
| | image = get_image(annotation_id) |
| | return ( |
| | annotation_id, |
| | category, |
| | topic, |
| | image, |
| | get_transformations(annotation_id), |
| | ) |
| |
|
| |
|
| | def main(): |
| |
|
| | with gr.Blocks(title="VTT") as demo: |
| | gr.Markdown(f"## {TITLE}") |
| |
|
| | with gr.Row(): |
| | with gr.Column(): |
| | annotation_id = gr.Number(label="Annotation ID", visible=False) |
| | start_button = gr.Button(START_TEXT, visible=False) |
| | with gr.Row(): |
| | prev_button = gr.Button(PREV_TEXT, visible=False) |
| | next_button = gr.Button(NEXT_TEXT) |
| | category = gr.Text(label=CATEGORY_TEXT) |
| | topic = gr.Text(label=TOPIC_TEXT) |
| |
|
| | image = gr.Image() |
| |
|
| | transformations = gr.Text(label=TRANSFORMATIONS_TEXT) |
| | |
| | start_button.click( |
| | start, |
| | inputs=[annotation_id], |
| | outputs=[ |
| | category, |
| | topic, |
| | image, |
| | transformations, |
| | ], |
| | ) |
| | prev_button.click( |
| | prev_sample, |
| | inputs=[annotation_id], |
| | outputs=[ |
| | annotation_id, |
| | category, |
| | topic, |
| | image, |
| | transformations, |
| | ], |
| | ) |
| | next_button.click( |
| | next_sample, |
| | inputs=[annotation_id], |
| | outputs=[ |
| | annotation_id, |
| | category, |
| | topic, |
| | image, |
| | transformations, |
| | ], |
| | ) |
| | |
| | |
| | demo.load( |
| | None, |
| | None, |
| | None, |
| | js="() => { const button = Array.from(document.querySelectorAll('button')).find(btn => btn.textContent.trim() === 'Start'); if (button) {button.click();} }" |
| | ) |
| |
|
| | |
| | demo.launch(server_name="0.0.0.0") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|