Update app.py
Browse files
app.py
CHANGED
|
@@ -1,3 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
from transformers import CLIPProcessor, CLIPModel
|
| 3 |
from PIL import Image
|
|
@@ -18,15 +25,35 @@ CACHE_FILE = "cache.pkl"
|
|
| 18 |
# Define supported image formats
|
| 19 |
IMAGE_EXTENSIONS = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.gif", "*.webp", "*.tiff", "*.tif"]
|
| 20 |
|
| 21 |
-
def get_all_image_files():
|
| 22 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
image_files = []
|
| 24 |
for ext in IMAGE_EXTENSIONS:
|
| 25 |
image_files.extend(DATASET_DIR.glob(ext))
|
| 26 |
image_files.extend(DATASET_DIR.glob(ext.upper())) # Also check uppercase
|
| 27 |
return image_files
|
| 28 |
|
| 29 |
-
def get_embedding(image: Image.Image, device="cpu"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
# Use CLIP's built-in preprocessing
|
| 31 |
inputs = processor(images=image, return_tensors="pt").to(device)
|
| 32 |
model_device = model.to(device)
|
|
@@ -37,7 +64,20 @@ def get_embedding(image: Image.Image, device="cpu"):
|
|
| 37 |
return emb
|
| 38 |
|
| 39 |
@spaces.GPU
|
| 40 |
-
def get_reference_embeddings():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
# Get all current image files
|
| 42 |
current_image_files = get_all_image_files()
|
| 43 |
current_images = set(img_path.name for img_path in current_image_files)
|
|
@@ -79,7 +119,20 @@ def get_reference_embeddings():
|
|
| 79 |
reference_embeddings = get_reference_embeddings()
|
| 80 |
|
| 81 |
@spaces.GPU
|
| 82 |
-
def search_similar(query_img):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
# Refresh embeddings to catch any new images
|
| 84 |
global reference_embeddings
|
| 85 |
reference_embeddings = get_reference_embeddings()
|
|
@@ -107,7 +160,22 @@ def search_similar(query_img):
|
|
| 107 |
return [(f"dataset/{name}", f"Score: {score:.4f}") for name, score in filtered_results[:5]]
|
| 108 |
|
| 109 |
@spaces.GPU
|
| 110 |
-
def add_image(name: str, image):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
if not name.strip():
|
| 112 |
return "Please provide a valid image name."
|
| 113 |
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CLIP Image Search Application
|
| 3 |
+
|
| 4 |
+
A Gradio-based application for searching similar images using OpenAI's CLIP model.
|
| 5 |
+
Supports multiple image formats and provides a web interface for uploading and searching images.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
import gradio as gr
|
| 9 |
from transformers import CLIPProcessor, CLIPModel
|
| 10 |
from PIL import Image
|
|
|
|
| 25 |
# Define supported image formats
|
| 26 |
IMAGE_EXTENSIONS = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.gif", "*.webp", "*.tiff", "*.tif"]
|
| 27 |
|
| 28 |
+
def get_all_image_files() -> List[Path]:
|
| 29 |
+
"""
|
| 30 |
+
Get all image files from the dataset directory.
|
| 31 |
+
|
| 32 |
+
Searches for images with supported extensions in both lowercase and uppercase.
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
List[Path]: List of Path objects for all found image files
|
| 36 |
+
"""
|
| 37 |
image_files = []
|
| 38 |
for ext in IMAGE_EXTENSIONS:
|
| 39 |
image_files.extend(DATASET_DIR.glob(ext))
|
| 40 |
image_files.extend(DATASET_DIR.glob(ext.upper())) # Also check uppercase
|
| 41 |
return image_files
|
| 42 |
|
| 43 |
+
def get_embedding(image: Image.Image, device: str = "cpu") -> torch.Tensor:
|
| 44 |
+
"""
|
| 45 |
+
Generate CLIP embedding for an image.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
image (Image.Image): PIL Image object to process
|
| 49 |
+
device (str, optional): Device to run computation on. Defaults to "cpu".
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
torch.Tensor: L2-normalized image embedding tensor
|
| 53 |
+
|
| 54 |
+
Raises:
|
| 55 |
+
RuntimeError: If CUDA is requested but not available
|
| 56 |
+
"""
|
| 57 |
# Use CLIP's built-in preprocessing
|
| 58 |
inputs = processor(images=image, return_tensors="pt").to(device)
|
| 59 |
model_device = model.to(device)
|
|
|
|
| 64 |
return emb
|
| 65 |
|
| 66 |
@spaces.GPU
|
| 67 |
+
def get_reference_embeddings() -> Dict[str, torch.Tensor]:
|
| 68 |
+
"""
|
| 69 |
+
Load or compute embeddings for all reference images in the dataset.
|
| 70 |
+
|
| 71 |
+
Checks if cached embeddings are up to date with the current dataset.
|
| 72 |
+
If not, recomputes embeddings for all images and updates the cache.
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
Dict[str, torch.Tensor]: Dictionary mapping image filenames to their embeddings
|
| 76 |
+
|
| 77 |
+
Raises:
|
| 78 |
+
FileNotFoundError: If dataset directory doesn't exist
|
| 79 |
+
PermissionError: If unable to write cache file
|
| 80 |
+
"""
|
| 81 |
# Get all current image files
|
| 82 |
current_image_files = get_all_image_files()
|
| 83 |
current_images = set(img_path.name for img_path in current_image_files)
|
|
|
|
| 119 |
reference_embeddings = get_reference_embeddings()
|
| 120 |
|
| 121 |
@spaces.GPU
|
| 122 |
+
def search_similar(query_img: Image.Image) -> List[Tuple[str, str]]:
|
| 123 |
+
"""
|
| 124 |
+
Find similar images to the query image using CLIP embeddings.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
query_img (Image.Image): Query image to find similar images for
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
List[Tuple[str, str]]: List of tuples containing (image_path, similarity_score)
|
| 131 |
+
Limited to top 5 results above similarity threshold
|
| 132 |
+
|
| 133 |
+
Raises:
|
| 134 |
+
RuntimeError: If CUDA operations fail
|
| 135 |
+
"""
|
| 136 |
# Refresh embeddings to catch any new images
|
| 137 |
global reference_embeddings
|
| 138 |
reference_embeddings = get_reference_embeddings()
|
|
|
|
| 160 |
return [(f"dataset/{name}", f"Score: {score:.4f}") for name, score in filtered_results[:5]]
|
| 161 |
|
| 162 |
@spaces.GPU
|
| 163 |
+
def add_image(name: str, image: Image.Image) -> str:
|
| 164 |
+
"""
|
| 165 |
+
Add a new image to the dataset and update embeddings.
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
name (str): Name for the new image (without extension)
|
| 169 |
+
image (Image.Image): PIL Image object to add to dataset
|
| 170 |
+
|
| 171 |
+
Returns:
|
| 172 |
+
str: Success message with total image count
|
| 173 |
+
|
| 174 |
+
Raises:
|
| 175 |
+
ValueError: If name is empty or invalid
|
| 176 |
+
PermissionError: If unable to save image or update cache
|
| 177 |
+
RuntimeError: If embedding computation fails
|
| 178 |
+
"""
|
| 179 |
if not name.strip():
|
| 180 |
return "Please provide a valid image name."
|
| 181 |
|