Spaces:
Runtime error
Runtime error
| import torch | |
| from PIL import Image | |
| from transformers import CLIPProcessor, CLIPModel | |
| from pathlib import Path | |
| from torch.utils.data import Dataset, DataLoader | |
| import os | |
| import numpy as np | |
| from numpy.linalg import norm | |
| import matplotlib.pyplot as plt | |
| import gradio as gr | |
| # Cast 1 ------------------------------------------------------------------------- | |
| def get_clip_embeddings(input_data, input_type='text'): | |
| # Load the CLIP model and processor | |
| model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
| processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| # Prepare the input based on the type | |
| if input_type == 'text': | |
| inputs = processor(text=input_data, return_tensors="pt", padding=True, truncation=True) | |
| elif input_type == 'image': | |
| if isinstance(input_data, str): | |
| image = Image.open(input_data) | |
| elif isinstance(input_data, Image.Image): | |
| image = input_data | |
| else: | |
| raise ValueError("For image input, provide either a file path or a PIL Image object") | |
| inputs = processor(images=image, return_tensors="pt") | |
| else: | |
| raise ValueError("Invalid input_type. Choose 'text' or 'image'") | |
| # Get the embeddings | |
| with torch.no_grad(): | |
| if input_type == 'text': | |
| embeddings = model.get_text_features(**inputs) | |
| else: | |
| embeddings = model.get_image_features(**inputs) | |
| return embeddings.numpy() | |
| # Cast 2 ------------------------------------------------------------------------- | |
| class ImageDataset(Dataset): | |
| def __init__(self, image_dir, processor): | |
| self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith(('.png', '.jpg', '.jpeg'))] | |
| self.processor = processor | |
| def __len__(self): | |
| return len(self.image_paths) | |
| def __getitem__(self, idx): | |
| image = Image.open(self.image_paths[idx]) | |
| return self.processor(images=image, return_tensors="pt")['pixel_values'][0] | |
| def get_clip_embeddings_batch(image_dir, batch_size=32, device='cuda'): | |
| # Load the CLIP model and processor | |
| model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) | |
| processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| # Create dataset and dataloader | |
| dataset = ImageDataset(image_dir, processor) | |
| dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4) | |
| all_embeddings = [] | |
| model.eval() | |
| with torch.no_grad(): | |
| for batch in dataloader: | |
| batch = batch.to(device) | |
| embeddings = model.get_image_features(pixel_values=batch) | |
| all_embeddings.append(embeddings.cpu().numpy()) | |
| return np.concatenate(all_embeddings) | |
| # Cast 3 ------------------------------------------------------------------------- | |
| # Funkcia na výpočet cosinovej similarity | |
| def cosine_similarity(x, y): | |
| return np.dot(x, y) / (norm(x) * norm(y)) | |
| # Funkcia na nájdenie indexov obrázkov | |
| def maxCS_indices(text_input, embeddings): | |
| text_embedding = get_clip_embeddings(text_input, input_type='text') | |
| x = text_embedding | |
| Y = embeddings | |
| # print("Text embedding shape:", x.shape) | |
| # print("Embeddings shape:", Y.shape) | |
| # Vypočítaj cosinovú similaritu pre každý riadok matice Y | |
| cosine_similarities = np.array([cosine_similarity(x, y) for y in Y]) | |
| # Získaj indexy štyroch vektorov s najväčšou cosinovou similaritou | |
| maxCS_indices = np.argsort(cosine_similarities, axis = 0)[-4:] | |
| # Výsledné vektory | |
| least_similar_vectors = Y[maxCS_indices] | |
| # print("Indexy vektorov s najmenšou cosinovou similaritou:", smallest_indices) | |
| # print("Vektory s najmenšou cosinovou similaritou:\n", least_similar_vectors) | |
| return(maxCS_indices) | |
| # Cast 4 ------------------------------------------------------------------------- | |
| def which_images(images_folder, indices): | |
| # Získání všech názvů obrázků ve složce | |
| image_filenames = [f for f in os.listdir(images_folder) if f.endswith(('.jpg', '.png'))] | |
| # Vytvoření numpy array z názvů obrázků | |
| image_names_array = np.array(image_filenames) | |
| # Vytvorenie vektora | |
| image_names = (image_names_array[indices]) | |
| # print(image_names_array[smallest_indices]) | |
| # Transformácia z poľa na vektor | |
| image_names_final = image_names.flatten() | |
| # print(image_names_final) | |
| return(image_names_final) | |
| # Cast 5 ------------------------------------------------------------------------- | |
| def display_images(folder_path, image_names): | |
| # Zložka s obrázkami | |
| folder = Path(folder_path) | |
| # Inicializuj subplots pre 2x2 grid (4 obrázky) | |
| fig, axes = plt.subplots(1, 4, figsize=(20, 5)) | |
| # Prejdi cez všetky zadané obrázky a vykresli ich | |
| for ax, img_name in zip(axes, image_names): | |
| # Načítaj obrázok | |
| img_path = folder / img_name | |
| img = Image.open(img_path) | |
| # Vykresli obrázok na subplot | |
| ax.imshow(img) | |
| ax.set_title(img_name) # Nastav názov obrázka ako titulok | |
| ax.axis('off') # Skry výstup osí | |
| # Zobraz obrázky | |
| plt.show() | |
| # Cast 6 ------------------------------------------------------------------------- | |
| # Nastavenie parametrov pre funkciu process_input | |
| images_folder = "Datasets/kotlarska2/Trains" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| embeddings = get_clip_embeddings_batch(images_folder, 32, device) | |
| # Hlavná funkcia na spracovanie vstupu a zobrazenie obrázkov | |
| def process_input(text_input): | |
| our_indinces = maxCS_indices(text_input, embeddings) | |
| our_images = which_images(images_folder, our_indinces) | |
| return display_our_images(images_folder, our_images) | |
| # Funkcia na zobrazenie obrázkov | |
| def display_our_images(folder_path, image_names): | |
| # Zložka s obrázkami | |
| folder = Path(folder_path) | |
| # Inicializuj subplots pre 2x2 grid (4 obrázky) | |
| fig, axes = plt.subplots(1, 4, figsize=(20, 5)) | |
| # Prejdi cez všetky zadané obrázky a vykresli ich | |
| for ax, img_name in zip(axes, image_names): | |
| # Načítaj obrázok | |
| img_path = folder / img_name | |
| img = Image.open(img_path) | |
| # Vykresli obrázok na subplot | |
| ax.imshow(img) | |
| ax.set_title(img_name) # Nastav názov obrázka ako titulok | |
| ax.axis('off') # Skry výstup osí | |
| # Ulož obrázok do súboru | |
| plt.tight_layout() | |
| plt.savefig('output_images.png') | |
| plt.close() | |
| return 'output_images.png' | |
| # Nastav a spusti Gradio rozhranie | |
| iface = gr.Interface( | |
| fn=process_input, | |
| inputs="text", | |
| outputs="image", | |
| title="Image Similarity", | |
| description="Zadaj text a zobrazia sa 4 najpodobnejšie obrázky z našej databázy SUV vozidiel.") | |
| iface.launch(share=True) | |