img_comparer / src /embeddings.py
Vivek Vaddina
πŸ§‘β€πŸ’» New UI and refactor code
df3522f unverified
import cv2
import torch
import mahotas
import numpy as np
from collections import defaultdict
from transformers import AutoModel, AutoProcessor
from skimage import measure, feature, img_as_ubyte
from sklearn.metrics.pairwise import cosine_similarity
from src.config import IMAGE_EMBEDDING_MODEL_CHECKPOINT, log
log.debug("loading HF embedding models")
emb_model = AutoModel.from_pretrained(
IMAGE_EMBEDDING_MODEL_CHECKPOINT # device_map="auto"
).eval()
emb_processor = AutoProcessor.from_pretrained(
IMAGE_EMBEDDING_MODEL_CHECKPOINT, use_fast=True
)
def get_image_embeddings(model, processor, img_rgbs):
if len(img_rgbs) == 1:
img_rgbs = [img_rgbs]
inputs = processor(images=img_rgbs, return_tensors="pt").to(model.device)
if model.base_model_prefix == "vit":
with torch.no_grad():
outputs = model(**inputs)
# Use [CLS] token embedding
return outputs.last_hidden_state[:, 0].to("cpu").numpy()
elif model.base_model_prefix == "siglip":
with torch.no_grad():
image_embeddings = model.get_image_features(**inputs)
return image_embeddings.to("cpu").numpy()
def get_hf_cosine_similarity(model, processor, masks):
emb = get_image_embeddings(emb_model, emb_processor, masks)
return cosine_similarity(emb, emb)[0][1:].tolist()
def compute_mask_embeddings(
mask,
zernike_degree=8,
hog_orientations=8,
hog_pixels_per_cell=(16, 16),
hog_cells_per_block=(1, 1),
):
"""
Compute a variety of classical shape embeddings for a binary mask.
Parameters
----------
mask : ndarray, shape (H, W)
Binary mask image (dtype=bool or 0/1).
zernike_radius : int
Radius used for Zernike moments (must fit inside mask).
zernike_degree : int
Maximum degree for Zernike moments.
hog_orientations : int
Number of orientation bins for HOG.
hog_pixels_per_cell : tuple
Size (in pixels) of a cell for HOG.
hog_cells_per_block : tuple
Number of cells in each block for HOG.
Returns
-------
embeddings : dict
Dictionary of feature vectors:
- 'hu_moments': 7 Hu moments
- 'zernike': len = (zernike_degree+1)*(zernike_degree+2)//2
- 'fourier': first 32 complex Fourier descriptors (flattened to 64D)
- 'regionprops': [area, bbox_w, bbox_h, aspect_ratio, centroid_x, centroid_y,
perimeter, extent]
- 'hog': HOG descriptor (1D vector)
"""
# Ensure mask is uint8 for OpenCV and mahotas
mask_u8 = img_as_ubyte(mask > 0)
embeddings = {}
# 1. Hu Moments (7D)
moments = cv2.moments(mask_u8)
hu = cv2.HuMoments(moments).flatten()
# Log scale transform for Hu
hu = -np.sign(hu) * np.log10(np.abs(hu) + 1e-16)
embeddings["hu_moments"] = hu
# 2. Zernike Moments
# Compute on the largest inscribed circle; mask must be square crop
H, W = mask_u8.shape
R = min(H, W) // 2
center = (H // 2, W // 2)
# Crop square region
y0, x0 = center[0] - R, center[1] - R
crop = mask_u8[y0 : y0 + 2 * R, x0 : x0 + 2 * R]
zernike = mahotas.features.zernike_moments(crop, radius=R, degree=zernike_degree)
embeddings["zernike"] = zernike
# 3. Fourier Shape Descriptors
# Extract external contour
contours, _ = cv2.findContours(mask_u8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
if len(contours) > 0:
cnt = max(contours, key=cv2.contourArea).squeeze() # (N,2) array
complex_pts = cnt[:, 0] + 1j * cnt[:, 1]
# Compute FFT, take first 32 coefficients (skip DC)
fft_coeffs = np.fft.fft(complex_pts)
descriptors = fft_coeffs[1:33] # first 32
# Normalize by magnitude of first descriptor
descriptors /= np.abs(descriptors[0]) + 1e-16
# Flatten to real+imag = 64D
fourier = np.hstack([descriptors.real, descriptors.imag])
else:
fourier = np.zeros(64)
embeddings["fourier"] = fourier
# 4. Region Properties Vector
label_img = measure.label(mask > 0)
props = measure.regionprops(label_img)
if props:
p = max(props, key=lambda x: x.area)
area = p.area
minr, minc, maxr, maxc = p.bbox
bbox_w = maxc - minc
bbox_h = maxr - minr
aspect = bbox_w / bbox_h if bbox_h else 0
centroid_y, centroid_x = p.centroid
perimeter = p.perimeter
extent = p.extent
regionprops = np.array(
[area, bbox_w, bbox_h, aspect, centroid_x, centroid_y, perimeter, extent],
dtype=float,
)
else:
regionprops = np.zeros(8)
embeddings["regionprops"] = regionprops
# 5. Histogram of Oriented Gradients (HOG)
hog_vec = feature.hog(
mask.astype(float),
orientations=hog_orientations,
pixels_per_cell=hog_pixels_per_cell,
cells_per_block=hog_cells_per_block,
block_norm="L2-Hys",
feature_vector=True,
)
embeddings["hog"] = hog_vec
return embeddings
def get_mask_embedding_scores(imgs):
if not isinstance(imgs, list):
imgs = [imgs]
emb_mask = defaultdict(list)
for img in imgs:
# b_img = binarize(np.array(img))
mask_embs_dc = compute_mask_embeddings(np.array(img.convert("L")))
for k, emb in mask_embs_dc.items():
emb_mask[k].append(emb)
for k in emb_mask.keys():
emb_mask[k] = np.array(emb_mask[k])
scores = defaultdict(list)
for metric, arr in emb_mask.items():
scores[f"cos_sim_{metric}"] = cosine_similarity(arr, arr)[0][1:].tolist()
return scores