Spaces:
Build error
Build error
| import os | |
| import jax | |
| import jax.numpy as jnp | |
| import numpy as np | |
| import requests | |
| import streamlit as st | |
| from PIL import Image | |
| from utils import load_model | |
| def split_image(im, num_rows=3, num_cols=3): | |
| im = np.array(im) | |
| row_size = im.shape[0] // num_rows | |
| col_size = im.shape[1] // num_cols | |
| tiles = [ | |
| im[row : row + row_size, col : col + col_size] | |
| for row in range(0, num_rows * row_size, row_size) | |
| for col in range(0, num_cols * col_size, col_size) | |
| ] | |
| return tiles | |
| def app(model_name): | |
| model, processor = load_model(f"koclip/{model_name}") | |
| st.title("Patch-based Relevance Retrieval") | |
| st.markdown( | |
| """ | |
| Given a piece of text, the CLIP model finds the part of an image that best explains the text. | |
| To try it out, you can | |
| 1. Upload an image | |
| 2. Explain a part of the image in text | |
| which will yield the most relevant image tile from a grid of the image. You can specify how | |
| granular you want to be with your search by specifying the number of rows and columns that | |
| make up the image grid. | |
| """ | |
| ) | |
| query1 = st.text_input( | |
| "Enter a URL to an image...", | |
| value="https://img.sbs.co.kr/newimg/news/20200823/201463830_1280.jpg", | |
| ) | |
| query2 = st.file_uploader("or upload an image...", type=["jpg", "jpeg", "png"]) | |
| captions = st.text_input( | |
| "Enter query to find most relevant part of image ", | |
| value="이건 서울의 경복궁 사진이다.", | |
| ) | |
| col1, col2 = st.beta_columns(2) | |
| with col1: | |
| num_rows = st.slider( | |
| "Number of rows", min_value=1, max_value=5, value=3, step=1 | |
| ) | |
| with col2: | |
| num_cols = st.slider( | |
| "Number of columns", min_value=1, max_value=5, value=3, step=1 | |
| ) | |
| if st.button("질문 (Query)"): | |
| if not any([query1, query2]): | |
| st.error("Please upload an image or paste an image URL.") | |
| else: | |
| st.markdown("""---""") | |
| with st.spinner("Computing..."): | |
| image_data = ( | |
| query2 | |
| if query2 is not None | |
| else requests.get(query1, stream=True).raw | |
| ) | |
| image = Image.open(image_data) | |
| st.image(image) | |
| images = split_image(image, num_rows, num_cols) | |
| inputs = processor( | |
| text=captions, images=images, return_tensors="jax", padding=True | |
| ) | |
| inputs["pixel_values"] = jnp.transpose( | |
| inputs["pixel_values"], axes=[0, 2, 3, 1] | |
| ) | |
| outputs = model(**inputs) | |
| probs = jax.nn.softmax(outputs.logits_per_image, axis=0) | |
| for idx, prob in sorted( | |
| enumerate(probs), key=lambda x: x[1], reverse=True | |
| ): | |
| st.text(f"Score: {prob[0]:.3f}") | |
| st.image(images[idx]) | |