Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer | |
| import sentence_transformers | |
| from sentence_transformers import SentenceTransformer, util | |
| import pickle | |
| from PIL import Image | |
| import os | |
| from datasets import load_dataset | |
| from huggingface_hub.hf_api import HfFolder | |
| import numpy as np | |
| import torch | |
| import os | |
| from PIL import Image | |
| import io | |
| def convert_to_image(byte_data): | |
| """Convert byte strings to images | |
| """ | |
| return Image.open(io.BytesIO(byte_data)) | |
| # Load the model and dataset | |
| model = SentenceTransformer('clip-ViT-B-32') | |
| ds_with_embeddings = load_dataset("kvriza8/clip_microscopy_image_text_embeddings") | |
| # Initialize FAISS index once | |
| ds_with_embeddings['train'].add_faiss_index(column='img_embeddings') | |
| def get_image_from_text(text_prompt, number_to_retrieve=1): | |
| prompt = model.encode(text_prompt) | |
| scores, retrieved_examples = ds_with_embeddings['train'].get_nearest_examples('img_embeddings', prompt, k=number_to_retrieve) | |
| # Convert byte images to PIL images | |
| images = [convert_to_image(img) for img in retrieved_examples['images']] | |
| captions = retrieved_examples['caption_summary'] | |
| return images, captions | |
| clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
| clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| def get_image_from_image(query_image, number_to_retrieve=1): | |
| if number_to_retrieve <= 0: | |
| raise ValueError("Number to retrieve must be a positive integer") | |
| image = Image.fromarray((query_image * 255).astype(np.uint8)) | |
| inputs = clip_processor(images=image, return_tensors="pt") | |
| with torch.no_grad(): | |
| image_features = clip_model.get_image_features(**inputs) | |
| image_features_numpy = image_features.cpu().detach().numpy() | |
| scores, retrieved_examples = ds_with_embeddings['train'].get_nearest_examples('img_embeddings', image_features_numpy, k=number_to_retrieve) | |
| images = [convert_to_image(img) for img in retrieved_examples['images']] | |
| captions = retrieved_examples['caption_summary'] | |
| return images, captions | |
| def plot_images(text_prompt="", number_to_retrieve=1, query_image=None): | |
| if query_image is not None: | |
| # Handle image input | |
| sample_images, sample_titles = get_image_from_image(query_image, number_to_retrieve) | |
| elif text_prompt: | |
| # Handle text input | |
| sample_images, sample_titles = get_image_from_text(text_prompt, number_to_retrieve) | |
| else: | |
| # Handle empty input | |
| return [], "No input provided" | |
| concatenated_captions = "\n".join(sample_titles) | |
| return sample_images, concatenated_captions | |
| iface = gr.Interface( | |
| title="Microscopy image retrieval", | |
| fn=plot_images, | |
| inputs=[ | |
| gr.Textbox(lines=4, label="Insert your prompt", placeholder="Text Here..."), | |
| gr.Slider(0, 8, step=1), | |
| gr.Image(label="Or Upload an Image") | |
| ], | |
| outputs=[gr.Gallery(label="Retrieved Images"), gr.Textbox(label="Image Captions")], | |
| examples=[["TEM image", 2], ["Nanoparticles", 1], ["ZnSe-ZnTe core-shell nanowire", 2]] | |
| ).launch(debug=True) | |