Rishabh Uikey
Stabilize grouping across environments
248b1be
import os
import torch
import open_clip
import numpy as np
from PIL import Image
from sklearn.preprocessing import normalize
class ProductGrouper:
def __init__(
self,
similarity_threshold=0.92,
row_vertical_threshold=90,
padding=5,
device=None,
seed=42,
similarity_margin=0.002,
similarity_round_decimals=4,
):
self.similarity_threshold = similarity_threshold
self.row_vertical_threshold = row_vertical_threshold
self.padding = padding
self.similarity_margin = float(os.getenv("SIMILARITY_MARGIN", str(similarity_margin)))
self.similarity_round_decimals = int(
os.getenv("SIMILARITY_ROUND_DECIMALS", str(similarity_round_decimals))
)
if device is None:
env_device = os.getenv("CLIP_DEVICE", "cpu")
if env_device == "auto":
env_device = "cuda" if torch.cuda.is_available() else "cpu"
device = env_device
self.device = torch.device(device)
self.seed = int(os.getenv("INFERENCE_SEED", str(seed)))
np.random.seed(self.seed)
torch.manual_seed(self.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(self.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
"ViT-B-32",
pretrained="openai"
)
self.model.to(self.device)
self.model.eval()
def extract_embeddings(self, image_path, detections):
image = Image.open(image_path).convert("RGB")
crops = []
for det in detections:
x1, y1, x2, y2 = det["bbox"]
x1p = max(0, x1 - self.padding)
y1p = max(0, y1 - self.padding)
x2p = min(image.width, x2 + self.padding)
y2p = min(image.height, y2 + self.padding)
crop = image.crop((x1p, y1p, x2p, y2p))
crops.append(self.preprocess(crop))
if not crops:
return np.empty((0, 0), dtype=np.float32)
# Stack all crops into one batch
batch = torch.stack(crops).to(self.device)
with torch.no_grad():
features = self.model.encode_image(batch)
embeddings = features.cpu().numpy()
embeddings = normalize(embeddings)
return embeddings
def group_products(self, image_path, detections):
if not detections:
return detections
embeddings = self.extract_embeddings(image_path, detections)
rows = []
sorted_indices = sorted(
range(len(detections)),
key=lambda i: (detections[i]["center_y"], detections[i]["bbox"][0], detections[i]["bbox"][1])
)
for idx in sorted_indices:
placed = False
for row in rows:
if abs(detections[idx]["center_y"] - detections[row[0]]["center_y"]) < self.row_vertical_threshold:
row.append(idx)
placed = True
break
if not placed:
rows.append([idx])
group_ids = [-1] * len(detections)
current_group = 0
for row in rows:
row = sorted(row, key=lambda i: (detections[i]["bbox"][0], detections[i]["bbox"][1]))
group_ids[row[0]] = current_group
for i in range(1, len(row)):
prev_idx = row[i - 1]
curr_idx = row[i]
raw_similarity = float(np.dot(embeddings[prev_idx], embeddings[curr_idx]))
similarity = round(raw_similarity, self.similarity_round_decimals)
if similarity >= (self.similarity_threshold - self.similarity_margin):
group_ids[curr_idx] = current_group
else:
current_group += 1
group_ids[curr_idx] = current_group
current_group += 1
for i, det in enumerate(detections):
det["group_id"] = group_ids[i]
det["brand_group_id"] = group_ids[i]
return detections