Spaces:
Runtime error
Runtime error
File size: 6,840 Bytes
28bc16c cd08094 28bc16c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import torch
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
import os
import numpy as np
from numpy.linalg import norm
import matplotlib.pyplot as plt
import gradio as gr
# Cast 1 -------------------------------------------------------------------------
def get_clip_embeddings(input_data, input_type='text'):
# Load the CLIP model and processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# Prepare the input based on the type
if input_type == 'text':
inputs = processor(text=input_data, return_tensors="pt", padding=True, truncation=True)
elif input_type == 'image':
if isinstance(input_data, str):
image = Image.open(input_data)
elif isinstance(input_data, Image.Image):
image = input_data
else:
raise ValueError("For image input, provide either a file path or a PIL Image object")
inputs = processor(images=image, return_tensors="pt")
else:
raise ValueError("Invalid input_type. Choose 'text' or 'image'")
# Get the embeddings
with torch.no_grad():
if input_type == 'text':
embeddings = model.get_text_features(**inputs)
else:
embeddings = model.get_image_features(**inputs)
return embeddings.numpy()
# Cast 2 -------------------------------------------------------------------------
class ImageDataset(Dataset):
def __init__(self, image_dir, processor):
self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]
self.processor = processor
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image = Image.open(self.image_paths[idx])
return self.processor(images=image, return_tensors="pt")['pixel_values'][0]
def get_clip_embeddings_batch(image_dir, batch_size=32, device='cuda'):
# Load the CLIP model and processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# Create dataset and dataloader
dataset = ImageDataset(image_dir, processor)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4)
all_embeddings = []
model.eval()
with torch.no_grad():
for batch in dataloader:
batch = batch.to(device)
embeddings = model.get_image_features(pixel_values=batch)
all_embeddings.append(embeddings.cpu().numpy())
return np.concatenate(all_embeddings)
# Cast 3 -------------------------------------------------------------------------
# Funkcia na výpočet cosinovej similarity
def cosine_similarity(x, y):
return np.dot(x, y) / (norm(x) * norm(y))
# Funkcia na nájdenie indexov obrázkov
def maxCS_indices(text_input, embeddings):
text_embedding = get_clip_embeddings(text_input, input_type='text')
x = text_embedding
Y = embeddings
# print("Text embedding shape:", x.shape)
# print("Embeddings shape:", Y.shape)
# Vypočítaj cosinovú similaritu pre každý riadok matice Y
cosine_similarities = np.array([cosine_similarity(x, y) for y in Y])
# Získaj indexy štyroch vektorov s najväčšou cosinovou similaritou
maxCS_indices = np.argsort(cosine_similarities, axis = 0)[-4:]
# Výsledné vektory
least_similar_vectors = Y[maxCS_indices]
# print("Indexy vektorov s najmenšou cosinovou similaritou:", smallest_indices)
# print("Vektory s najmenšou cosinovou similaritou:\n", least_similar_vectors)
return(maxCS_indices)
# Cast 4 -------------------------------------------------------------------------
def which_images(images_folder, indices):
# Získání všech názvů obrázků ve složce
image_filenames = [f for f in os.listdir(images_folder) if f.endswith(('.jpg', '.png'))]
# Vytvoření numpy array z názvů obrázků
image_names_array = np.array(image_filenames)
# Vytvorenie vektora
image_names = (image_names_array[indices])
# print(image_names_array[smallest_indices])
# Transformácia z poľa na vektor
image_names_final = image_names.flatten()
# print(image_names_final)
return(image_names_final)
# Cast 5 -------------------------------------------------------------------------
def display_images(folder_path, image_names):
# Zložka s obrázkami
folder = Path(folder_path)
# Inicializuj subplots pre 2x2 grid (4 obrázky)
fig, axes = plt.subplots(1, 4, figsize=(20, 5))
# Prejdi cez všetky zadané obrázky a vykresli ich
for ax, img_name in zip(axes, image_names):
# Načítaj obrázok
img_path = folder / img_name
img = Image.open(img_path)
# Vykresli obrázok na subplot
ax.imshow(img)
ax.set_title(img_name) # Nastav názov obrázka ako titulok
ax.axis('off') # Skry výstup osí
# Zobraz obrázky
plt.show()
# Cast 6 -------------------------------------------------------------------------
# Nastavenie parametrov pre funkciu process_input
images_folder = "Datasets/kotlarska2/Trains"
device = "cuda" if torch.cuda.is_available() else "cpu"
embeddings = get_clip_embeddings_batch(images_folder, 32, device)
# Hlavná funkcia na spracovanie vstupu a zobrazenie obrázkov
def process_input(text_input):
our_indinces = maxCS_indices(text_input, embeddings)
our_images = which_images(images_folder, our_indinces)
return display_our_images(images_folder, our_images)
# Funkcia na zobrazenie obrázkov
def display_our_images(folder_path, image_names):
# Zložka s obrázkami
folder = Path(folder_path)
# Inicializuj subplots pre 2x2 grid (4 obrázky)
fig, axes = plt.subplots(1, 4, figsize=(20, 5))
# Prejdi cez všetky zadané obrázky a vykresli ich
for ax, img_name in zip(axes, image_names):
# Načítaj obrázok
img_path = folder / img_name
img = Image.open(img_path)
# Vykresli obrázok na subplot
ax.imshow(img)
ax.set_title(img_name) # Nastav názov obrázka ako titulok
ax.axis('off') # Skry výstup osí
# Ulož obrázok do súboru
plt.tight_layout()
plt.savefig('output_images.png')
plt.close()
return 'output_images.png'
# Nastav a spusti Gradio rozhranie
iface = gr.Interface(
fn=process_input,
inputs="text",
outputs="image",
title="Image Similarity",
description="Zadaj text a zobrazia sa 4 najpodobnejšie obrázky z našej databázy SUV vozidiel.")
iface.launch(share=True)
|