Spaces:
Build error
Build error
| import torch | |
| from PIL import Image | |
| from transformers import CLIPProcessor, CLIPModel | |
| from torch.utils.data import Dataset, DataLoader | |
| import os | |
| import numpy as np | |
| import pickle | |
| import gradio as gr | |
| class ImageDataset(Dataset): | |
| def __init__(self, image_dir, processor): | |
| self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith(('.png', '.jpg', '.jpeg'))] | |
| self.processor = processor | |
| def __len__(self): | |
| return len(self.image_paths) | |
| def __getitem__(self, idx): | |
| image = Image.open(self.image_paths[idx]) | |
| return self.processor(images=image, return_tensors="pt")['pixel_values'][0] | |
| def get_and_save_clip_embeddings(image_dir, output_file, batch_size=32, device='cuda'): | |
| model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) | |
| processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| dataset = ImageDataset(image_dir, processor) | |
| dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4) | |
| all_embeddings = [] | |
| image_paths = [] | |
| model.eval() | |
| with torch.no_grad(): | |
| for batch_idx, batch in enumerate(dataloader): | |
| batch = batch.to(device) | |
| embeddings = model.get_image_features(pixel_values=batch) | |
| all_embeddings.append(embeddings.cpu().numpy()) | |
| start_idx = batch_idx * batch_size | |
| end_idx = start_idx + len(batch) | |
| image_paths.extend(dataset.image_paths[start_idx:end_idx]) | |
| all_embeddings = np.concatenate(all_embeddings) | |
| with open(output_file, 'wb') as f: | |
| pickle.dump({'embeddings': all_embeddings, 'image_paths': image_paths}, f) | |
| # image_dir = "dataset/" | |
| # output_file = "image_embeddings.pkl" | |
| # batch_size = 32 | |
| # device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # get_and_save_clip_embeddings(image_dir, output_file, batch_size, device) | |
| # APP | |
| model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
| processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| with open('image_embeddings.pkl', 'rb') as f: | |
| f = pickle.load(f) | |
| image_embeddings = f['embeddings'] | |
| image_names = f['image_paths'] | |
| image_paths = './dataset' | |
| def cosine_similarity(a, b): | |
| a = a / np.linalg.norm(a, axis=-1, keepdims=True) | |
| b = b / np.linalg.norm(b, axis=-1, keepdims=True) | |
| return np.dot(a, b.T) | |
| def find_similar_images(text): | |
| inputs = processor(text=[text], return_tensors="pt", padding=True) | |
| with torch.no_grad(): | |
| text_embedding = model.get_text_features(**inputs).cpu().numpy() | |
| similarities = cosine_similarity(text_embedding, image_embeddings) | |
| top_indices = np.argsort(similarities[0])[::-1][:4] | |
| top_images = [image_names[i] for i in top_indices] | |
| return top_images | |
| text_input = gr.Textbox(label="Input text", placeholder="Enter the images description") | |
| imgs_output = gr.Gallery(label="Top 4 most similar images") | |
| intf = gr.Interface( | |
| fn=find_similar_images, | |
| inputs=gr.Textbox(label="Input Text", placeholder="Enter a description"), | |
| outputs=gr.Gallery(label="Top 4 Similar Images"), | |
| ) | |
| extra_text = gr.Markdown(""" | |
| The dataset contains images of dogs, rabbits, sharks, and deer. | |
| Displaying the images might take a couple of seconds. | |
| """) | |
| with gr.Blocks() as app: | |
| intf.render() | |
| extra_text.render() | |
| app.launch(share=True) |