| |
| from io import BytesIO |
| from math import ceil |
|
|
| import clip |
| from multilingual_clip import legacy_multilingual_clip, pt_multilingual_clip |
| import numpy as np |
| import pandas as pd |
| from PIL import Image |
| import requests |
| import streamlit as st |
| import torch |
| from torchvision.transforms import ToPILImage |
| from transformers import AutoTokenizer, AutoModel, BertTokenizer |
|
|
| from CLIP_Explainability.clip_ import load, tokenize |
| from CLIP_Explainability.rn_cam import ( |
| |
| interpret_rn_overlapped, |
| rn_perword_relevance, |
| ) |
| from CLIP_Explainability.vit_cam import ( |
| |
| vit_perword_relevance, |
| interpret_vit_overlapped, |
| ) |
|
|
| from pytorch_grad_cam.grad_cam import GradCAM |
|
|
| RUN_LITE = True |
|
|
| MAX_IMG_WIDTH = 500 |
| MAX_IMG_HEIGHT = 800 |
|
|
| st.set_page_config(layout="wide") |
|
|
|
|
| |
| def find_best_matches(text_features, image_features, image_ids): |
| |
| similarities = (image_features @ text_features.T).squeeze(1) |
|
|
| |
| best_image_idx = (-similarities).argsort() |
|
|
| |
| return [[image_ids[i], similarities[i].item()] for i in best_image_idx] |
|
|
|
|
| |
| def encode_search_query(search_query, model_type): |
| with torch.no_grad(): |
| |
| if model_type == "M-CLIP (multilingual ViT)": |
| text_encoded = st.session_state.ml_model.forward( |
| search_query, st.session_state.ml_tokenizer |
| ) |
| text_encoded /= text_encoded.norm(dim=-1, keepdim=True) |
| elif model_type == "J-CLIP (日本語 ViT)": |
| t_text = st.session_state.ja_tokenizer( |
| search_query, |
| padding=True, |
| return_tensors="pt", |
| device=st.session_state.device, |
| ) |
| text_encoded = st.session_state.ja_model.get_text_features(**t_text) |
| text_encoded /= text_encoded.norm(dim=-1, keepdim=True) |
| else: |
| text_encoded = st.session_state.rn_model(search_query) |
| text_encoded /= text_encoded.norm(dim=-1, keepdim=True) |
|
|
| |
| return text_encoded.to(st.session_state.device) |
|
|
|
|
| def clip_search(search_query): |
| if st.session_state.search_field_value != search_query: |
| st.session_state.search_field_value = search_query |
|
|
| model_type = st.session_state.active_model |
|
|
| if len(search_query) >= 1: |
| text_features = encode_search_query(search_query, model_type) |
|
|
| |
| |
|
|
| |
| if model_type == "M-CLIP (multilingual ViT)": |
| matches = find_best_matches( |
| text_features, |
| st.session_state.ml_image_features, |
| st.session_state.image_ids, |
| ) |
| elif model_type == "J-CLIP (日本語 ViT)": |
| matches = find_best_matches( |
| text_features, |
| st.session_state.ja_image_features, |
| st.session_state.image_ids, |
| ) |
| else: |
| matches = find_best_matches( |
| text_features, |
| st.session_state.rn_image_features, |
| st.session_state.image_ids, |
| ) |
|
|
| st.session_state.search_image_ids = [match[0] for match in matches] |
| st.session_state.search_image_scores = {match[0]: match[1] for match in matches} |
|
|
|
|
| def string_search(): |
| st.session_state.disable_uploader = ( |
| RUN_LITE and st.session_state.active_model == "M-CLIP (multilingual ViT)" |
| ) |
|
|
| if "search_field_value" in st.session_state: |
| clip_search(st.session_state.search_field_value) |
|
|
|
|
| def load_image_features(): |
| |
| if st.session_state.vision_mode == "tiled": |
| ml_image_features = np.load("./image_features/tiled_ml_features.npy") |
| ja_image_features = np.load("./image_features/tiled_ja_features.npy") |
| rn_image_features = np.load("./image_features/tiled_rn_features.npy") |
| elif st.session_state.vision_mode == "stretched": |
| ml_image_features = np.load("./image_features/resized_ml_features.npy") |
| ja_image_features = np.load("./image_features/resized_ja_features.npy") |
| rn_image_features = np.load("./image_features/resized_rn_features.npy") |
| else: |
| ml_image_features = np.load("./image_features/cropped_ml_features.npy") |
| ja_image_features = np.load("./image_features/cropped_ja_features.npy") |
| rn_image_features = np.load("./image_features/cropped_rn_features.npy") |
|
|
| |
| device = st.session_state.device |
| if device == "cpu": |
| ml_image_features = torch.from_numpy(ml_image_features).float().to(device) |
| ja_image_features = torch.from_numpy(ja_image_features).float().to(device) |
| rn_image_features = torch.from_numpy(rn_image_features).float().to(device) |
| else: |
| ml_image_features = torch.from_numpy(ml_image_features).to(device) |
| ja_image_features = torch.from_numpy(ja_image_features).to(device) |
| rn_image_features = torch.from_numpy(rn_image_features).to(device) |
|
|
| st.session_state.ml_image_features = ml_image_features / ml_image_features.norm( |
| dim=-1, keepdim=True |
| ) |
| st.session_state.ja_image_features = ja_image_features / ja_image_features.norm( |
| dim=-1, keepdim=True |
| ) |
| st.session_state.rn_image_features = rn_image_features / rn_image_features.norm( |
| dim=-1, keepdim=True |
| ) |
|
|
| string_search() |
|
|
|
|
| def init(): |
| st.session_state.current_page = 1 |
|
|
| |
| device = "cpu" |
|
|
| st.session_state.device = device |
|
|
| |
|
|
| with st.spinner("Loading models and data, please wait..."): |
| ml_model_name = "M-CLIP/XLM-Roberta-Large-Vit-B-16Plus" |
| ml_model_path = "./models/vit_b_16_plus_240-laion400m_e32-699c4b84.pt" |
|
|
| if not RUN_LITE: |
| st.session_state.ml_image_model, st.session_state.ml_image_preprocess = ( |
| load(ml_model_path, device=device, jit=False) |
| ) |
|
|
| st.session_state.ml_model = ( |
| pt_multilingual_clip.MultilingualCLIP.from_pretrained(ml_model_name) |
| ).to(device) |
| st.session_state.ml_tokenizer = AutoTokenizer.from_pretrained(ml_model_name) |
|
|
| ja_model_name = "hakuhodo-tech/japanese-clip-vit-h-14-bert-wider" |
| ja_model_path = "./models/ViT-H-14-laion2B-s32B-b79K.bin" |
|
|
| st.session_state.ja_image_model, st.session_state.ja_image_preprocess = load( |
| ja_model_path, device=device, jit=False |
| ) |
|
|
| st.session_state.ja_model = AutoModel.from_pretrained( |
| ja_model_name, trust_remote_code=True |
| ).to(device) |
| st.session_state.ja_tokenizer = AutoTokenizer.from_pretrained( |
| ja_model_name, trust_remote_code=True |
| ) |
|
|
| st.session_state.rn_image_model, st.session_state.rn_image_preprocess = ( |
| clip.load("RN50x4", device=device) |
| ) |
|
|
| st.session_state.rn_model = legacy_multilingual_clip.load_model( |
| "M-BERT-Base-69" |
| ).to(device) |
| st.session_state.rn_tokenizer = BertTokenizer.from_pretrained( |
| "bert-base-multilingual-cased" |
| ) |
|
|
| |
| st.session_state.images_info = pd.read_csv("./metadata.csv") |
| st.session_state.images_info.set_index("filename", inplace=True) |
|
|
| with open("./images_list.txt", "r", encoding="utf-8") as images_list: |
| st.session_state.image_ids = list(images_list.read().strip().split("\n")) |
|
|
| st.session_state.active_model = "J-CLIP (日本語 ViT)" |
|
|
| st.session_state.vision_mode = "tiled" |
| st.session_state.search_image_ids = [] |
| st.session_state.search_image_scores = {} |
| st.session_state.text_table_df = None |
| st.session_state.disable_uploader = ( |
| RUN_LITE and st.session_state.active_model == "Legacy (multilingual ResNet)" |
| ) |
|
|
| with st.spinner("Loading models and data, please wait..."): |
| load_image_features() |
|
|
|
|
| if "images_info" not in st.session_state: |
| init() |
|
|
|
|
| def get_overlay_vis(image, img_dim, image_model): |
| orig_img_dims = image.size |
|
|
| |
| tile_behavior = None |
|
|
| if st.session_state.vision_mode == "tiled": |
| scaled_dims = [img_dim, img_dim] |
|
|
| if orig_img_dims[0] > orig_img_dims[1]: |
| scale_ratio = round(orig_img_dims[0] / orig_img_dims[1]) |
| if scale_ratio > 1: |
| scaled_dims = [scale_ratio * img_dim, img_dim] |
| tile_behavior = "width" |
| elif orig_img_dims[0] < orig_img_dims[1]: |
| scale_ratio = round(orig_img_dims[1] / orig_img_dims[0]) |
| if scale_ratio > 1: |
| scaled_dims = [img_dim, scale_ratio * img_dim] |
| tile_behavior = "height" |
|
|
| resized_image = image.resize(scaled_dims, Image.LANCZOS) |
|
|
| if tile_behavior == "width": |
| image_tiles = [] |
| for x in range(0, scale_ratio): |
| box = (x * img_dim, 0, (x + 1) * img_dim, img_dim) |
| image_tiles.append(resized_image.crop(box)) |
|
|
| elif tile_behavior == "height": |
| image_tiles = [] |
| for y in range(0, scale_ratio): |
| box = (0, y * img_dim, img_dim, (y + 1) * img_dim) |
| image_tiles.append(resized_image.crop(box)) |
|
|
| else: |
| image_tiles = [resized_image] |
|
|
| elif st.session_state.vision_mode == "stretched": |
| image_tiles = [image.resize((img_dim, img_dim), Image.LANCZOS)] |
|
|
| else: |
| if orig_img_dims[0] > orig_img_dims[1]: |
| scale_factor = orig_img_dims[0] / orig_img_dims[1] |
| resized_img_dims = (round(scale_factor * img_dim), img_dim) |
| resized_img = image.resize(resized_img_dims) |
| elif orig_img_dims[0] < orig_img_dims[1]: |
| scale_factor = orig_img_dims[1] / orig_img_dims[0] |
| resized_img_dims = (img_dim, round(scale_factor * img_dim)) |
| else: |
| resized_img_dims = (img_dim, img_dim) |
|
|
| resized_img = image.resize(resized_img_dims) |
|
|
| left = round((resized_img_dims[0] - img_dim) / 2) |
| top = round((resized_img_dims[1] - img_dim) / 2) |
| x_right = round(resized_img_dims[0] - img_dim) - left |
| x_bottom = round(resized_img_dims[1] - img_dim) - top |
| right = resized_img_dims[0] - x_right |
| bottom = resized_img_dims[1] - x_bottom |
|
|
| |
| image_tiles = [resized_img.crop((left, top, right, bottom))] |
|
|
| image_visualizations = [] |
| image_features = [] |
| image_similarities = [] |
|
|
| if st.session_state.active_model == "M-CLIP (multilingual ViT)": |
| text_features = st.session_state.ml_model.forward( |
| st.session_state.search_field_value, st.session_state.ml_tokenizer |
| ) |
|
|
| if st.session_state.device == "cpu": |
| text_features = text_features.float().to(st.session_state.device) |
| else: |
| text_features = text_features.to(st.session_state.device) |
|
|
| for altered_image in image_tiles: |
| p_image = ( |
| st.session_state.ml_image_preprocess(altered_image) |
| .unsqueeze(0) |
| .to(st.session_state.device) |
| ) |
|
|
| vis_t, img_feats, similarity = interpret_vit_overlapped( |
| p_image.type(image_model.dtype), |
| text_features.type(image_model.dtype), |
| image_model.visual, |
| st.session_state.device, |
| img_dim=img_dim, |
| ) |
|
|
| image_visualizations.append(vis_t) |
| image_features.append(img_feats) |
| image_similarities.append(similarity.item()) |
|
|
| elif st.session_state.active_model == "J-CLIP (日本語 ViT)": |
| t_text = st.session_state.ja_tokenizer( |
| st.session_state.search_field_value, |
| return_tensors="pt", |
| device=st.session_state.device, |
| ) |
|
|
| text_features = st.session_state.ja_model.get_text_features(**t_text) |
|
|
| if st.session_state.device == "cpu": |
| text_features = text_features.float().to(st.session_state.device) |
| else: |
| text_features = text_features.to(st.session_state.device) |
|
|
| for altered_image in image_tiles: |
| p_image = ( |
| st.session_state.ja_image_preprocess(altered_image) |
| .unsqueeze(0) |
| .to(st.session_state.device) |
| ) |
|
|
| vis_t, img_feats, similarity = interpret_vit_overlapped( |
| p_image.type(image_model.dtype), |
| text_features.type(image_model.dtype), |
| image_model.visual, |
| st.session_state.device, |
| img_dim=img_dim, |
| ) |
|
|
| image_visualizations.append(vis_t) |
| image_features.append(img_feats) |
| image_similarities.append(similarity.item()) |
|
|
| else: |
| text_features = st.session_state.rn_model(st.session_state.search_field_value) |
|
|
| if st.session_state.device == "cpu": |
| text_features = text_features.float().to(st.session_state.device) |
| else: |
| text_features = text_features.to(st.session_state.device) |
|
|
| for altered_image in image_tiles: |
| p_image = ( |
| st.session_state.rn_image_preprocess(altered_image) |
| .unsqueeze(0) |
| .to(st.session_state.device) |
| ) |
|
|
| vis_t = interpret_rn_overlapped( |
| p_image.type(image_model.dtype), |
| text_features.type(image_model.dtype), |
| image_model.visual, |
| GradCAM, |
| st.session_state.device, |
| img_dim=img_dim, |
| ) |
|
|
| text_features_norm = text_features.norm(dim=-1, keepdim=True) |
| text_features_new = text_features / text_features_norm |
|
|
| image_feats = image_model.encode_image(p_image.type(image_model.dtype)) |
| image_feats_norm = image_feats.norm(dim=-1, keepdim=True) |
| image_feats_new = image_feats / image_feats_norm |
|
|
| similarity = image_feats_new[0].dot(text_features_new[0]) |
|
|
| image_visualizations.append(vis_t) |
| image_features.append(p_image) |
| image_similarities.append(similarity.item()) |
|
|
| transform = ToPILImage() |
|
|
| vis_images = [transform(vis_t) for vis_t in image_visualizations] |
|
|
| if st.session_state.vision_mode == "cropped": |
| resized_img.paste(vis_images[0], (left, top)) |
| vis_images = [resized_img] |
|
|
| if orig_img_dims[0] > orig_img_dims[1]: |
| scale_factor = MAX_IMG_WIDTH / orig_img_dims[0] |
| scaled_dims = [MAX_IMG_WIDTH, int(orig_img_dims[1] * scale_factor)] |
| else: |
| scale_factor = MAX_IMG_HEIGHT / orig_img_dims[1] |
| scaled_dims = [int(orig_img_dims[0] * scale_factor), MAX_IMG_HEIGHT] |
|
|
| if tile_behavior == "width": |
| vis_image = Image.new("RGB", (len(vis_images) * img_dim, img_dim)) |
| for x, v_img in enumerate(vis_images): |
| vis_image.paste(v_img, (x * img_dim, 0)) |
| activations_image = vis_image.resize(scaled_dims) |
|
|
| elif tile_behavior == "height": |
| vis_image = Image.new("RGB", (img_dim, len(vis_images) * img_dim)) |
| for y, v_img in enumerate(vis_images): |
| vis_image.paste(v_img, (0, y * img_dim)) |
| activations_image = vis_image.resize(scaled_dims) |
|
|
| else: |
| activations_image = vis_images[0].resize(scaled_dims) |
|
|
| return activations_image, image_features, np.mean(image_similarities) |
|
|
|
|
| def visualize_gradcam(image): |
| if "search_field_value" not in st.session_state: |
| return |
|
|
| header_cols = st.columns([80, 20], vertical_alignment="bottom") |
| with header_cols[0]: |
| st.title("Image + query activation gradients") |
| with header_cols[1]: |
| if st.button("Close"): |
| st.rerun() |
|
|
| if st.session_state.active_model == "M-CLIP (multilingual ViT)": |
| img_dim = 240 |
| image_model = st.session_state.ml_image_model |
| |
| tokenized_text = st.session_state.ml_tokenizer.tokenize( |
| st.session_state.search_field_value |
| ) |
| elif st.session_state.active_model == "Legacy (multilingual ResNet)": |
| img_dim = 288 |
| image_model = st.session_state.rn_image_model |
| |
| tokenized_text = st.session_state.rn_tokenizer.tokenize( |
| st.session_state.search_field_value |
| ) |
| else: |
| img_dim = 224 |
| image_model = st.session_state.ja_image_model |
| |
| tokenized_text = st.session_state.ja_tokenizer.tokenize( |
| st.session_state.search_field_value |
| ) |
|
|
| st.image(image) |
|
|
| with st.spinner("Calculating..."): |
| |
|
|
| activations_image, image_features, similarity_score = get_overlay_vis( |
| image, img_dim, image_model |
| ) |
|
|
| st.markdown( |
| f"**Query text:** {st.session_state.search_field_value} | **Approx. image relevance:** {round(similarity_score.item(), 3)}" |
| ) |
|
|
| st.image(activations_image) |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| tokenized_text = [ |
| tok.replace("▁", "").replace("#", "") for tok in tokenized_text if tok != "▁" |
| ] |
| tokenized_text = [ |
| tok |
| for tok in tokenized_text |
| if tok |
| not in ["s", "ed", "a", "the", "an", "ing", "て", "に", "の", "は", "と", "た"] |
| ] |
|
|
| if ( |
| len(tokenized_text) > 1 |
| and len(tokenized_text) < 25 |
| and st.button( |
| "Calculate text importance (may take some time)", |
| ) |
| ): |
| scores_per_token = {} |
|
|
| progress_text = f"Processing {len(tokenized_text)} text tokens" |
| progress_bar = st.progress(0.0, text=progress_text) |
|
|
| for t, tok in enumerate(tokenized_text): |
| token = tok |
|
|
| for img_feats in image_features: |
| if st.session_state.active_model == "Legacy (multilingual ResNet)": |
| word_rel = rn_perword_relevance( |
| img_feats, |
| st.session_state.search_field_value, |
| image_model, |
| tokenize, |
| GradCAM, |
| st.session_state.device, |
| token, |
| data_only=True, |
| img_dim=img_dim, |
| ) |
| else: |
| word_rel = vit_perword_relevance( |
| img_feats, |
| st.session_state.search_field_value, |
| image_model, |
| tokenize, |
| st.session_state.device, |
| token, |
| img_dim=img_dim, |
| ) |
| avg_score = np.mean(word_rel) |
| if avg_score == 0 or np.isnan(avg_score): |
| continue |
|
|
| if token not in scores_per_token: |
| scores_per_token[token] = [1 / avg_score] |
| else: |
| scores_per_token[token].append(1 / avg_score) |
|
|
| progress_bar.progress( |
| (t + 1) / len(tokenized_text), |
| text=f"Processing token {t+1} of {len(tokenized_text)}", |
| ) |
| progress_bar.empty() |
|
|
| avg_scores_per_token = [ |
| np.mean(scores_per_token[tok]) for tok in list(scores_per_token.keys()) |
| ] |
|
|
| normed_scores = torch.softmax(torch.tensor(avg_scores_per_token), dim=0) |
|
|
| token_scores = [f"{round(score.item() * 100, 3)}%" for score in normed_scores] |
| st.session_state.text_table_df = pd.DataFrame( |
| {"token": list(scores_per_token.keys()), "importance": token_scores} |
| ) |
|
|
| st.markdown("**Importance of each text token to relevance score**") |
| st.table(st.session_state.text_table_df) |
|
|
|
|
| @st.dialog(" ", width="large") |
| def image_modal(image): |
| visualize_gradcam(image) |
|
|
|
|
| def vis_known_image(vis_image_id): |
| image_url = st.session_state.images_info.loc[vis_image_id]["image_url"] |
| image_response = requests.get(image_url) |
| image = Image.open(BytesIO(image_response.content), formats=["JPEG", "GIF", "PNG"]) |
| image = image.convert("RGB") |
|
|
| image_modal(image) |
|
|
|
|
| def vis_uploaded_image(): |
| uploaded_file = st.session_state.uploaded_image |
| if uploaded_file is not None: |
| |
| bytes_data = uploaded_file.getvalue() |
| image = Image.open(BytesIO(bytes_data), formats=["JPEG", "GIF", "PNG"]) |
| image = image.convert("RGB") |
|
|
| image_modal(image) |
|
|
|
|
| def format_vision_mode(mode_stub): |
| return mode_stub.capitalize() |
|
|
|
|
| st.title("Explore Japanese visual aesthetics with CLIP models") |
|
|
| st.markdown( |
| """ |
| <style> |
| [data-testid=stImageCaption] { |
| padding: 0 0 0 0; |
| } |
| [data-testid=stVerticalBlockBorderWrapper] { |
| line-height: 1.2; |
| } |
| [data-testid=stVerticalBlock] { |
| gap: .75rem; |
| } |
| [data-testid=baseButton-secondary] { |
| min-height: 1rem; |
| padding: 0 0.75rem; |
| margin: 0 0 1rem 0; |
| } |
| div[aria-label="dialog"]>button[aria-label="Close"] { |
| display: none; |
| } |
| [data-testid=stFullScreenFrame] { |
| display: flex; |
| flex-direction: column; |
| align-items: center; |
| } |
| </style> |
| """, |
| unsafe_allow_html=True, |
| ) |
|
|
| search_row = st.columns([45, 8, 8, 10, 1, 8, 20], vertical_alignment="center") |
| with search_row[0]: |
| search_field = st.text_input( |
| label="search", |
| label_visibility="collapsed", |
| placeholder="Type something, or click a suggested search below.", |
| on_change=string_search, |
| key="search_field_value", |
| ) |
| with search_row[1]: |
| st.button( |
| "Search", on_click=string_search, use_container_width=True, type="primary" |
| ) |
| with search_row[2]: |
| st.markdown("**Vision mode:**") |
| with search_row[3]: |
| st.selectbox( |
| "Vision mode", |
| options=["tiled", "stretched", "cropped"], |
| key="vision_mode", |
| help="How to consider images that aren't square", |
| on_change=load_image_features, |
| format_func=format_vision_mode, |
| label_visibility="collapsed", |
| ) |
| with search_row[4]: |
| st.empty() |
| with search_row[5]: |
| st.markdown("**CLIP model:**") |
| with search_row[6]: |
| st.selectbox( |
| "CLIP Model:", |
| options=[ |
| "J-CLIP (日本語 ViT)", |
| "M-CLIP (multilingual ViT)", |
| "Legacy (multilingual ResNet)", |
| ], |
| key="active_model", |
| on_change=string_search, |
| label_visibility="collapsed", |
| ) |
|
|
| canned_searches = st.columns([12, 22, 22, 22, 22], vertical_alignment="top") |
| with canned_searches[0]: |
| st.markdown("**Suggested searches:**") |
| if st.session_state.active_model == "J-CLIP (日本語 ViT)": |
| with canned_searches[1]: |
| st.button( |
| "間", |
| on_click=clip_search, |
| args=["間"], |
| use_container_width=True, |
| ) |
| with canned_searches[2]: |
| st.button("奥", on_click=clip_search, args=["奥"], use_container_width=True) |
| with canned_searches[3]: |
| st.button("山", on_click=clip_search, args=["山"], use_container_width=True) |
| with canned_searches[4]: |
| st.button( |
| "花に酔えり 羽織着て刀 さす女", |
| on_click=clip_search, |
| args=["花に酔えり 羽織着て刀 さす女"], |
| use_container_width=True, |
| ) |
| else: |
| with canned_searches[1]: |
| st.button( |
| "negative space", |
| on_click=clip_search, |
| args=["negative space"], |
| use_container_width=True, |
| ) |
| with canned_searches[2]: |
| st.button("間", on_click=clip_search, args=["間"], use_container_width=True) |
| with canned_searches[3]: |
| st.button("음각", on_click=clip_search, args=["음각"], use_container_width=True) |
| with canned_searches[4]: |
| st.button( |
| "αρνητικός χώρος", |
| on_click=clip_search, |
| args=["αρνητικός χώρος"], |
| use_container_width=True, |
| ) |
|
|
| controls = st.columns([25, 25, 20, 35], gap="large", vertical_alignment="center") |
| with controls[0]: |
| im_per_pg = st.columns([30, 70], vertical_alignment="center") |
| with im_per_pg[0]: |
| st.markdown("**Images/page:**") |
| with im_per_pg[1]: |
| batch_size = st.select_slider( |
| "Images/page:", range(10, 50, 10), label_visibility="collapsed" |
| ) |
| with controls[1]: |
| im_per_row = st.columns([30, 70], vertical_alignment="center") |
| with im_per_row[0]: |
| st.markdown("**Images/row:**") |
| with im_per_row[1]: |
| row_size = st.select_slider( |
| "Images/row:", range(1, 6), value=5, label_visibility="collapsed" |
| ) |
| num_batches = ceil(len(st.session_state.image_ids) / batch_size) |
| with controls[2]: |
| pager = st.columns([40, 60], vertical_alignment="center") |
| with pager[0]: |
| st.markdown(f"Page **{st.session_state.current_page}** of **{num_batches}** ") |
| with pager[1]: |
| st.number_input( |
| "Page", |
| min_value=1, |
| max_value=num_batches, |
| step=1, |
| label_visibility="collapsed", |
| key="current_page", |
| ) |
| with controls[3]: |
| st.file_uploader( |
| "Upload an image", |
| type=["jpg", "jpeg", "gif", "png"], |
| key="uploaded_image", |
| label_visibility="collapsed", |
| on_change=vis_uploaded_image, |
| disabled=st.session_state.disable_uploader, |
| ) |
|
|
|
|
| if len(st.session_state.search_image_ids) == 0: |
| batch = [] |
| else: |
| batch = st.session_state.search_image_ids[ |
| (st.session_state.current_page - 1) * batch_size : st.session_state.current_page |
| * batch_size |
| ] |
|
|
| grid = st.columns(row_size) |
| col = 0 |
| for image_id in batch: |
| with grid[col]: |
| link_text = st.session_state.images_info.loc[image_id]["permalink"].split("/")[ |
| 2 |
| ] |
| |
| |
| |
| |
| st.html( |
| f"""<div style="display: flex; flex-direction: column; align-items: center"> |
| <img src="{st.session_state.images_info.loc[image_id]['image_url']}" style="max-width: 100%; max-height: {MAX_IMG_HEIGHT}px" /> |
| <div>{st.session_state.images_info.loc[image_id]['caption']} <b>[{round(st.session_state.search_image_scores[image_id], 3)}]</b></div> |
| </div>""" |
| ) |
| st.caption( |
| f"""<div style="display: flex; flex-direction: column; align-items: center; position: relative; top: -12px"> |
| <a href="{st.session_state.images_info.loc[image_id]['permalink']}">{link_text}</a> |
| <div>""", |
| unsafe_allow_html=True, |
| ) |
| if not ( |
| RUN_LITE and st.session_state.active_model == "M-CLIP (multilingual ViT)" |
| ): |
| st.button( |
| "Explain this", |
| on_click=vis_known_image, |
| args=[image_id], |
| use_container_width=True, |
| key=image_id, |
| ) |
| else: |
| st.empty() |
| col = (col + 1) % row_size |
|
|