| import streamlit as st |
| import pandas as pd, numpy as np |
| from html import escape |
| import os |
| from transformers import CLIPProcessor, CLIPTextModel, CLIPModel |
|
|
| @st.cache(show_spinner=False, |
| hash_funcs={CLIPModel: lambda _: None, |
| CLIPTextModel: lambda _: None, |
| CLIPProcessor: lambda _: None, |
| dict: lambda _: None}) |
| def load(): |
| model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") |
| processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") |
| df = {0: pd.read_csv('data.csv'), 1: pd.read_csv('data2.csv')} |
| embeddings = {0: np.load('embeddings.npy'), 1: np.load('embeddings2.npy')} |
| for k in [0, 1]: |
| embeddings[k] = np.divide(embeddings[k], np.sqrt(np.sum(embeddings[k]**2, axis=1, keepdims=True))) |
| return model, processor, df, embeddings |
| model, processor, df, embeddings = load() |
|
|
| source = {0: '\nSource: Unsplash', 1: '\nSource: The Movie Database (TMDB)'} |
|
|
| def get_html(url_list, height=200): |
| html = "<div style='margin-top: 20px; max-width: 1200px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>" |
| for url, title, link in url_list: |
| html2 = f"<img title='{escape(title)}' style='height: {height}px; margin: 5px' src='{escape(url)}'>" |
| if len(link) > 0: |
| html2 = f"<a href='{escape(link)}' target='_blank'>" + html2 + "</a>" |
| html = html + html2 |
| html += "</div>" |
| return html |
|
|
| def compute_text_embeddings(list_of_strings): |
| inputs = processor(text=list_of_strings, return_tensors="pt", padding=True) |
| return model.get_text_features(**inputs) |
|
|
| st.cache(show_spinner=False) |
| def image_search(query, corpus, n_results=24): |
| text_embeddings = compute_text_embeddings([query]).detach().numpy() |
| k = 0 if corpus == 'Unsplash' else 1 |
| results = np.argsort((embeddings[k]@text_embeddings.T)[:, 0])[-1:-n_results-1:-1] |
| return [(df[k].iloc[i]['path'], |
| df[k].iloc[i]['tooltip'] + source[k], |
| df[k].iloc[i]['link']) for i in results] |
|
|
| description = ''' |
| # Semantic image search |
| |
| **Enter your query and hit enter** |
| |
| *Built with OpenAI's [CLIP](https://openai.com/blog/clip/) model, 🤗 Hugging Face's [transformers library](https://huggingface.co/transformers/), [Streamlit](https://streamlit.io/), 25k images from [Unsplash](https://unsplash.com/) and 8k images from [The Movie Database (TMDB)](https://www.themoviedb.org/)* |
| |
| *Inspired by [Unsplash Image Search](https://github.com/haltakov/natural-language-image-search) from Vladimir Haltakov and [Alph, The Sacred River](https://github.com/thoppe/alph-the-sacred-river) from Travis Hoppe* |
| ''' |
|
|
| def main(): |
| st.markdown(''' |
| <style> |
| .block-container{ |
| max-width: 1200px; |
| } |
| div.row-widget.stRadio > div{ |
| flex-direction:row; |
| display: flex; |
| justify-content: center; |
| } |
| div.row-widget.stRadio > div > label{ |
| margin-left: 5px; |
| margin-right: 5px; |
| } |
| section.main>div:first-child { |
| padding-top: 0px; |
| } |
| section:not(.main)>div:first-child { |
| padding-top: 30px; |
| } |
| div.reportview-container > section:first-child{ |
| max-width: 320px; |
| } |
| #MainMenu { |
| visibility: hidden; |
| } |
| footer { |
| visibility: hidden; |
| } |
| </style>''', |
| unsafe_allow_html=True) |
| st.sidebar.markdown(description) |
| _, c, _ = st.beta_columns((1, 3, 1)) |
| query = c.text_input('', value='clouds at sunset') |
| corpus = st.radio('', ["Unsplash","Movies"]) |
| if len(query) > 0: |
| results = image_search(query, corpus) |
| st.markdown(get_html(results), unsafe_allow_html=True) |
|
|
| if __name__ == '__main__': |
| main() |
|
|