Spaces:
Build error
Build error
| import os | |
| from pathlib import Path | |
| import pandas as pd, numpy as np | |
| from transformers import CLIPProcessor, CLIPTextModel, CLIPModel | |
| import torch | |
| from torch import nn | |
| import gradio as gr | |
| import requests | |
| from PIL import Image, ImageFile | |
| ImageFile.LOAD_TRUNCATED_IMAGES = True | |
| LABELS = Path('class_names.txt').read_text().splitlines() | |
| class_model = nn.Sequential( | |
| nn.Conv2d(1, 32, 3, padding='same'), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2), | |
| nn.Conv2d(32, 64, 3, padding='same'), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2), | |
| nn.Conv2d(64, 128, 3, padding='same'), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2), | |
| nn.Flatten(), | |
| nn.Linear(1152, 256), | |
| nn.ReLU(), | |
| nn.Linear(256, len(LABELS)), | |
| ) | |
| state_dict = torch.load('pytorch_model.bin', map_location='cpu') | |
| class_model.load_state_dict(state_dict, strict=False) | |
| class_model.eval() | |
| model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
| processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| df = pd.read_csv('clip.csv') | |
| embeddings_npy = np.load('clip.npy') | |
| embeddings = np.divide(embeddings_npy, np.sqrt(np.sum(embeddings_npy**2, axis=1, keepdims=True))) | |
| def compute_text_embeddings(list_of_strings): | |
| inputs = processor(text=list_of_strings, return_tensors="pt", padding=True) | |
| return model.get_text_features(**inputs) | |
| def compute_image_embeddings(list_of_images): | |
| inputs = processor(images=list_of_images, return_tensors="pt", padding=True) | |
| return model.get_image_features(**inputs) | |
| def load_image(image, same_height=False): | |
| # im = Image.open(path) | |
| im = Image.fromarray(np.uint8(image)) | |
| if im.mode != 'RGB': | |
| im = im.convert('RGB') | |
| if same_height: | |
| ratio = 224/im.size[1] | |
| return im.resize((int(im.size[0]*ratio), int(im.size[1]*ratio))) | |
| else: | |
| ratio = 224/min(im.size) | |
| return im.resize((int(im.size[0]*ratio), int(im.size[1]*ratio))) | |
| def download_img(identifier, url): | |
| local_path = f"{identifier}.jpg" | |
| if not os.path.isfile(local_path): | |
| img_data = requests.get(url).content | |
| with open(local_path, 'wb') as handler: | |
| handler.write(img_data) | |
| return local_path | |
| def predict(image=None, text=None, sketch=None): | |
| if image is not None: | |
| input_embeddings = compute_image_embeddings([load_image(image)]).detach().numpy() | |
| topk = {"local": 1} | |
| else: | |
| if text: | |
| query = text | |
| topk = {text: 1} | |
| else: | |
| x = torch.tensor(sketch, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255. | |
| with torch.no_grad(): | |
| out = class_model(x) | |
| probabilities = torch.nn.functional.softmax(out[0], dim=0) | |
| values, indices = torch.topk(probabilities, 5) | |
| query = LABELS[indices[0]] | |
| topk = {LABELS[i]: v.item() / 100.0 for i, v in zip(indices, values)} | |
| input_embeddings = compute_text_embeddings([query]).detach().numpy() | |
| n_results = 3 | |
| results = np.argsort((embeddings @ input_embeddings.T)[:, 0])[-1:-n_results - 1:-1] | |
| outputs = [download_img(df.iloc[i]['id'], df.iloc[i]['thumbnail']) for i in results] | |
| outputs.insert(0, topk) | |
| print(outputs) | |
| return outputs | |
| def predict_text(text): | |
| return predict(None, text, None) | |
| title = "Type to search in the Nasjonalbiblioteket" | |
| description = "Find images in the Nasjonalbiblioteket image collections based on what you type" | |
| interface = gr.Interface( | |
| fn=predict_text, | |
| inputs=["text"], | |
| outputs=[gr.Label(num_top_classes=3), gr.Image(type="filepath"), gr.Image(type="filepath"), gr.Image(type="filepath")], | |
| title=title, | |
| description=description, | |
| #live=True, | |
| examples=[ | |
| ["kids playing in the snow"], | |
| ["walking in the dark"], | |
| ["woman sitting on a chair while drinking a beer"], | |
| ["nice view out the window on a train"], | |
| ], | |
| ) | |
| interface.launch(debug=True) | |