Spaces:
Running
Running
| import streamlit as st | |
| from text2image import get_model, get_tokenizer, get_image_transform | |
| from utils import text_encoder, image_encoder | |
| from PIL import Image | |
| from jax import numpy as jnp | |
| from io import BytesIO | |
| import pandas as pd | |
| import requests | |
| import jax | |
| import gc | |
| headers = { | |
| "User-Agent": | |
| "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/70.0.3538.102 Safari/537.36 Edge/18.19582" | |
| } | |
| def app(): | |
| st.title("From Image to Text") | |
| st.markdown( | |
| """ | |
| ### π Ciao! | |
| Here you can find the captions or the labels that are most related to a given image. It is a zero-shot | |
| image classification task! | |
| π€ Italian mode on! π€ | |
| For example, try typing "gatto" (cat) in the space for label1 and "cane" (dog) in the space for label2 and click | |
| "classify"! | |
| """ | |
| ) | |
| image_url = st.text_input( | |
| "You can input the URL of an image", | |
| value="https://upload.wikimedia.org/wikipedia/commons/thumb/8/88/Ragdoll%2C_blue_mitted.JPG/1280px-Ragdoll%2C_blue_mitted.JPG", | |
| ) | |
| MAX_CAP = 4 | |
| col1, col2 = st.columns([0.75, 0.25]) | |
| with col2: | |
| captions_count = st.selectbox( | |
| "Number of labels", options=range(1, MAX_CAP + 1), index=1 | |
| ) | |
| compute = st.button("CLASSIFY") | |
| with col1: | |
| captions = list() | |
| for idx in range(min(MAX_CAP, captions_count)): | |
| captions.append(st.text_input(f"Insert label {idx+1}")) | |
| if compute: | |
| captions = [c for c in captions if c != ""] | |
| if not captions or not image_url: | |
| st.error("Please choose one image and at least one label") | |
| else: | |
| with st.spinner("Computing..."): | |
| model = get_model() | |
| tokenizer = get_tokenizer() | |
| text_embeds = list() | |
| for i, c in enumerate(captions): | |
| text_embeds.extend(text_encoder(c, model, tokenizer)[0]) | |
| text_embeds = jnp.array(text_embeds) | |
| response = requests.get(image_url, headers=headers, stream=True) | |
| image = Image.open(BytesIO(response.content)).convert("RGB") | |
| transform = get_image_transform(model.config.vision_config.image_size) | |
| image_embed, _ = image_encoder(transform(image), model) | |
| # we could have a softmax here | |
| cos_similarities = jax.nn.softmax( | |
| jnp.matmul(image_embed, text_embeds.T) | |
| ) | |
| chart_data = pd.Series(cos_similarities[0], index=captions) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.bar_chart(chart_data) | |
| with col2: | |
| st.image(image, use_column_width=True) | |
| gc.collect() | |
| elif image_url: | |
| response = requests.get(image_url, headers=headers, stream=True) | |
| image = Image.open(BytesIO(response.content)).convert("RGB") | |
| st.image(image) | |