import random from pathlib import Path from io import BytesIO import gradio as gr import jsonlines import matplotlib.image as mpimg import matplotlib.pyplot as plt from PIL import Image CURRENT_DIR = Path(__file__).parent LIST_FILE = "demo.jsonl" STATES_ROOT = Path("states/") REPEAT = 1 MAX_IMAGES_ROW = 6 TITLE = "VTT Demo" START_TEXT = "Start" PREV_TEXT = "Prev" NEXT_TEXT = "Next" CATEGORY_TEXT = "Category" TOPIC_TEXT = "Topic" TRANSFORMATIONS_TEXT = "Transformation Descriptions" with jsonlines.open(LIST_FILE) as reader: samples = list(reader) samples_dict = {sample["id"]: sample for sample in samples} def get_sample(annotation_id): validate_annotation_id(annotation_id) id = samples[annotation_id]["id"] sample = samples_dict[id] return sample def get_texts(annotation_id): annotation_id = validate_annotation_id(annotation_id) sample = samples[annotation_id] texts = [x['label'] for x in sample["annotation"]] return texts def get_transformations(annotation_id): texts = get_texts(annotation_id) return ", ".join([f"{i} -> {i+1}: {text}" for i, text in enumerate(texts)]) def show_figures(path_list, title=None, labels=None, show_indices=True): from textwrap import wrap n_img = len(path_list) width, height = plt.figaspect(1) plt.rcParams["savefig.bbox"] = "tight" plt.rcParams["axes.linewidth"] = 0 plt.rcParams["axes.titlepad"] = 6 plt.rcParams["axes.titlesize"] = 12 plt.rcParams["font.family"] = "Helvetica" plt.rcParams["axes.labelweight"] = "normal" plt.rcParams["font.size"] = 12 plt.rcParams["figure.dpi"] = 100 plt.rcParams["savefig.dpi"] = 100 plt.rcParams["figure.titlesize"] = 18 # subplot(r,c) provide the no. of rows and columns if n_img > MAX_IMAGES_ROW: width = width / 2 height = height / 2 n_image_row = min(n_img, MAX_IMAGES_ROW) n_row = (n_img - 1) // n_image_row + 1 fig, axarr = plt.subplots( n_row, n_image_row, figsize=(width * n_image_row, height * n_row) ) # use the created array to output your multiple images. In this case I have stacked 4 images vertically for i in range(n_row * n_image_row): # axarr[i].axis("off") if n_row == 1: ax = axarr[i] else: ax = axarr[i // n_image_row][i % n_image_row] if i < len(path_list) and path_list[i].exists(): ax.imshow(mpimg.imread(path_list[i])) if show_indices: ax.set_title(f"{i}") if labels is not None and labels[i]: ax.set_xlabel( "\n".join(wrap(f"{i-1}-{i}: {labels[i]}", width=width * 10)) ) ax.set_xticks([]) ax.set_yticks([]) plt.tight_layout() def show_sample(sample, texts): n_states = len(sample["annotation"]) + 1 state_path_list = [ STATES_ROOT / f"{sample['id']}_{n_states}_{i}.jpg" for i in range(n_states) ] show_figures( state_path_list, labels=[""] + texts, ) def get_image(annotation_id): sample = get_sample(annotation_id) buf = BytesIO() show_sample(sample, get_texts(annotation_id)) plt.savefig(buf, format="png") buf.seek(0) img = Image.open(buf) plt.close() return img def get_category_topic(annotation_id): sample = get_sample(annotation_id) return sample["category"], sample["topic"] def validate_annotation_id(annotation_id): annotation_id = max(0, min(int(annotation_id), len(samples) - 1)) return annotation_id def start(annotation_id): annotation_id = validate_annotation_id(annotation_id) category, topic = get_category_topic(annotation_id) image = get_image(annotation_id) return ( category, topic, image, get_transformations(annotation_id), ) def prev_sample(annotation_id): annotation_id = validate_annotation_id(annotation_id - 1) category, topic = get_category_topic(annotation_id) image = get_image(annotation_id) return ( annotation_id, category, topic, image, get_transformations(annotation_id), ) def next_sample(annotation_id): annotation_id = random.randint(0, len(samples) - 1) annotation_id = validate_annotation_id(annotation_id + 1) category, topic = get_category_topic(annotation_id) image = get_image(annotation_id) return ( annotation_id, category, topic, image, get_transformations(annotation_id), ) def main(): with gr.Blocks(title="VTT") as demo: gr.Markdown(f"## {TITLE}") with gr.Row(): with gr.Column(): annotation_id = gr.Number(label="Annotation ID", visible=False) start_button = gr.Button(START_TEXT, visible=False) with gr.Row(): prev_button = gr.Button(PREV_TEXT, visible=False) next_button = gr.Button(NEXT_TEXT) category = gr.Text(label=CATEGORY_TEXT) topic = gr.Text(label=TOPIC_TEXT) image = gr.Image() transformations = gr.Text(label=TRANSFORMATIONS_TEXT) start_button.click( start, inputs=[annotation_id], outputs=[ category, topic, image, transformations, ], ) prev_button.click( prev_sample, inputs=[annotation_id], outputs=[ annotation_id, category, topic, image, transformations, ], ) next_button.click( next_sample, inputs=[annotation_id], outputs=[ annotation_id, category, topic, image, transformations, ], ) # Add a hidden load button demo.load( None, None, None, js="() => { const button = Array.from(document.querySelectorAll('button')).find(btn => btn.textContent.trim() === 'Start'); if (button) {button.click();} }" ) # demo.launch(server_name="0.0.0.0", share=True) demo.launch(server_name="0.0.0.0") if __name__ == "__main__": main()