| import sys | |
| sys.path.append(".") | |
| import streamlit as st | |
| import pandas as pd | |
| from PIL import Image | |
| from model_loader import * | |
| from datasets import load_dataset | |
| # load dataset | |
| #ds = load_dataset("test") | |
| # ds = load_dataset("HuggingFaceM4/VQAv2", split="validation", cache_dir="cache", streaming=False) | |
| df = pd.read_json('vqa_samples.json', orient="columns") | |
| # define selector | |
| model_name = st.sidebar.selectbox( | |
| "Select a model: ", | |
| ('vilt', 'git', 'blip', 'vbert') | |
| ) | |
| image_selector_unspecific = st.number_input( | |
| "Select an image id: ", | |
| 0, len(df) | |
| ) | |
| # select and display | |
| #sample = ds[image_selector_unspecific] | |
| sample = df.iloc[image_selector_unspecific] | |
| img_path = sample['img_path'] | |
| image = Image.open(f'images/{img_path}.jpg') | |
| st.image(image, channels="RGB") | |
| question = sample['ques'] | |
| label = sample['label'] | |
| # inference | |
| question = st.text_input(f"Ask the model a question related to the image: \n" | |
| f"(e.g. \"{sample['question']}\")") | |
| args = load_model(model_name) # TODO: cache | |
| answer = get_answer(args, image, question, model_name) | |
| st.text(f"Answer by {model_name}: {answer}") | |
| st.text(f"Ground truth: {label}") |