import os import torch import open_clip import numpy as np from PIL import Image from sklearn.preprocessing import normalize class ProductGrouper: def __init__( self, similarity_threshold=0.92, row_vertical_threshold=90, padding=5, device=None, seed=42, similarity_margin=0.002, similarity_round_decimals=4, ): self.similarity_threshold = similarity_threshold self.row_vertical_threshold = row_vertical_threshold self.padding = padding self.similarity_margin = float(os.getenv("SIMILARITY_MARGIN", str(similarity_margin))) self.similarity_round_decimals = int( os.getenv("SIMILARITY_ROUND_DECIMALS", str(similarity_round_decimals)) ) if device is None: env_device = os.getenv("CLIP_DEVICE", "cpu") if env_device == "auto": env_device = "cuda" if torch.cuda.is_available() else "cpu" device = env_device self.device = torch.device(device) self.seed = int(os.getenv("INFERENCE_SEED", str(seed))) np.random.seed(self.seed) torch.manual_seed(self.seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(self.seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False self.model, _, self.preprocess = open_clip.create_model_and_transforms( "ViT-B-32", pretrained="openai" ) self.model.to(self.device) self.model.eval() def extract_embeddings(self, image_path, detections): image = Image.open(image_path).convert("RGB") crops = [] for det in detections: x1, y1, x2, y2 = det["bbox"] x1p = max(0, x1 - self.padding) y1p = max(0, y1 - self.padding) x2p = min(image.width, x2 + self.padding) y2p = min(image.height, y2 + self.padding) crop = image.crop((x1p, y1p, x2p, y2p)) crops.append(self.preprocess(crop)) if not crops: return np.empty((0, 0), dtype=np.float32) # Stack all crops into one batch batch = torch.stack(crops).to(self.device) with torch.no_grad(): features = self.model.encode_image(batch) embeddings = features.cpu().numpy() embeddings = normalize(embeddings) return embeddings def group_products(self, image_path, detections): if not detections: return detections embeddings = self.extract_embeddings(image_path, detections) rows = [] sorted_indices = sorted( range(len(detections)), key=lambda i: (detections[i]["center_y"], detections[i]["bbox"][0], detections[i]["bbox"][1]) ) for idx in sorted_indices: placed = False for row in rows: if abs(detections[idx]["center_y"] - detections[row[0]]["center_y"]) < self.row_vertical_threshold: row.append(idx) placed = True break if not placed: rows.append([idx]) group_ids = [-1] * len(detections) current_group = 0 for row in rows: row = sorted(row, key=lambda i: (detections[i]["bbox"][0], detections[i]["bbox"][1])) group_ids[row[0]] = current_group for i in range(1, len(row)): prev_idx = row[i - 1] curr_idx = row[i] raw_similarity = float(np.dot(embeddings[prev_idx], embeddings[curr_idx])) similarity = round(raw_similarity, self.similarity_round_decimals) if similarity >= (self.similarity_threshold - self.similarity_margin): group_ids[curr_idx] = current_group else: current_group += 1 group_ids[curr_idx] = current_group current_group += 1 for i, det in enumerate(detections): det["group_id"] = group_ids[i] det["brand_group_id"] = group_ids[i] return detections