Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import pickle | |
| import numpy as np | |
| import pandas as pd | |
| from transformers import CLIPProcessor, CLIPModel | |
| from transformers import VisionTextDualEncoderModel, VisionTextDualEncoderProcessor | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import csv | |
| from PIL import Image | |
| model_path_rclip = "kaveh/rclip" | |
| embeddings_file_rclip = './image_embeddings_rclip.pkl' | |
| model_path_pubmedclip = "flaviagiammarino/pubmed-clip-vit-base-patch32" | |
| embeddings_file_pubmedclip = './image_embeddings_pubmedclip.pkl' | |
| csv_path = "./captions.txt" | |
| def load_image_ids(csv_file): | |
| ids = [] | |
| captions = [] | |
| with open(csv_file, 'r') as f: | |
| reader = csv.reader(f, delimiter='\t') | |
| for row in reader: | |
| ids.append(row[0]) | |
| captions.append(row[1]) | |
| return ids, captions | |
| def load_embeddings(embeddings_file): | |
| with open(embeddings_file, 'rb') as f: | |
| image_embeddings = pickle.load(f) | |
| return image_embeddings | |
| def find_similar_images(query_embedding, image_embeddings, k=2): | |
| similarities = cosine_similarity(query_embedding.reshape(1, -1), image_embeddings) | |
| closest_indices = np.argsort(similarities[0])[::-1][:k] | |
| scores = sorted(similarities[0])[::-1][:k] | |
| return closest_indices, scores | |
| def main(query, model_id="RCLIP", k=2): | |
| if model_id=="RCLIP": | |
| # Load RCLIP model | |
| model = VisionTextDualEncoderModel.from_pretrained(model_path_rclip) | |
| processor = VisionTextDualEncoderProcessor.from_pretrained(model_path_rclip) | |
| # Load image embeddings | |
| image_embeddings = load_embeddings(embeddings_file_rclip) | |
| elif model_id=="PubMedCLIP": | |
| model = CLIPModel.from_pretrained(model_path_pubmedclip) | |
| processor = CLIPProcessor.from_pretrained(model_path_pubmedclip) | |
| # Load image embeddings | |
| image_embeddings = load_embeddings(embeddings_file_pubmedclip) | |
| # Embed the query | |
| inputs = processor(text=query, images=None, return_tensors="pt", padding=True) | |
| with torch.no_grad(): | |
| query_embedding = model.get_text_features(**inputs)[0].numpy() | |
| # Get image names | |
| ids, captions = load_image_ids(csv_path) | |
| # Find similar images | |
| similar_image_indices, scores = find_similar_images(query_embedding, image_embeddings, k=int(k)) | |
| # Return the results | |
| similar_image_names = [f"./images/{ids[index]}.jpg" for index in similar_image_indices] | |
| similar_image_captions = [captions[index] for index in similar_image_indices] | |
| similar_images = [Image.open(i) for i in similar_image_names] | |
| return similar_images, pd.DataFrame([[t+1 for t in range(k)], similar_image_names, similar_image_captions, scores], index=["#", "path", "caption", "score"]).T | |
| # Define the Gradio interface | |
| examples = [ | |
| ["Chest X-ray photos", "RCLIP", 10], | |
| ["Chest X-ray photos", "PubMedCLIP", 10], | |
| ["Orthopantogram (OPG)", "RCLIP", 10], | |
| ["Brain MRI", "RCLIP", 10], | |
| ["Ultrasound", "RCLIP", 10], | |
| ] | |
| title="RCLIP Image Retrieval" | |
| description = "CLIP model fine-tuned on the ROCO dataset" | |
| with gr.Blocks(title=title) as demo: | |
| with gr.Row(): | |
| with gr.Column(scale=5): | |
| gr.Markdown("# "+title) | |
| gr.Markdown(description) | |
| #Image.open("./data/teesside university logo.png"), height=70, show_label=False, container=False) | |
| with gr.Row(variant="compact"): | |
| query = gr.Textbox(value="Chest X-Ray Photos", label="Enter your query", show_label=False, placeholder= "Enter your query" , scale=5) | |
| btn = gr.Button("Search query", variant="primary", scale=1) | |
| with gr.Row(variant="compact"): | |
| model_id = gr.Dropdown(["RCLIP", "PubMedCLIP"], value="RCLIP", label="Model", type="value", scale=1) | |
| n_s = gr.Slider(2, 10, label='Number of Top Results', value=10, step=1.0, show_label=True, scale=1) | |
| with gr.Column(variant="compact"): | |
| gr.Markdown("## Results") | |
| gallery = gr.Gallery(label="found images", show_label=True, elem_id="gallery", columns=[2], rows=[4], object_fit="contain", height="400px", preview=True) | |
| gr.Markdown("Information of the found images") | |
| df = gr.DataFrame() | |
| btn.click(main, [query, model_id, n_s], [gallery, df]) | |
| with gr.Column(variant="compact"): | |
| gr.Markdown("## Examples") | |
| gr.Examples(examples, [query, model_id, n_s]) | |
| demo.launch(debug='True') | |