Spaces:
Configuration error
Configuration error
| """ | |
| # Copyright (c) 2022, salesforce.com, inc. | |
| # All rights reserved. | |
| # SPDX-License-Identifier: BSD-3-Clause | |
| # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
| """ | |
| import random | |
| from collections import OrderedDict | |
| from functools import reduce | |
| from tkinter import N | |
| import streamlit as st | |
| from lavis.common.registry import registry | |
| from lavis.datasets.builders import dataset_zoo, load_dataset | |
| from lavis.datasets.builders.base_dataset_builder import load_dataset_config | |
| from PIL import Image | |
| IMAGE_LAYOUT = 3, 4 | |
| VIDEO_LAYOUT = 1, 2 | |
| PREV_STR = "Prev" | |
| NEXT_STR = "Next" | |
| def sample_dataset(dataset, indices): | |
| samples = [dataset.displ_item(idx) for idx in indices] | |
| return samples | |
| def get_concat_v(im1, im2): | |
| margin = 5 | |
| canvas_size = (im1.width + im2.width + margin, max(im1.height, im2.height)) | |
| canvas = Image.new("RGB", canvas_size, "White") | |
| canvas.paste(im1, (0, 0)) | |
| canvas.paste(im2, (im1.width + margin, 0)) | |
| return canvas | |
| def resize_img_w(raw_img, new_w=224): | |
| if isinstance(raw_img, list): | |
| resized_imgs = [resize_img_w(img, 196) for img in raw_img] | |
| # concatenate images | |
| resized_image = reduce(get_concat_v, resized_imgs) | |
| else: | |
| w, h = raw_img.size | |
| scaling_factor = new_w / w | |
| resized_image = raw_img.resize( | |
| (int(w * scaling_factor), int(h * scaling_factor)) | |
| ) | |
| return resized_image | |
| def get_visual_key(dataset): | |
| if "image" in dataset[0]: | |
| return "image" | |
| elif "image0" in dataset[0]: # NLVR2 dataset | |
| return "image" | |
| elif "video" in dataset[0]: | |
| return "video" | |
| else: | |
| raise ValueError("Visual key not found.") | |
| def gather_items(samples, exclude=[]): | |
| gathered = [] | |
| for s in samples: | |
| ns = OrderedDict() | |
| for k in s.keys(): | |
| if k not in exclude: | |
| ns[k] = s[k] | |
| gathered.append(ns) | |
| return gathered | |
| def load_dataset_cache(name): | |
| return load_dataset(name) | |
| def format_text(text): | |
| md = "\n\n".join([f"**{k}**: {v}" for k, v in text.items()]) | |
| return md | |
| def show_samples(dataset, offset=0, is_next=False): | |
| visual_key = get_visual_key(dataset) | |
| num_rows, num_cols = IMAGE_LAYOUT if visual_key == "image" else VIDEO_LAYOUT | |
| n_samples = num_rows * num_cols | |
| if not shuffle: | |
| if is_next: | |
| start = min(int(start_idx) + offset + n_samples, len(dataset) - n_samples) | |
| else: | |
| start = max(0, int(start_idx) + offset - n_samples) | |
| st.session_state.last_start = start | |
| end = min(start + n_samples, len(dataset)) | |
| indices = list(range(start, end)) | |
| else: | |
| indices = random.sample(range(len(dataset)), n_samples) | |
| samples = sample_dataset(dataset, indices) | |
| visual_info = ( | |
| iter([resize_img_w(s[visual_key]) for s in samples]) | |
| if visual_key == "image" | |
| # else iter([s[visual_key] for s in samples]) | |
| else iter([s["file"] for s in samples]) | |
| ) | |
| text_info = gather_items(samples, exclude=["image", "video"]) | |
| text_info = iter([format_text(s) for s in text_info]) | |
| st.markdown( | |
| """<hr style="height:1px;border:none;color:#c7ccd4;background-color:#c7ccd4;"/> """, | |
| unsafe_allow_html=True, | |
| ) | |
| for _ in range(num_rows): | |
| with st.container(): | |
| for col in st.columns(num_cols): | |
| # col.text(next(text_info)) | |
| # col.caption(next(text_info)) | |
| try: | |
| col.markdown(next(text_info)) | |
| if visual_key == "image": | |
| col.image(next(visual_info), use_column_width=True, clamp=True) | |
| elif visual_key == "video": | |
| col.markdown( | |
| "" | |
| ) | |
| except StopIteration: | |
| break | |
| st.markdown( | |
| """<hr style="height:1px;border:none;color:#c7ccd4;background-color:#c7ccd4;"/> """, | |
| unsafe_allow_html=True, | |
| ) | |
| st.session_state.n_display = n_samples | |
| if __name__ == "__main__": | |
| st.set_page_config( | |
| page_title="LAVIS Dataset Explorer", | |
| # layout="wide", | |
| initial_sidebar_state="expanded", | |
| ) | |
| dataset_name = st.sidebar.selectbox("Dataset:", dataset_zoo.get_names()) | |
| function = st.sidebar.selectbox("Function:", ["Browser"], index=0) | |
| if function == "Browser": | |
| shuffle = st.sidebar.selectbox("Shuffled:", [True, False], index=0) | |
| dataset = load_dataset_cache(dataset_name) | |
| split = st.sidebar.selectbox("Split:", dataset.keys()) | |
| dataset_len = len(dataset[split]) | |
| st.success( | |
| f"Loaded {dataset_name}/{split} with **{dataset_len}** records. **Image/video directory**: {dataset[split].vis_root}" | |
| ) | |
| if "last_dataset" not in st.session_state: | |
| st.session_state.last_dataset = dataset_name | |
| st.session_state.last_split = split | |
| if "last_start" not in st.session_state: | |
| st.session_state.last_start = 0 | |
| if "start_idx" not in st.session_state: | |
| st.session_state.start_idx = 0 | |
| if "shuffle" not in st.session_state: | |
| st.session_state.shuffle = shuffle | |
| if "first_run" not in st.session_state: | |
| st.session_state.first_run = True | |
| elif ( | |
| st.session_state.last_dataset != dataset_name | |
| or st.session_state.last_split != split | |
| ): | |
| st.session_state.first_run = True | |
| st.session_state.last_dataset = dataset_name | |
| st.session_state.last_split = split | |
| elif st.session_state.shuffle != shuffle: | |
| st.session_state.shuffle = shuffle | |
| st.session_state.first_run = True | |
| if not shuffle: | |
| n_col, p_col = st.columns([0.05, 1]) | |
| prev_button = n_col.button(PREV_STR) | |
| next_button = p_col.button(NEXT_STR) | |
| else: | |
| next_button = st.button(NEXT_STR) | |
| if not shuffle: | |
| start_idx = st.sidebar.text_input(f"Begin from (total {dataset_len})", 0) | |
| if not start_idx.isdigit(): | |
| st.error(f"Input to 'Begin from' must be digits, found {start_idx}.") | |
| else: | |
| if int(start_idx) != st.session_state.start_idx: | |
| st.session_state.start_idx = int(start_idx) | |
| st.session_state.last_start = int(start_idx) | |
| if prev_button: | |
| show_samples( | |
| dataset[split], | |
| offset=st.session_state.last_start - st.session_state.start_idx, | |
| is_next=False, | |
| ) | |
| if next_button: | |
| show_samples( | |
| dataset[split], | |
| offset=st.session_state.last_start - st.session_state.start_idx, | |
| is_next=True, | |
| ) | |
| if st.session_state.first_run: | |
| st.session_state.first_run = False | |
| show_samples( | |
| dataset[split], | |
| offset=st.session_state.last_start - st.session_state.start_idx, | |
| is_next=True, | |
| ) | |