|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import random |
|
|
from dataclasses import dataclass |
|
|
from typing import Any, List, Dict, Optional, Union, Tuple |
|
|
import os |
|
|
|
|
|
import cv2 |
|
|
import torch |
|
|
import requests |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import clip |
|
|
import plotly.express as px |
|
|
from datetime import datetime |
|
|
import matplotlib.pyplot as plt |
|
|
import plotly.graph_objects as go |
|
|
from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class BoundingBox: |
|
|
xmin: int |
|
|
ymin: int |
|
|
xmax: int |
|
|
ymax: int |
|
|
|
|
|
@property |
|
|
def xyxy(self) -> List[float]: |
|
|
return [self.xmin, self.ymin, self.xmax, self.ymax] |
|
|
|
|
|
@dataclass |
|
|
class DetectionResult: |
|
|
score: float |
|
|
label: str |
|
|
box: BoundingBox |
|
|
mask: Optional[np.array] = None |
|
|
|
|
|
@classmethod |
|
|
def from_dict(cls, detection_dict: Dict) -> 'DetectionResult': |
|
|
return cls(score=detection_dict['score'], |
|
|
label=detection_dict['label'], |
|
|
box=BoundingBox(xmin=detection_dict['box']['xmin'], |
|
|
ymin=detection_dict['box']['ymin'], |
|
|
xmax=detection_dict['box']['xmax'], |
|
|
ymax=detection_dict['box']['ymax'])) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def annotate(image: Union[Image.Image, np.ndarray], detection_results: List[DetectionResult]) -> np.ndarray: |
|
|
|
|
|
image_cv2 = np.array(image) if isinstance(image, Image.Image) else image |
|
|
image_cv2 = cv2.cvtColor(image_cv2, cv2.COLOR_RGB2BGR) |
|
|
|
|
|
|
|
|
for detection in detection_results: |
|
|
label = detection.label |
|
|
score = detection.score |
|
|
box = detection.box |
|
|
mask = detection.mask |
|
|
|
|
|
|
|
|
color = np.random.randint(0, 256, size=3) |
|
|
|
|
|
|
|
|
cv2.rectangle(image_cv2, (box.xmin, box.ymin), (box.xmax, box.ymax), color.tolist(), 2) |
|
|
cv2.putText(imagUnione_cv2, f'{label}: {score:.2f}', (box.xmin, box.ymin - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color.tolist(), 2) |
|
|
|
|
|
|
|
|
if mask is not None: |
|
|
|
|
|
mask_uint8 = (mask * 255).astype(np.uint8) |
|
|
contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
cv2.drawContours(image_cv2, contours, -1, color.tolist(), 2) |
|
|
|
|
|
return cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
def plot_detections( |
|
|
image: Union[Image.Image, np.ndarray], |
|
|
detections: List[DetectionResult], |
|
|
save_name: Optional[str] = None |
|
|
) -> None: |
|
|
annotated_image = annotate(image, detections) |
|
|
plt.imshow(annotated_image) |
|
|
plt.axis('off') |
|
|
if save_name: |
|
|
plt.savefig(save_name, bbox_inches='tight') |
|
|
plt.show() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def random_named_css_colors(num_colors: int) -> List[str]: |
|
|
""" |
|
|
Returns a list of randomly selected named CSS colors. |
|
|
|
|
|
Args: |
|
|
- num_colors (int): Number of random colors to generate. |
|
|
|
|
|
Returns: |
|
|
- list: List of randomly selected named CSS colors. |
|
|
""" |
|
|
|
|
|
named_css_colors = [ |
|
|
'aliceblue', 'antiquewhite', 'aqua', 'aquamarine', 'azure', 'beige', 'bisque', 'black', 'blanchedalmond', |
|
|
'blue', 'blueviolet', 'brown', 'burlywood', 'cadetblue', 'chartreuse', 'chocolate', 'coral', 'cornflowerblue', |
|
|
'cornsilk', 'crimson', 'cyan', 'darkblue', 'darkcyan', 'darkgoldenrod', 'darkgray', 'darkgreen', 'darkgrey', |
|
|
'darkkhaki', 'darkmagenta', 'darkolivegreen', 'darkorange', 'darkorchid', 'darkred', 'darksalmon', 'darkseagreen', |
|
|
'darkslateblue', 'darkslategray', 'darkslategrey', 'darkturquoise', 'darkviolet', 'deeppink', 'deepskyblue', |
|
|
'dimgray', 'dimgrey', 'dodgerblue', 'firebrick', 'floralwhite', 'forestgreen', 'fuchsia', 'gainsboro', 'ghostwhite', |
|
|
'gold', 'goldenrod', 'gray', 'green', 'greenyellow', 'grey', 'honeydew', 'hotpink', 'indianred', 'indigo', 'ivory', |
|
|
'khaki', 'lavender', 'lavenderblush', 'lawngreen', 'lemonchiffon', 'lightblue', 'lightcoral', 'lightcyan', 'lightgoldenrodyellow', |
|
|
'lightgray', 'lightgreen', 'lightgrey', 'lightpink', 'lightsalmon', 'lightseagreen', 'lightskyblue', 'lightslategray', |
|
|
'lightslategrey', 'lightsteelblue', 'lightyellow', 'lime', 'limegreen', 'linen', 'magenta', 'maroon', 'mediumaquamarine', |
|
|
'mediumblue', 'mediumorchid', 'mediumpurple', 'mediumseagreen', 'mediumslateblue', 'mediumspringgreen', 'mediumturquoise', |
|
|
'mediumvioletred', 'midnightblue', 'mintcream', 'mistyrose', 'moccasin', 'navajowhite', 'navy', 'oldlace', 'olive', |
|
|
'olivedrab', 'orange', 'orangered', 'orchid', 'palegoldenrod', 'palegreen', 'paleturquoise', 'palevioletred', 'papayawhip', |
|
|
'peachpuff', 'peru', 'pink', 'plum', 'powderblue', 'purple', 'rebeccapurple', 'red', 'rosybrown', 'royalblue', 'saddlebrown', |
|
|
'salmon', 'sandybrown', 'seagreen', 'seashell', 'sienna', 'silver', 'skyblue', 'slateblue', 'slategray', 'slategrey', |
|
|
'snow', 'springgreen', 'steelblue', 'tan', 'teal', 'thistle', 'tomato', 'turquoise', 'violet', 'wheat', 'white', |
|
|
'whitesmoke', 'yellow', 'yellowgreen' |
|
|
] |
|
|
|
|
|
|
|
|
return random.sample(named_css_colors, min(num_colors, len(named_css_colors))) |
|
|
|
|
|
def plot_detections_plotly( |
|
|
image: np.ndarray, |
|
|
detections: List[DetectionResult], |
|
|
class_colors: Optional[Dict[str, str]] = None |
|
|
) -> None: |
|
|
|
|
|
if class_colors is None: |
|
|
num_detections = len(detections) |
|
|
colors = random_named_css_colors(num_detections) |
|
|
class_colors = {} |
|
|
for i in range(num_detections): |
|
|
class_colors[i] = colors[i] |
|
|
|
|
|
|
|
|
fig = px.imshow(image) |
|
|
|
|
|
|
|
|
shapes = [] |
|
|
annotations = [] |
|
|
for idx, detection in enumerate(detections): |
|
|
label = detection.label |
|
|
box = detection.box |
|
|
score = detection.score |
|
|
mask = detection.mask |
|
|
|
|
|
polygon = mask_to_polygon(mask) |
|
|
|
|
|
fig.add_trace(go.Scatter( |
|
|
x=[point[0] for point in polygon] + [polygon[0][0]], |
|
|
y=[point[1] for point in polygon] + [polygon[0][1]], |
|
|
mode='lines', |
|
|
line=dict(color=class_colors[idx], width=2), |
|
|
fill='toself', |
|
|
name=f"{label}: {score:.2f}" |
|
|
)) |
|
|
|
|
|
xmin, ymin, xmax, ymax = box.xyxy |
|
|
shape = [ |
|
|
dict( |
|
|
type="rect", |
|
|
xref="x", yref="y", |
|
|
x0=xmin, y0=ymin, |
|
|
x1=xmax, y1=ymax, |
|
|
line=dict(color=class_colors[idx]) |
|
|
) |
|
|
] |
|
|
annotation = [ |
|
|
dict( |
|
|
x=(xmin+xmax) // 2, y=(ymin+ymax) // 2, |
|
|
xref="x", yref="y", |
|
|
text=f"{label}: {score:.2f}", |
|
|
) |
|
|
] |
|
|
|
|
|
shapes.append(shape) |
|
|
annotations.append(annotation) |
|
|
|
|
|
|
|
|
button_shapes = [dict(label="None",method="relayout",args=["shapes", []])] |
|
|
button_shapes = button_shapes + [ |
|
|
dict(label=f"Detection {idx+1}",method="relayout",args=["shapes", shape]) for idx, shape in enumerate(shapes) |
|
|
] |
|
|
button_shapes = button_shapes + [dict(label="All", method="relayout", args=["shapes", sum(shapes, [])])] |
|
|
|
|
|
fig.update_layout( |
|
|
xaxis=dict(visible=False), |
|
|
yaxis=dict(visible=False), |
|
|
|
|
|
showlegend=True, |
|
|
updatemenus=[ |
|
|
dict( |
|
|
type="buttons", |
|
|
direction="up", |
|
|
buttons=button_shapes |
|
|
) |
|
|
], |
|
|
legend=dict( |
|
|
orientation="h", |
|
|
yanchor="bottom", |
|
|
y=1.02, |
|
|
xanchor="right", |
|
|
x=1 |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
fig.show() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def mask_to_polygon(mask: np.ndarray) -> List[List[int]]: |
|
|
|
|
|
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
|
|
|
|
|
|
largest_contour = max(contours, key=cv2.contourArea) |
|
|
|
|
|
|
|
|
polygon = largest_contour.reshape(-1, 2).tolist() |
|
|
|
|
|
return polygon |
|
|
|
|
|
def polygon_to_mask(polygon: List[Tuple[int, int]], image_shape: Tuple[int, int]) -> np.ndarray: |
|
|
""" |
|
|
Convert a polygon to a segmentation mask. |
|
|
|
|
|
Args: |
|
|
- polygon (list): List of (x, y) coordinates representing the vertices of the polygon. |
|
|
- image_shape (tuple): Shape of the image (height, width) for the mask. |
|
|
|
|
|
Returns: |
|
|
- np.ndarray: Segmentation mask with the polygon filled. |
|
|
""" |
|
|
|
|
|
mask = np.zeros(image_shape, dtype=np.uint8) |
|
|
|
|
|
|
|
|
pts = np.array(polygon, dtype=np.int32) |
|
|
|
|
|
|
|
|
cv2.fillPoly(mask, [pts], color=(255,)) |
|
|
|
|
|
return mask |
|
|
|
|
|
def load_image(image_str: str) -> Image.Image: |
|
|
if image_str.startswith("http"): |
|
|
image = Image.open(requests.get(image_str, stream=True).raw).convert("RGB") |
|
|
else: |
|
|
image = Image.open(image_str).convert("RGB") |
|
|
|
|
|
return image |
|
|
|
|
|
def get_boxes(results: DetectionResult) -> List[List[List[float]]]: |
|
|
boxes = [] |
|
|
for result in results: |
|
|
xyxy = result.box.xyxy |
|
|
boxes.append(xyxy) |
|
|
|
|
|
return [boxes] |
|
|
|
|
|
def refine_masks(masks: torch.BoolTensor, polygon_refinement: bool = False) -> List[np.ndarray]: |
|
|
masks = masks.cpu().float() |
|
|
masks = masks.permute(0, 2, 3, 1) |
|
|
masks = masks.mean(axis=-1) |
|
|
masks = (masks > 0).int() |
|
|
masks = masks.numpy().astype(np.uint8) |
|
|
masks = list(masks) |
|
|
|
|
|
if polygon_refinement: |
|
|
for idx, mask in enumerate(masks): |
|
|
shape = mask.shape |
|
|
polygon = mask_to_polygon(mask) |
|
|
mask = polygon_to_mask(polygon, shape) |
|
|
masks[idx] = mask |
|
|
|
|
|
return masks |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def detect( |
|
|
image: Image.Image, |
|
|
labels: List[str], |
|
|
threshold: float = 0.3, |
|
|
detector_id: Optional[str] = None |
|
|
) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Use Grounding DINO to detect a set of labels in an image in a zero-shot fashion. |
|
|
""" |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
detector_id = detector_id if detector_id is not None else "IDEA-Research/grounding-dino-tiny" |
|
|
object_detector = pipeline(model=detector_id, task="zero-shot-object-detection", device=device) |
|
|
|
|
|
labels = [label if label.endswith(".") else label+"." for label in labels] |
|
|
|
|
|
results = object_detector(image, candidate_labels=labels, threshold=threshold) |
|
|
results = [DetectionResult.from_dict(result) for result in results] |
|
|
|
|
|
return results |
|
|
|
|
|
def segment( |
|
|
image: Image.Image, |
|
|
detection_results: List[Dict[str, Any]], |
|
|
polygon_refinement: bool = False, |
|
|
segmenter_id: Optional[str] = None |
|
|
) -> List[DetectionResult]: |
|
|
""" |
|
|
Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes. |
|
|
""" |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
segmenter_id = segmenter_id if segmenter_id is not None else "facebook/sam-vit-base" |
|
|
|
|
|
segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to(device) |
|
|
processor = AutoProcessor.from_pretrained(segmenter_id) |
|
|
|
|
|
boxes = get_boxes(detection_results) |
|
|
inputs = processor(images=image, input_boxes=boxes, return_tensors="pt").to(device) |
|
|
|
|
|
outputs = segmentator(**inputs) |
|
|
masks = processor.post_process_masks( |
|
|
masks=outputs.pred_masks, |
|
|
original_sizes=inputs.original_sizes, |
|
|
reshaped_input_sizes=inputs.reshaped_input_sizes |
|
|
)[0] |
|
|
|
|
|
masks = refine_masks(masks, polygon_refinement) |
|
|
|
|
|
for detection_result, mask in zip(detection_results, masks): |
|
|
detection_result.mask = mask |
|
|
|
|
|
return detection_results |
|
|
|
|
|
def grounded_segmentation( |
|
|
image: Union[Image.Image, str], |
|
|
labels: List[str], |
|
|
threshold: float = 0.3, |
|
|
polygon_refinement: bool = False, |
|
|
detector_id: Optional[str] = None, |
|
|
segmenter_id: Optional[str] = None |
|
|
) -> Tuple[np.ndarray, List[DetectionResult]]: |
|
|
if isinstance(image, str): |
|
|
image = load_image(image) |
|
|
|
|
|
detections = detect(image, labels, threshold, detector_id) |
|
|
detections = segment(image, detections, polygon_refinement, segmenter_id) |
|
|
|
|
|
return image, detections |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def cut_image(image, mask, box): |
|
|
ny_image = np.array(image) |
|
|
cut = cv2.bitwise_and(ny_image, ny_image, mask=mask.astype(np.uint8)*255) |
|
|
x0, y0, x1, y1 = map(int, box.xyxy) |
|
|
cropped = cut[y0:y1, x0:x1] |
|
|
cropped_bgr = cv2.cvtColor(cropped, cv2.COLOR_RGB2BGR) |
|
|
return cropped_bgr |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image_url = "/home/dheena/Downloads/fashion/images (1).jpeg" |
|
|
labels = ["a dress"] |
|
|
threshold = 0.3 |
|
|
|
|
|
detector_id = "IDEA-Research/grounding-dino-tiny" |
|
|
segmenter_id = "facebook/sam-vit-base" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import faiss |
|
|
import torch |
|
|
import clip |
|
|
from openai import OpenAI |
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
|
|
|
from tqdm import tqdm |
|
|
import os |
|
|
import numpy as np |
|
|
from typing import List, Tuple |
|
|
|
|
|
|
|
|
from PIL import Image |
|
|
from fastapi import FastAPI |
|
|
from typing import List |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
client = OpenAI() |
|
|
|
|
|
|
|
|
device = "cpu" |
|
|
model, preprocess = clip.load("ViT-B/32", device=device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_image_features(image: Image.Image) -> np.ndarray: |
|
|
"""Extract CLIP features from an image.""" |
|
|
image_input = preprocess(image).unsqueeze(0).to(device) |
|
|
with torch.no_grad(): |
|
|
image_features = model.encode_image(image_input).float() |
|
|
return image_features.cpu().numpy() |
|
|
|
|
|
|
|
|
index = faiss.IndexFlatIP(512) |
|
|
meta_data_store = [] |
|
|
|
|
|
def save_image_in_index(image_features: np.ndarray, metadata: dict): |
|
|
"""Normalize features and add to index.""" |
|
|
faiss.normalize_L2(image_features) |
|
|
index.add(image_features) |
|
|
meta_data_store.append(metadata) |
|
|
|
|
|
def process_image_embedding(image_url: str, labels=['clothes']) -> np.ndarray: |
|
|
"""Get feature embedding for a query image.""" |
|
|
search_image, search_detections = grounded_segmentation(image=image_url, labels=labels) |
|
|
cropped_image = cut_image(search_image, search_detections[0].mask, search_detections[0].box) |
|
|
|
|
|
|
|
|
if cropped_image.dtype != np.uint8: |
|
|
cropped_image = (cropped_image * 255).astype(np.uint8) |
|
|
if cropped_image.ndim == 2: |
|
|
cropped_image = np.stack([cropped_image] * 3, axis=-1) |
|
|
|
|
|
pil_image = Image.fromarray(cropped_image) |
|
|
return pil_image |
|
|
|
|
|
def get_top_k_results(image_url: str, k: int = 10) -> List[dict]: |
|
|
"""Find top-k similar images from the index.""" |
|
|
processed_image = process_image_embedding(image_url) |
|
|
image_search_embedding = get_image_features(processed_image) |
|
|
faiss.normalize_L2(image_search_embedding) |
|
|
distances, indices = index.search(image_search_embedding.reshape(1, -1), k) |
|
|
|
|
|
results = [] |
|
|
for i, dist in zip(indices[0], distances[0]): |
|
|
if i < len(meta_data_store): |
|
|
results.append({ |
|
|
'metadata': meta_data_store[i], |
|
|
'score': float(dist) |
|
|
}) |
|
|
return results |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
@app.get("/similar_images") |
|
|
def get_similar_images(image_url: str, k: int = 10): |
|
|
results = get_top_k_results(image_url, k) |
|
|
|
|
|
return { |
|
|
"results": [ |
|
|
{ |
|
|
"metadata": item["metadata"], |
|
|
"score": item["score"] |
|
|
} |
|
|
for item in results |
|
|
] |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|