kotlarska2's picture
Update app.py
cd08094 verified
import torch
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
import os
import numpy as np
from numpy.linalg import norm
import matplotlib.pyplot as plt
import gradio as gr
# Cast 1 -------------------------------------------------------------------------
def get_clip_embeddings(input_data, input_type='text'):
# Load the CLIP model and processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# Prepare the input based on the type
if input_type == 'text':
inputs = processor(text=input_data, return_tensors="pt", padding=True, truncation=True)
elif input_type == 'image':
if isinstance(input_data, str):
image = Image.open(input_data)
elif isinstance(input_data, Image.Image):
image = input_data
else:
raise ValueError("For image input, provide either a file path or a PIL Image object")
inputs = processor(images=image, return_tensors="pt")
else:
raise ValueError("Invalid input_type. Choose 'text' or 'image'")
# Get the embeddings
with torch.no_grad():
if input_type == 'text':
embeddings = model.get_text_features(**inputs)
else:
embeddings = model.get_image_features(**inputs)
return embeddings.numpy()
# Cast 2 -------------------------------------------------------------------------
class ImageDataset(Dataset):
def __init__(self, image_dir, processor):
self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]
self.processor = processor
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image = Image.open(self.image_paths[idx])
return self.processor(images=image, return_tensors="pt")['pixel_values'][0]
def get_clip_embeddings_batch(image_dir, batch_size=32, device='cuda'):
# Load the CLIP model and processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# Create dataset and dataloader
dataset = ImageDataset(image_dir, processor)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4)
all_embeddings = []
model.eval()
with torch.no_grad():
for batch in dataloader:
batch = batch.to(device)
embeddings = model.get_image_features(pixel_values=batch)
all_embeddings.append(embeddings.cpu().numpy())
return np.concatenate(all_embeddings)
# Cast 3 -------------------------------------------------------------------------
# Funkcia na výpočet cosinovej similarity
def cosine_similarity(x, y):
return np.dot(x, y) / (norm(x) * norm(y))
# Funkcia na nájdenie indexov obrázkov
def maxCS_indices(text_input, embeddings):
text_embedding = get_clip_embeddings(text_input, input_type='text')
x = text_embedding
Y = embeddings
# print("Text embedding shape:", x.shape)
# print("Embeddings shape:", Y.shape)
# Vypočítaj cosinovú similaritu pre každý riadok matice Y
cosine_similarities = np.array([cosine_similarity(x, y) for y in Y])
# Získaj indexy štyroch vektorov s najväčšou cosinovou similaritou
maxCS_indices = np.argsort(cosine_similarities, axis = 0)[-4:]
# Výsledné vektory
least_similar_vectors = Y[maxCS_indices]
# print("Indexy vektorov s najmenšou cosinovou similaritou:", smallest_indices)
# print("Vektory s najmenšou cosinovou similaritou:\n", least_similar_vectors)
return(maxCS_indices)
# Cast 4 -------------------------------------------------------------------------
def which_images(images_folder, indices):
# Získání všech názvů obrázků ve složce
image_filenames = [f for f in os.listdir(images_folder) if f.endswith(('.jpg', '.png'))]
# Vytvoření numpy array z názvů obrázků
image_names_array = np.array(image_filenames)
# Vytvorenie vektora
image_names = (image_names_array[indices])
# print(image_names_array[smallest_indices])
# Transformácia z poľa na vektor
image_names_final = image_names.flatten()
# print(image_names_final)
return(image_names_final)
# Cast 5 -------------------------------------------------------------------------
def display_images(folder_path, image_names):
# Zložka s obrázkami
folder = Path(folder_path)
# Inicializuj subplots pre 2x2 grid (4 obrázky)
fig, axes = plt.subplots(1, 4, figsize=(20, 5))
# Prejdi cez všetky zadané obrázky a vykresli ich
for ax, img_name in zip(axes, image_names):
# Načítaj obrázok
img_path = folder / img_name
img = Image.open(img_path)
# Vykresli obrázok na subplot
ax.imshow(img)
ax.set_title(img_name) # Nastav názov obrázka ako titulok
ax.axis('off') # Skry výstup osí
# Zobraz obrázky
plt.show()
# Cast 6 -------------------------------------------------------------------------
# Nastavenie parametrov pre funkciu process_input
images_folder = "Datasets/kotlarska2/Trains"
device = "cuda" if torch.cuda.is_available() else "cpu"
embeddings = get_clip_embeddings_batch(images_folder, 32, device)
# Hlavná funkcia na spracovanie vstupu a zobrazenie obrázkov
def process_input(text_input):
our_indinces = maxCS_indices(text_input, embeddings)
our_images = which_images(images_folder, our_indinces)
return display_our_images(images_folder, our_images)
# Funkcia na zobrazenie obrázkov
def display_our_images(folder_path, image_names):
# Zložka s obrázkami
folder = Path(folder_path)
# Inicializuj subplots pre 2x2 grid (4 obrázky)
fig, axes = plt.subplots(1, 4, figsize=(20, 5))
# Prejdi cez všetky zadané obrázky a vykresli ich
for ax, img_name in zip(axes, image_names):
# Načítaj obrázok
img_path = folder / img_name
img = Image.open(img_path)
# Vykresli obrázok na subplot
ax.imshow(img)
ax.set_title(img_name) # Nastav názov obrázka ako titulok
ax.axis('off') # Skry výstup osí
# Ulož obrázok do súboru
plt.tight_layout()
plt.savefig('output_images.png')
plt.close()
return 'output_images.png'
# Nastav a spusti Gradio rozhranie
iface = gr.Interface(
fn=process_input,
inputs="text",
outputs="image",
title="Image Similarity",
description="Zadaj text a zobrazia sa 4 najpodobnejšie obrázky z našej databázy SUV vozidiel.")
iface.launch(share=True)