#!/usr/bin/env python # coding: utf-8 # In[30]: import random from dataclasses import dataclass from typing import Any, List, Dict, Optional, Union, Tuple import os import cv2 import torch import requests import numpy as np from PIL import Image import clip import plotly.express as px from datetime import datetime import matplotlib.pyplot as plt import plotly.graph_objects as go from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline # In[2]: @dataclass class BoundingBox: xmin: int ymin: int xmax: int ymax: int @property def xyxy(self) -> List[float]: return [self.xmin, self.ymin, self.xmax, self.ymax] @dataclass class DetectionResult: score: float label: str box: BoundingBox mask: Optional[np.array] = None @classmethod def from_dict(cls, detection_dict: Dict) -> 'DetectionResult': return cls(score=detection_dict['score'], label=detection_dict['label'], box=BoundingBox(xmin=detection_dict['box']['xmin'], ymin=detection_dict['box']['ymin'], xmax=detection_dict['box']['xmax'], ymax=detection_dict['box']['ymax'])) # In[3]: def annotate(image: Union[Image.Image, np.ndarray], detection_results: List[DetectionResult]) -> np.ndarray: # Convert PIL Image to OpenCV format image_cv2 = np.array(image) if isinstance(image, Image.Image) else image image_cv2 = cv2.cvtColor(image_cv2, cv2.COLOR_RGB2BGR) # Iterate over detections and add bounding boxes and masks for detection in detection_results: label = detection.label score = detection.score box = detection.box mask = detection.mask # Sample a random color for each detection color = np.random.randint(0, 256, size=3) # Draw bounding box cv2.rectangle(image_cv2, (box.xmin, box.ymin), (box.xmax, box.ymax), color.tolist(), 2) cv2.putText(imagUnione_cv2, f'{label}: {score:.2f}', (box.xmin, box.ymin - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color.tolist(), 2) # If mask is available, apply it if mask is not None: # Convert mask to uint8 mask_uint8 = (mask * 255).astype(np.uint8) contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) cv2.drawContours(image_cv2, contours, -1, color.tolist(), 2) return cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB) def plot_detections( image: Union[Image.Image, np.ndarray], detections: List[DetectionResult], save_name: Optional[str] = None ) -> None: annotated_image = annotate(image, detections) plt.imshow(annotated_image) plt.axis('off') if save_name: plt.savefig(save_name, bbox_inches='tight') plt.show() # In[4]: def random_named_css_colors(num_colors: int) -> List[str]: """ Returns a list of randomly selected named CSS colors. Args: - num_colors (int): Number of random colors to generate. Returns: - list: List of randomly selected named CSS colors. """ # List of named CSS colors named_css_colors = [ 'aliceblue', 'antiquewhite', 'aqua', 'aquamarine', 'azure', 'beige', 'bisque', 'black', 'blanchedalmond', 'blue', 'blueviolet', 'brown', 'burlywood', 'cadetblue', 'chartreuse', 'chocolate', 'coral', 'cornflowerblue', 'cornsilk', 'crimson', 'cyan', 'darkblue', 'darkcyan', 'darkgoldenrod', 'darkgray', 'darkgreen', 'darkgrey', 'darkkhaki', 'darkmagenta', 'darkolivegreen', 'darkorange', 'darkorchid', 'darkred', 'darksalmon', 'darkseagreen', 'darkslateblue', 'darkslategray', 'darkslategrey', 'darkturquoise', 'darkviolet', 'deeppink', 'deepskyblue', 'dimgray', 'dimgrey', 'dodgerblue', 'firebrick', 'floralwhite', 'forestgreen', 'fuchsia', 'gainsboro', 'ghostwhite', 'gold', 'goldenrod', 'gray', 'green', 'greenyellow', 'grey', 'honeydew', 'hotpink', 'indianred', 'indigo', 'ivory', 'khaki', 'lavender', 'lavenderblush', 'lawngreen', 'lemonchiffon', 'lightblue', 'lightcoral', 'lightcyan', 'lightgoldenrodyellow', 'lightgray', 'lightgreen', 'lightgrey', 'lightpink', 'lightsalmon', 'lightseagreen', 'lightskyblue', 'lightslategray', 'lightslategrey', 'lightsteelblue', 'lightyellow', 'lime', 'limegreen', 'linen', 'magenta', 'maroon', 'mediumaquamarine', 'mediumblue', 'mediumorchid', 'mediumpurple', 'mediumseagreen', 'mediumslateblue', 'mediumspringgreen', 'mediumturquoise', 'mediumvioletred', 'midnightblue', 'mintcream', 'mistyrose', 'moccasin', 'navajowhite', 'navy', 'oldlace', 'olive', 'olivedrab', 'orange', 'orangered', 'orchid', 'palegoldenrod', 'palegreen', 'paleturquoise', 'palevioletred', 'papayawhip', 'peachpuff', 'peru', 'pink', 'plum', 'powderblue', 'purple', 'rebeccapurple', 'red', 'rosybrown', 'royalblue', 'saddlebrown', 'salmon', 'sandybrown', 'seagreen', 'seashell', 'sienna', 'silver', 'skyblue', 'slateblue', 'slategray', 'slategrey', 'snow', 'springgreen', 'steelblue', 'tan', 'teal', 'thistle', 'tomato', 'turquoise', 'violet', 'wheat', 'white', 'whitesmoke', 'yellow', 'yellowgreen' ] # Sample random named CSS colors return random.sample(named_css_colors, min(num_colors, len(named_css_colors))) def plot_detections_plotly( image: np.ndarray, detections: List[DetectionResult], class_colors: Optional[Dict[str, str]] = None ) -> None: # If class_colors is not provided, generate random colors for each class if class_colors is None: num_detections = len(detections) colors = random_named_css_colors(num_detections) class_colors = {} for i in range(num_detections): class_colors[i] = colors[i] fig = px.imshow(image) # Add bounding boxes shapes = [] annotations = [] for idx, detection in enumerate(detections): label = detection.label box = detection.box score = detection.score mask = detection.mask polygon = mask_to_polygon(mask) fig.add_trace(go.Scatter( x=[point[0] for point in polygon] + [polygon[0][0]], y=[point[1] for point in polygon] + [polygon[0][1]], mode='lines', line=dict(color=class_colors[idx], width=2), fill='toself', name=f"{label}: {score:.2f}" )) xmin, ymin, xmax, ymax = box.xyxy shape = [ dict( type="rect", xref="x", yref="y", x0=xmin, y0=ymin, x1=xmax, y1=ymax, line=dict(color=class_colors[idx]) ) ] annotation = [ dict( x=(xmin+xmax) // 2, y=(ymin+ymax) // 2, xref="x", yref="y", text=f"{label}: {score:.2f}", ) ] shapes.append(shape) annotations.append(annotation) # Update layout button_shapes = [dict(label="None",method="relayout",args=["shapes", []])] button_shapes = button_shapes + [ dict(label=f"Detection {idx+1}",method="relayout",args=["shapes", shape]) for idx, shape in enumerate(shapes) ] button_shapes = button_shapes + [dict(label="All", method="relayout", args=["shapes", sum(shapes, [])])] fig.update_layout( xaxis=dict(visible=False), yaxis=dict(visible=False), # margin=dict(l=0, r=0, t=0, b=0), showlegend=True, updatemenus=[ dict( type="buttons", direction="up", buttons=button_shapes ) ], legend=dict( orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1 ) ) # Show plot fig.show() # In[5]: def mask_to_polygon(mask: np.ndarray) -> List[List[int]]: # Find contours in the binary mask contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) # Find the contour with the largest area largest_contour = max(contours, key=cv2.contourArea) # Extract the vertices of the contour polygon = largest_contour.reshape(-1, 2).tolist() return polygon def polygon_to_mask(polygon: List[Tuple[int, int]], image_shape: Tuple[int, int]) -> np.ndarray: """ Convert a polygon to a segmentation mask. Args: - polygon (list): List of (x, y) coordinates representing the vertices of the polygon. - image_shape (tuple): Shape of the image (height, width) for the mask. Returns: - np.ndarray: Segmentation mask with the polygon filled. """ # Create an empty mask mask = np.zeros(image_shape, dtype=np.uint8) # Convert polygon to an array of points pts = np.array(polygon, dtype=np.int32) # Fill the polygon with white color (255) cv2.fillPoly(mask, [pts], color=(255,)) return mask def load_image(image_str: str) -> Image.Image: if image_str.startswith("http"): image = Image.open(requests.get(image_str, stream=True).raw).convert("RGB") else: image = Image.open(image_str).convert("RGB") return image def get_boxes(results: DetectionResult) -> List[List[List[float]]]: boxes = [] for result in results: xyxy = result.box.xyxy boxes.append(xyxy) return [boxes] def refine_masks(masks: torch.BoolTensor, polygon_refinement: bool = False) -> List[np.ndarray]: masks = masks.cpu().float() masks = masks.permute(0, 2, 3, 1) masks = masks.mean(axis=-1) masks = (masks > 0).int() masks = masks.numpy().astype(np.uint8) masks = list(masks) if polygon_refinement: for idx, mask in enumerate(masks): shape = mask.shape polygon = mask_to_polygon(mask) mask = polygon_to_mask(polygon, shape) masks[idx] = mask return masks # In[6]: def detect( image: Image.Image, labels: List[str], threshold: float = 0.3, detector_id: Optional[str] = None ) -> List[Dict[str, Any]]: """ Use Grounding DINO to detect a set of labels in an image in a zero-shot fashion. """ device = "cuda" if torch.cuda.is_available() else "cpu" detector_id = detector_id if detector_id is not None else "IDEA-Research/grounding-dino-tiny" object_detector = pipeline(model=detector_id, task="zero-shot-object-detection", device=device) labels = [label if label.endswith(".") else label+"." for label in labels] results = object_detector(image, candidate_labels=labels, threshold=threshold) results = [DetectionResult.from_dict(result) for result in results] return results def segment( image: Image.Image, detection_results: List[Dict[str, Any]], polygon_refinement: bool = False, segmenter_id: Optional[str] = None ) -> List[DetectionResult]: """ Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes. """ device = "cuda" if torch.cuda.is_available() else "cpu" segmenter_id = segmenter_id if segmenter_id is not None else "facebook/sam-vit-base" segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to(device) processor = AutoProcessor.from_pretrained(segmenter_id) boxes = get_boxes(detection_results) inputs = processor(images=image, input_boxes=boxes, return_tensors="pt").to(device) outputs = segmentator(**inputs) masks = processor.post_process_masks( masks=outputs.pred_masks, original_sizes=inputs.original_sizes, reshaped_input_sizes=inputs.reshaped_input_sizes )[0] masks = refine_masks(masks, polygon_refinement) for detection_result, mask in zip(detection_results, masks): detection_result.mask = mask return detection_results def grounded_segmentation( image: Union[Image.Image, str], labels: List[str], threshold: float = 0.3, polygon_refinement: bool = False, detector_id: Optional[str] = None, segmenter_id: Optional[str] = None ) -> Tuple[np.ndarray, List[DetectionResult]]: if isinstance(image, str): image = load_image(image) detections = detect(image, labels, threshold, detector_id) detections = segment(image, detections, polygon_refinement, segmenter_id) return image, detections # In[7]: # save clipped images def cut_image(image, mask, box): ny_image = np.array(image) cut = cv2.bitwise_and(ny_image, ny_image, mask=mask.astype(np.uint8)*255) x0, y0, x1, y1 = map(int, box.xyxy) cropped = cut[y0:y1, x0:x1] cropped_bgr = cv2.cvtColor(cropped, cv2.COLOR_RGB2BGR) return cropped_bgr # In[8]: image_url = "/home/dheena/Downloads/fashion/images (1).jpeg" labels = ["a dress"] threshold = 0.3 detector_id = "IDEA-Research/grounding-dino-tiny" segmenter_id = "facebook/sam-vit-base" # In[9]: # image, detections = grounded_segmentation( # image=image_url, # labels=labels, # threshold=threshold, # polygon_refinement=True, # detector_id=detector_id, # segmenter_id=segmenter_id # ) # current = datetime.now().strftime("%Y-%m-%d_%H:%M:%S") # cropped_image = cut_image(image, detections[0].mask, detections[0].box) # cv2.imwrite("/home/dheena/Downloads/fashion/output/" + current, cropped_image) # In[44]: # plot_detections(np.array(image), detections, "test.png") # In[60]: # model imports import faiss import torch import clip from openai import OpenAI from torch.utils.data import DataLoader # helper imports from tqdm import tqdm import os import numpy as np from typing import List, Tuple # visualization imports from PIL import Image from fastapi import FastAPI from typing import List import matplotlib.pyplot as plt client = OpenAI() # Set device device = "cpu" model, preprocess = clip.load("ViT-B/32", device=device) # # Directory path # direc = '/home/dheena/Downloads/fashion/output/' # def get_image(filepath: str) -> Image.Image: # """Safely load and convert an image file to RGB PIL format.""" # try: # return Image.open(filepath).convert("RGB") # except Exception as e: # print(f"Failed to load {filepath}: {e}") # return None # def get_all_images_from_dir(directory: str) -> List[Tuple[str, Image.Image]]: # """Load all supported images from a directory, with paths.""" # supported_exts = ('.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp') # image_data = [] # for root, _, files in os.walk(directory): # for file in files: # if file.lower().endswith(supported_exts): # full_path = os.path.join(root, file) # try: # img = Image.open(full_path).convert("RGB") # image_data.append((full_path, img)) # except Exception as e: # print(f"Error loading {full_path}: {e}") # return image_data def get_image_features(image: Image.Image) -> np.ndarray: """Extract CLIP features from an image.""" image_input = preprocess(image).unsqueeze(0).to(device) with torch.no_grad(): image_features = model.encode_image(image_input).float() return image_features.cpu().numpy() # FAISS setup index = faiss.IndexFlatIP(512) meta_data_store = [] def save_image_in_index(image_features: np.ndarray, metadata: dict): """Normalize features and add to index.""" faiss.normalize_L2(image_features) index.add(image_features) meta_data_store.append(metadata) def process_image_embedding(image_url: str, labels=['clothes']) -> np.ndarray: """Get feature embedding for a query image.""" search_image, search_detections = grounded_segmentation(image=image_url, labels=labels) cropped_image = cut_image(search_image, search_detections[0].mask, search_detections[0].box) # Convert to valid RGB if cropped_image.dtype != np.uint8: cropped_image = (cropped_image * 255).astype(np.uint8) if cropped_image.ndim == 2: cropped_image = np.stack([cropped_image] * 3, axis=-1) pil_image = Image.fromarray(cropped_image) return pil_image def get_top_k_results(image_url: str, k: int = 10) -> List[dict]: """Find top-k similar images from the index.""" processed_image = process_image_embedding(image_url) image_search_embedding = get_image_features(processed_image) faiss.normalize_L2(image_search_embedding) distances, indices = index.search(image_search_embedding.reshape(1, -1), k) results = [] for i, dist in zip(indices[0], distances[0]): if i < len(meta_data_store): results.append({ 'metadata': meta_data_store[i], 'score': float(dist) }) return results # def display_similar_images(results: List[dict]): # """Display retrieved images using matplotlib.""" # for item in results: # img = get_image(item['metadata']['image_path']) # if img: # print(f"Score: {item['score']:.4f}") # plt.imshow(img) # plt.axis('off') # plt.show() # In[73]: app = FastAPI() @app.get("/similar_images") def get_similar_images(image_url: str, k: int = 10): results = get_top_k_results(image_url, k) # display_similar_images(results) # Optional visualization call return { "results": [ { "metadata": item["metadata"], "score": item["score"] } for item in results ] } # Example usage: # results = get_top_k_results("/home/dheena/Downloads/fashion/temp/KPR-120-Wine_2_1024x1024.webp") # display_similar_images(results) # In[54]: