| import os |
|
|
| import torch |
| import open_clip |
| import numpy as np |
| from PIL import Image |
| from sklearn.preprocessing import normalize |
|
|
| class ProductGrouper: |
|
|
| def __init__( |
| self, |
| similarity_threshold=0.92, |
| row_vertical_threshold=90, |
| padding=5, |
| device=None, |
| seed=42, |
| similarity_margin=0.002, |
| similarity_round_decimals=4, |
| ): |
|
|
| self.similarity_threshold = similarity_threshold |
| self.row_vertical_threshold = row_vertical_threshold |
| self.padding = padding |
| self.similarity_margin = float(os.getenv("SIMILARITY_MARGIN", str(similarity_margin))) |
| self.similarity_round_decimals = int( |
| os.getenv("SIMILARITY_ROUND_DECIMALS", str(similarity_round_decimals)) |
| ) |
|
|
| if device is None: |
| env_device = os.getenv("CLIP_DEVICE", "cpu") |
| if env_device == "auto": |
| env_device = "cuda" if torch.cuda.is_available() else "cpu" |
| device = env_device |
|
|
| self.device = torch.device(device) |
| self.seed = int(os.getenv("INFERENCE_SEED", str(seed))) |
|
|
| np.random.seed(self.seed) |
| torch.manual_seed(self.seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(self.seed) |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
|
|
| self.model, _, self.preprocess = open_clip.create_model_and_transforms( |
| "ViT-B-32", |
| pretrained="openai" |
| ) |
|
|
| self.model.to(self.device) |
| self.model.eval() |
|
|
| def extract_embeddings(self, image_path, detections): |
|
|
| image = Image.open(image_path).convert("RGB") |
|
|
| crops = [] |
|
|
| for det in detections: |
|
|
| x1, y1, x2, y2 = det["bbox"] |
|
|
| x1p = max(0, x1 - self.padding) |
| y1p = max(0, y1 - self.padding) |
| x2p = min(image.width, x2 + self.padding) |
| y2p = min(image.height, y2 + self.padding) |
|
|
| crop = image.crop((x1p, y1p, x2p, y2p)) |
|
|
| crops.append(self.preprocess(crop)) |
|
|
| if not crops: |
| return np.empty((0, 0), dtype=np.float32) |
|
|
| |
| batch = torch.stack(crops).to(self.device) |
|
|
| with torch.no_grad(): |
| features = self.model.encode_image(batch) |
|
|
| embeddings = features.cpu().numpy() |
|
|
| embeddings = normalize(embeddings) |
|
|
| return embeddings |
|
|
| def group_products(self, image_path, detections): |
|
|
| if not detections: |
| return detections |
|
|
| embeddings = self.extract_embeddings(image_path, detections) |
|
|
| rows = [] |
| sorted_indices = sorted( |
| range(len(detections)), |
| key=lambda i: (detections[i]["center_y"], detections[i]["bbox"][0], detections[i]["bbox"][1]) |
| ) |
|
|
| for idx in sorted_indices: |
| placed = False |
| for row in rows: |
| if abs(detections[idx]["center_y"] - detections[row[0]]["center_y"]) < self.row_vertical_threshold: |
| row.append(idx) |
| placed = True |
| break |
| if not placed: |
| rows.append([idx]) |
|
|
| group_ids = [-1] * len(detections) |
| current_group = 0 |
|
|
| for row in rows: |
| row = sorted(row, key=lambda i: (detections[i]["bbox"][0], detections[i]["bbox"][1])) |
|
|
| group_ids[row[0]] = current_group |
|
|
| for i in range(1, len(row)): |
| prev_idx = row[i - 1] |
| curr_idx = row[i] |
|
|
| raw_similarity = float(np.dot(embeddings[prev_idx], embeddings[curr_idx])) |
| similarity = round(raw_similarity, self.similarity_round_decimals) |
|
|
| if similarity >= (self.similarity_threshold - self.similarity_margin): |
| group_ids[curr_idx] = current_group |
| else: |
| current_group += 1 |
| group_ids[curr_idx] = current_group |
|
|
| current_group += 1 |
|
|
| for i, det in enumerate(detections): |
| det["group_id"] = group_ids[i] |
| det["brand_group_id"] = group_ids[i] |
|
|
| return detections |
|
|