Spaces:
Running on Zero
Running on Zero
| from datasets import load_dataset | |
| import torch | |
| from transformers import AutoProcessor, AutoModelForZeroShotImageClassification | |
| from loadimg import load_img | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' # we should rlly check for mps but, who uses macs (this is a space. lol) | |
| processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14") | |
| model = AutoModelForZeroShotImageClassification.from_pretrained("openai/clip-vit-large-patch14", device_map = device) | |
| class Instance: | |
| def __init__(self, dataset, token=None, split="train"): | |
| self.dataset = dataset | |
| self.token = token | |
| self.split = split | |
| self.data = load_dataset(self.dataset, split=self.split) | |
| self.data = self.data.add_faiss_index("embeddings") | |
| def embed(batch): | |
| """a function that embeds a batch of images and returns the embeddings intended for embedding already existing images in an external dataset. (unused)""" | |
| pixel_values = processor(images = batch["image"], return_tensors="pt")['pixel_values'] | |
| pixel_values = pixel_values.to(device) | |
| img_emb = model.get_image_features(pixel_values) | |
| batch["embeddings"] = img_emb | |
| return batch | |
| def search(self, query: str, k: int = 3 ): | |
| """ | |
| A function that embeds a query image and returns the most probable results. | |
| Args: | |
| query: the image to search for | |
| k: the number of results to return | |
| Returns: | |
| scores: the scores of the retrieved examples (cosine similarity i think in this case) | |
| retrieved_examples: the retrieved examples | |
| """ | |
| pixel_values = processor(images = query, return_tensors="pt")['pixel_values'] | |
| pixel_values = pixel_values.to(device) | |
| img_emb = model.get_image_features(pixel_values)[0] | |
| img_emb = img_emb.cpu().detach().numpy() | |
| scores, retrieved_examples = self.data.get_nearest_examples( | |
| "embeddings", img_emb, | |
| k=k | |
| ) | |
| return scores, retrieved_examples | |
| def high_level_search(self, img): | |
| """ | |
| High level wrapper for the search function. | |
| Args: | |
| img: input image (path, url, pillow or numpy) | |
| Returns: | |
| scores: the scores of the retrieved examples (cosine similarity i think in this case) | |
| retrieved_examples: the retrieved examples | |
| """ | |
| image = load_img(img) | |
| scores, retrieved_examples = self.search(image) |