Spaces:
Runtime error
Runtime error
| import io | |
| import requests | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| import gradio as gr | |
| import uform | |
| from datetime import datetime | |
| model_multi = uform.get_model('unum-cloud/uform-vl-multilingual') | |
| embeddings = np.load('tensors/embeddings.npy') | |
| embeddings = torch.tensor(embeddings) | |
| #features = np.load('multilingual-image-search/tensors/features.npy') | |
| #features = torch.tensor(features) | |
| img_df = pd.read_csv('image_data.csv') | |
| def url2img(url, resize = False, fix_height = 150): | |
| data = requests.get(url, allow_redirects = True).content | |
| img = Image.open(io.BytesIO(data)) | |
| if resize: | |
| img.thumbnail([fix_height, fix_height], Image.LANCZOS) | |
| return img | |
| def find_topk(text): | |
| print('text', text) | |
| top_k = 20 | |
| text_data = model_multi.preprocess_text(text) | |
| text_features, text_embedding = model_multi.encode_text(text_data, return_features=True) | |
| print('Got features', datetime.now().strftime("%H:%M:%S")) | |
| sims = F.cosine_similarity(text_embedding, embeddings) | |
| vals, inds = sims.topk(top_k) | |
| top_k_urls = img_df.iloc[inds]['photo_image_url'].values | |
| print('Got top_k_urls', top_k_urls) | |
| print(datetime.now().strftime("%H:%M:%S")) | |
| return top_k_urls | |
| # def rerank(text_features, text_data): | |
| # # craet joint embeddings & get scores | |
| # joint_embedding = model_multi.encode_multimodal( | |
| # image_features=image_features, | |
| # text_features=text_features, | |
| # attention_mask=text_data['attention_mask'] | |
| # ) | |
| # score = model_multi.get_matching_scores(joint_embedding) | |
| # # argmax to get top N | |
| # return | |
| #demo = gr.Interface(find_topk, inputs = 'text', outputs = 'image') | |
| print('version', gr.__version__) | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown('# Enter a prompt in one of the supported languages.') | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown( | |
| '||||||\n' | |
| '|:-------: |:---: |:-------: |:---: | :--- |\n' | |
| '|__English__| # |__French__ | # |__Russian__|\n' | |
| '|__German__ | # |__Italian__ | # |__Chinese (Simplified)__|\n' | |
| '|__Spanish__| # |__Japanese__| # |__Korean__|\n' | |
| '|__Turkish__| # |__Polish__ | # |.|\n') | |
| with gr.Column(): | |
| prompt_box = gr.Textbox(label = 'Enter your prompt', lines = 3, container = True) | |
| btn_search = gr.Button("Find images") | |
| with gr.Row(): | |
| gr.Examples(['a girl wandering alone in the forest', | |
| 'морозное утро в городе', | |
| '카메라를 바라보는 강아지 새끼', | |
| 'ein Schloss, das zwischen modernen Gebäuden hervorlugt', | |
| 'un couple sirotant un café au bord de la rivière', | |
| 'una banda de música actuando en un gran espacio al aire libre', | |
| '秋の静かな霧の庭園' | |
| ], inputs=[prompt_box]) | |
| gallery = gr.Gallery().style(grid = [5], height="auto") | |
| btn_search.click(find_topk, inputs = prompt_box, outputs = gallery) | |
| if __name__ == "__main__": | |
| demo.launch() |