getitem / src /image-segmentation.py
dheena
SYS-0000
7a75e77
#!/usr/bin/env python
# coding: utf-8
# In[30]:
import random
from dataclasses import dataclass
from typing import Any, List, Dict, Optional, Union, Tuple
import os
import cv2
import torch
import requests
import numpy as np
from PIL import Image
import clip
import plotly.express as px
from datetime import datetime
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline
# In[2]:
@dataclass
class BoundingBox:
xmin: int
ymin: int
xmax: int
ymax: int
@property
def xyxy(self) -> List[float]:
return [self.xmin, self.ymin, self.xmax, self.ymax]
@dataclass
class DetectionResult:
score: float
label: str
box: BoundingBox
mask: Optional[np.array] = None
@classmethod
def from_dict(cls, detection_dict: Dict) -> 'DetectionResult':
return cls(score=detection_dict['score'],
label=detection_dict['label'],
box=BoundingBox(xmin=detection_dict['box']['xmin'],
ymin=detection_dict['box']['ymin'],
xmax=detection_dict['box']['xmax'],
ymax=detection_dict['box']['ymax']))
# In[3]:
def annotate(image: Union[Image.Image, np.ndarray], detection_results: List[DetectionResult]) -> np.ndarray:
# Convert PIL Image to OpenCV format
image_cv2 = np.array(image) if isinstance(image, Image.Image) else image
image_cv2 = cv2.cvtColor(image_cv2, cv2.COLOR_RGB2BGR)
# Iterate over detections and add bounding boxes and masks
for detection in detection_results:
label = detection.label
score = detection.score
box = detection.box
mask = detection.mask
# Sample a random color for each detection
color = np.random.randint(0, 256, size=3)
# Draw bounding box
cv2.rectangle(image_cv2, (box.xmin, box.ymin), (box.xmax, box.ymax), color.tolist(), 2)
cv2.putText(imagUnione_cv2, f'{label}: {score:.2f}', (box.xmin, box.ymin - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color.tolist(), 2)
# If mask is available, apply it
if mask is not None:
# Convert mask to uint8
mask_uint8 = (mask * 255).astype(np.uint8)
contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(image_cv2, contours, -1, color.tolist(), 2)
return cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
def plot_detections(
image: Union[Image.Image, np.ndarray],
detections: List[DetectionResult],
save_name: Optional[str] = None
) -> None:
annotated_image = annotate(image, detections)
plt.imshow(annotated_image)
plt.axis('off')
if save_name:
plt.savefig(save_name, bbox_inches='tight')
plt.show()
# In[4]:
def random_named_css_colors(num_colors: int) -> List[str]:
"""
Returns a list of randomly selected named CSS colors.
Args:
- num_colors (int): Number of random colors to generate.
Returns:
- list: List of randomly selected named CSS colors.
"""
# List of named CSS colors
named_css_colors = [
'aliceblue', 'antiquewhite', 'aqua', 'aquamarine', 'azure', 'beige', 'bisque', 'black', 'blanchedalmond',
'blue', 'blueviolet', 'brown', 'burlywood', 'cadetblue', 'chartreuse', 'chocolate', 'coral', 'cornflowerblue',
'cornsilk', 'crimson', 'cyan', 'darkblue', 'darkcyan', 'darkgoldenrod', 'darkgray', 'darkgreen', 'darkgrey',
'darkkhaki', 'darkmagenta', 'darkolivegreen', 'darkorange', 'darkorchid', 'darkred', 'darksalmon', 'darkseagreen',
'darkslateblue', 'darkslategray', 'darkslategrey', 'darkturquoise', 'darkviolet', 'deeppink', 'deepskyblue',
'dimgray', 'dimgrey', 'dodgerblue', 'firebrick', 'floralwhite', 'forestgreen', 'fuchsia', 'gainsboro', 'ghostwhite',
'gold', 'goldenrod', 'gray', 'green', 'greenyellow', 'grey', 'honeydew', 'hotpink', 'indianred', 'indigo', 'ivory',
'khaki', 'lavender', 'lavenderblush', 'lawngreen', 'lemonchiffon', 'lightblue', 'lightcoral', 'lightcyan', 'lightgoldenrodyellow',
'lightgray', 'lightgreen', 'lightgrey', 'lightpink', 'lightsalmon', 'lightseagreen', 'lightskyblue', 'lightslategray',
'lightslategrey', 'lightsteelblue', 'lightyellow', 'lime', 'limegreen', 'linen', 'magenta', 'maroon', 'mediumaquamarine',
'mediumblue', 'mediumorchid', 'mediumpurple', 'mediumseagreen', 'mediumslateblue', 'mediumspringgreen', 'mediumturquoise',
'mediumvioletred', 'midnightblue', 'mintcream', 'mistyrose', 'moccasin', 'navajowhite', 'navy', 'oldlace', 'olive',
'olivedrab', 'orange', 'orangered', 'orchid', 'palegoldenrod', 'palegreen', 'paleturquoise', 'palevioletred', 'papayawhip',
'peachpuff', 'peru', 'pink', 'plum', 'powderblue', 'purple', 'rebeccapurple', 'red', 'rosybrown', 'royalblue', 'saddlebrown',
'salmon', 'sandybrown', 'seagreen', 'seashell', 'sienna', 'silver', 'skyblue', 'slateblue', 'slategray', 'slategrey',
'snow', 'springgreen', 'steelblue', 'tan', 'teal', 'thistle', 'tomato', 'turquoise', 'violet', 'wheat', 'white',
'whitesmoke', 'yellow', 'yellowgreen'
]
# Sample random named CSS colors
return random.sample(named_css_colors, min(num_colors, len(named_css_colors)))
def plot_detections_plotly(
image: np.ndarray,
detections: List[DetectionResult],
class_colors: Optional[Dict[str, str]] = None
) -> None:
# If class_colors is not provided, generate random colors for each class
if class_colors is None:
num_detections = len(detections)
colors = random_named_css_colors(num_detections)
class_colors = {}
for i in range(num_detections):
class_colors[i] = colors[i]
fig = px.imshow(image)
# Add bounding boxes
shapes = []
annotations = []
for idx, detection in enumerate(detections):
label = detection.label
box = detection.box
score = detection.score
mask = detection.mask
polygon = mask_to_polygon(mask)
fig.add_trace(go.Scatter(
x=[point[0] for point in polygon] + [polygon[0][0]],
y=[point[1] for point in polygon] + [polygon[0][1]],
mode='lines',
line=dict(color=class_colors[idx], width=2),
fill='toself',
name=f"{label}: {score:.2f}"
))
xmin, ymin, xmax, ymax = box.xyxy
shape = [
dict(
type="rect",
xref="x", yref="y",
x0=xmin, y0=ymin,
x1=xmax, y1=ymax,
line=dict(color=class_colors[idx])
)
]
annotation = [
dict(
x=(xmin+xmax) // 2, y=(ymin+ymax) // 2,
xref="x", yref="y",
text=f"{label}: {score:.2f}",
)
]
shapes.append(shape)
annotations.append(annotation)
# Update layout
button_shapes = [dict(label="None",method="relayout",args=["shapes", []])]
button_shapes = button_shapes + [
dict(label=f"Detection {idx+1}",method="relayout",args=["shapes", shape]) for idx, shape in enumerate(shapes)
]
button_shapes = button_shapes + [dict(label="All", method="relayout", args=["shapes", sum(shapes, [])])]
fig.update_layout(
xaxis=dict(visible=False),
yaxis=dict(visible=False),
# margin=dict(l=0, r=0, t=0, b=0),
showlegend=True,
updatemenus=[
dict(
type="buttons",
direction="up",
buttons=button_shapes
)
],
legend=dict(
orientation="h",
yanchor="bottom",
y=1.02,
xanchor="right",
x=1
)
)
# Show plot
fig.show()
# In[5]:
def mask_to_polygon(mask: np.ndarray) -> List[List[int]]:
# Find contours in the binary mask
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Find the contour with the largest area
largest_contour = max(contours, key=cv2.contourArea)
# Extract the vertices of the contour
polygon = largest_contour.reshape(-1, 2).tolist()
return polygon
def polygon_to_mask(polygon: List[Tuple[int, int]], image_shape: Tuple[int, int]) -> np.ndarray:
"""
Convert a polygon to a segmentation mask.
Args:
- polygon (list): List of (x, y) coordinates representing the vertices of the polygon.
- image_shape (tuple): Shape of the image (height, width) for the mask.
Returns:
- np.ndarray: Segmentation mask with the polygon filled.
"""
# Create an empty mask
mask = np.zeros(image_shape, dtype=np.uint8)
# Convert polygon to an array of points
pts = np.array(polygon, dtype=np.int32)
# Fill the polygon with white color (255)
cv2.fillPoly(mask, [pts], color=(255,))
return mask
def load_image(image_str: str) -> Image.Image:
if image_str.startswith("http"):
image = Image.open(requests.get(image_str, stream=True).raw).convert("RGB")
else:
image = Image.open(image_str).convert("RGB")
return image
def get_boxes(results: DetectionResult) -> List[List[List[float]]]:
boxes = []
for result in results:
xyxy = result.box.xyxy
boxes.append(xyxy)
return [boxes]
def refine_masks(masks: torch.BoolTensor, polygon_refinement: bool = False) -> List[np.ndarray]:
masks = masks.cpu().float()
masks = masks.permute(0, 2, 3, 1)
masks = masks.mean(axis=-1)
masks = (masks > 0).int()
masks = masks.numpy().astype(np.uint8)
masks = list(masks)
if polygon_refinement:
for idx, mask in enumerate(masks):
shape = mask.shape
polygon = mask_to_polygon(mask)
mask = polygon_to_mask(polygon, shape)
masks[idx] = mask
return masks
# In[6]:
def detect(
image: Image.Image,
labels: List[str],
threshold: float = 0.3,
detector_id: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
Use Grounding DINO to detect a set of labels in an image in a zero-shot fashion.
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
detector_id = detector_id if detector_id is not None else "IDEA-Research/grounding-dino-tiny"
object_detector = pipeline(model=detector_id, task="zero-shot-object-detection", device=device)
labels = [label if label.endswith(".") else label+"." for label in labels]
results = object_detector(image, candidate_labels=labels, threshold=threshold)
results = [DetectionResult.from_dict(result) for result in results]
return results
def segment(
image: Image.Image,
detection_results: List[Dict[str, Any]],
polygon_refinement: bool = False,
segmenter_id: Optional[str] = None
) -> List[DetectionResult]:
"""
Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes.
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
segmenter_id = segmenter_id if segmenter_id is not None else "facebook/sam-vit-base"
segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to(device)
processor = AutoProcessor.from_pretrained(segmenter_id)
boxes = get_boxes(detection_results)
inputs = processor(images=image, input_boxes=boxes, return_tensors="pt").to(device)
outputs = segmentator(**inputs)
masks = processor.post_process_masks(
masks=outputs.pred_masks,
original_sizes=inputs.original_sizes,
reshaped_input_sizes=inputs.reshaped_input_sizes
)[0]
masks = refine_masks(masks, polygon_refinement)
for detection_result, mask in zip(detection_results, masks):
detection_result.mask = mask
return detection_results
def grounded_segmentation(
image: Union[Image.Image, str],
labels: List[str],
threshold: float = 0.3,
polygon_refinement: bool = False,
detector_id: Optional[str] = None,
segmenter_id: Optional[str] = None
) -> Tuple[np.ndarray, List[DetectionResult]]:
if isinstance(image, str):
image = load_image(image)
detections = detect(image, labels, threshold, detector_id)
detections = segment(image, detections, polygon_refinement, segmenter_id)
return image, detections
# In[7]:
# save clipped images
def cut_image(image, mask, box):
ny_image = np.array(image)
cut = cv2.bitwise_and(ny_image, ny_image, mask=mask.astype(np.uint8)*255)
x0, y0, x1, y1 = map(int, box.xyxy)
cropped = cut[y0:y1, x0:x1]
cropped_bgr = cv2.cvtColor(cropped, cv2.COLOR_RGB2BGR)
return cropped_bgr
# In[8]:
image_url = "/home/dheena/Downloads/fashion/images (1).jpeg"
labels = ["a dress"]
threshold = 0.3
detector_id = "IDEA-Research/grounding-dino-tiny"
segmenter_id = "facebook/sam-vit-base"
# In[9]:
# image, detections = grounded_segmentation(
# image=image_url,
# labels=labels,
# threshold=threshold,
# polygon_refinement=True,
# detector_id=detector_id,
# segmenter_id=segmenter_id
# )
# current = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
# cropped_image = cut_image(image, detections[0].mask, detections[0].box)
# cv2.imwrite("/home/dheena/Downloads/fashion/output/" + current, cropped_image)
# In[44]:
# plot_detections(np.array(image), detections, "test.png")
# In[60]:
# model imports
import faiss
import torch
import clip
from openai import OpenAI
from torch.utils.data import DataLoader
# helper imports
from tqdm import tqdm
import os
import numpy as np
from typing import List, Tuple
# visualization imports
from PIL import Image
from fastapi import FastAPI
from typing import List
import matplotlib.pyplot as plt
client = OpenAI()
# Set device
device = "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
# # Directory path
# direc = '/home/dheena/Downloads/fashion/output/'
# def get_image(filepath: str) -> Image.Image:
# """Safely load and convert an image file to RGB PIL format."""
# try:
# return Image.open(filepath).convert("RGB")
# except Exception as e:
# print(f"Failed to load {filepath}: {e}")
# return None
# def get_all_images_from_dir(directory: str) -> List[Tuple[str, Image.Image]]:
# """Load all supported images from a directory, with paths."""
# supported_exts = ('.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp')
# image_data = []
# for root, _, files in os.walk(directory):
# for file in files:
# if file.lower().endswith(supported_exts):
# full_path = os.path.join(root, file)
# try:
# img = Image.open(full_path).convert("RGB")
# image_data.append((full_path, img))
# except Exception as e:
# print(f"Error loading {full_path}: {e}")
# return image_data
def get_image_features(image: Image.Image) -> np.ndarray:
"""Extract CLIP features from an image."""
image_input = preprocess(image).unsqueeze(0).to(device)
with torch.no_grad():
image_features = model.encode_image(image_input).float()
return image_features.cpu().numpy()
# FAISS setup
index = faiss.IndexFlatIP(512)
meta_data_store = []
def save_image_in_index(image_features: np.ndarray, metadata: dict):
"""Normalize features and add to index."""
faiss.normalize_L2(image_features)
index.add(image_features)
meta_data_store.append(metadata)
def process_image_embedding(image_url: str, labels=['clothes']) -> np.ndarray:
"""Get feature embedding for a query image."""
search_image, search_detections = grounded_segmentation(image=image_url, labels=labels)
cropped_image = cut_image(search_image, search_detections[0].mask, search_detections[0].box)
# Convert to valid RGB
if cropped_image.dtype != np.uint8:
cropped_image = (cropped_image * 255).astype(np.uint8)
if cropped_image.ndim == 2:
cropped_image = np.stack([cropped_image] * 3, axis=-1)
pil_image = Image.fromarray(cropped_image)
return pil_image
def get_top_k_results(image_url: str, k: int = 10) -> List[dict]:
"""Find top-k similar images from the index."""
processed_image = process_image_embedding(image_url)
image_search_embedding = get_image_features(processed_image)
faiss.normalize_L2(image_search_embedding)
distances, indices = index.search(image_search_embedding.reshape(1, -1), k)
results = []
for i, dist in zip(indices[0], distances[0]):
if i < len(meta_data_store):
results.append({
'metadata': meta_data_store[i],
'score': float(dist)
})
return results
# def display_similar_images(results: List[dict]):
# """Display retrieved images using matplotlib."""
# for item in results:
# img = get_image(item['metadata']['image_path'])
# if img:
# print(f"Score: {item['score']:.4f}")
# plt.imshow(img)
# plt.axis('off')
# plt.show()
# In[73]:
app = FastAPI()
@app.get("/similar_images")
def get_similar_images(image_url: str, k: int = 10):
results = get_top_k_results(image_url, k)
# display_similar_images(results) # Optional visualization call
return {
"results": [
{
"metadata": item["metadata"],
"score": item["score"]
}
for item in results
]
}
# Example usage:
# results = get_top_k_results("/home/dheena/Downloads/fashion/temp/KPR-120-Wine_2_1024x1024.webp")
# display_similar_images(results)
# In[54]: