vtt / app.py
Mr.Blue
update demo
fd039a1
import random
from pathlib import Path
from io import BytesIO
import gradio as gr
import jsonlines
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
from PIL import Image
CURRENT_DIR = Path(__file__).parent
LIST_FILE = "demo.jsonl"
STATES_ROOT = Path("states/")
REPEAT = 1
MAX_IMAGES_ROW = 6
TITLE = "VTT Demo"
START_TEXT = "Start"
PREV_TEXT = "Prev"
NEXT_TEXT = "Next"
CATEGORY_TEXT = "Category"
TOPIC_TEXT = "Topic"
TRANSFORMATIONS_TEXT = "Transformation Descriptions"
with jsonlines.open(LIST_FILE) as reader:
samples = list(reader)
samples_dict = {sample["id"]: sample for sample in samples}
def get_sample(annotation_id):
validate_annotation_id(annotation_id)
id = samples[annotation_id]["id"]
sample = samples_dict[id]
return sample
def get_texts(annotation_id):
annotation_id = validate_annotation_id(annotation_id)
sample = samples[annotation_id]
texts = [x['label'] for x in sample["annotation"]]
return texts
def get_transformations(annotation_id):
texts = get_texts(annotation_id)
return ", ".join([f"{i} -> {i+1}: {text}" for i, text in enumerate(texts)])
def show_figures(path_list, title=None, labels=None, show_indices=True):
from textwrap import wrap
n_img = len(path_list)
width, height = plt.figaspect(1)
plt.rcParams["savefig.bbox"] = "tight"
plt.rcParams["axes.linewidth"] = 0
plt.rcParams["axes.titlepad"] = 6
plt.rcParams["axes.titlesize"] = 12
plt.rcParams["font.family"] = "Helvetica"
plt.rcParams["axes.labelweight"] = "normal"
plt.rcParams["font.size"] = 12
plt.rcParams["figure.dpi"] = 100
plt.rcParams["savefig.dpi"] = 100
plt.rcParams["figure.titlesize"] = 18
# subplot(r,c) provide the no. of rows and columns
if n_img > MAX_IMAGES_ROW:
width = width / 2
height = height / 2
n_image_row = min(n_img, MAX_IMAGES_ROW)
n_row = (n_img - 1) // n_image_row + 1
fig, axarr = plt.subplots(
n_row, n_image_row, figsize=(width * n_image_row, height * n_row)
)
# use the created array to output your multiple images. In this case I have stacked 4 images vertically
for i in range(n_row * n_image_row):
# axarr[i].axis("off")
if n_row == 1:
ax = axarr[i]
else:
ax = axarr[i // n_image_row][i % n_image_row]
if i < len(path_list) and path_list[i].exists():
ax.imshow(mpimg.imread(path_list[i]))
if show_indices:
ax.set_title(f"{i}")
if labels is not None and labels[i]:
ax.set_xlabel(
"\n".join(wrap(f"{i-1}-{i}: {labels[i]}", width=width * 10))
)
ax.set_xticks([])
ax.set_yticks([])
plt.tight_layout()
def show_sample(sample, texts):
n_states = len(sample["annotation"]) + 1
state_path_list = [
STATES_ROOT / f"{sample['id']}_{n_states}_{i}.jpg"
for i in range(n_states)
]
show_figures(
state_path_list,
labels=[""] + texts,
)
def get_image(annotation_id):
sample = get_sample(annotation_id)
buf = BytesIO()
show_sample(sample, get_texts(annotation_id))
plt.savefig(buf, format="png")
buf.seek(0)
img = Image.open(buf)
plt.close()
return img
def get_category_topic(annotation_id):
sample = get_sample(annotation_id)
return sample["category"], sample["topic"]
def validate_annotation_id(annotation_id):
annotation_id = max(0, min(int(annotation_id), len(samples) - 1))
return annotation_id
def start(annotation_id):
annotation_id = validate_annotation_id(annotation_id)
category, topic = get_category_topic(annotation_id)
image = get_image(annotation_id)
return (
category,
topic,
image,
get_transformations(annotation_id),
)
def prev_sample(annotation_id):
annotation_id = validate_annotation_id(annotation_id - 1)
category, topic = get_category_topic(annotation_id)
image = get_image(annotation_id)
return (
annotation_id,
category,
topic,
image,
get_transformations(annotation_id),
)
def next_sample(annotation_id):
annotation_id = random.randint(0, len(samples) - 1)
annotation_id = validate_annotation_id(annotation_id + 1)
category, topic = get_category_topic(annotation_id)
image = get_image(annotation_id)
return (
annotation_id,
category,
topic,
image,
get_transformations(annotation_id),
)
def main():
with gr.Blocks(title="VTT") as demo:
gr.Markdown(f"## {TITLE}")
with gr.Row():
with gr.Column():
annotation_id = gr.Number(label="Annotation ID", visible=False)
start_button = gr.Button(START_TEXT, visible=False)
with gr.Row():
prev_button = gr.Button(PREV_TEXT, visible=False)
next_button = gr.Button(NEXT_TEXT)
category = gr.Text(label=CATEGORY_TEXT)
topic = gr.Text(label=TOPIC_TEXT)
image = gr.Image()
transformations = gr.Text(label=TRANSFORMATIONS_TEXT)
start_button.click(
start,
inputs=[annotation_id],
outputs=[
category,
topic,
image,
transformations,
],
)
prev_button.click(
prev_sample,
inputs=[annotation_id],
outputs=[
annotation_id,
category,
topic,
image,
transformations,
],
)
next_button.click(
next_sample,
inputs=[annotation_id],
outputs=[
annotation_id,
category,
topic,
image,
transformations,
],
)
# Add a hidden load button
demo.load(
None,
None,
None,
js="() => { const button = Array.from(document.querySelectorAll('button')).find(btn => btn.textContent.trim() === 'Start'); if (button) {button.click();} }"
)
# demo.launch(server_name="0.0.0.0", share=True)
demo.launch(server_name="0.0.0.0")
if __name__ == "__main__":
main()