shelf_demo / embeddings.py
OverMind0's picture
Update embeddings.py
1b3afb7 verified
# -*- coding: utf-8 -*-
"""Reference augmentation and embedding utilities."""
from __future__ import annotations
from collections import defaultdict
from typing import Dict, List, Tuple
import numpy as np
import torch
from PIL import Image, ImageEnhance, ImageOps
from transformers import AutoImageProcessor, AutoModel
_DINO_PROCESSOR = None
_DINO_MODEL = None
def get_device() -> torch.device:
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
def get_dino_model(device: torch.device):
global _DINO_PROCESSOR, _DINO_MODEL
if _DINO_PROCESSOR is None:
_DINO_PROCESSOR = AutoImageProcessor.from_pretrained("facebook/dinov2-small")
if _DINO_MODEL is None:
_DINO_MODEL = AutoModel.from_pretrained("facebook/dinov2-small").to(device)
_DINO_MODEL.eval()
return _DINO_PROCESSOR, _DINO_MODEL
def augment_image(img: Image.Image) -> Image.Image:
aug = img.copy()
if np.random.rand() < 0.5:
aug = ImageOps.mirror(aug)
angle = float(np.random.uniform(-10, 10))
aug = aug.rotate(angle, resample=Image.BILINEAR)
if np.random.rand() < 0.7:
enhancer = ImageEnhance.Brightness(aug)
aug = enhancer.enhance(float(np.random.uniform(0.8, 1.2)))
if np.random.rand() < 0.7:
enhancer = ImageEnhance.Contrast(aug)
aug = enhancer.enhance(float(np.random.uniform(0.8, 1.2)))
if np.random.rand() < 0.5:
enhancer = ImageEnhance.Sharpness(aug)
aug = enhancer.enhance(float(np.random.uniform(0.9, 1.3)))
return aug
def extract_embedding_from_pil(image: Image.Image, device: torch.device) -> torch.Tensor:
processor, model = get_dino_model(device)
inputs = processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
emb = outputs.last_hidden_state[:, 0, :]
emb = torch.nn.functional.normalize(emb, p=2, dim=1)
return emb
def build_reference_embeddings(
ref_images: List[Image.Image],
device: torch.device,
augmentations_per_image: int = 5,
) -> torch.Tensor:
augmented_images: List[Image.Image] = []
for img in ref_images:
augmented_images.append(img)
for _ in range(augmentations_per_image):
augmented_images.append(augment_image(img))
embeddings: List[torch.Tensor] = []
for img in augmented_images:
emb = extract_embedding_from_pil(img, device)
embeddings.append(emb)
return torch.cat(embeddings, dim=0)
def compute_similarities(
object_crops: Dict[int, Image.Image],
boxes: List[List[int]],
ref_embeddings: torch.Tensor,
device: torch.device,
) -> List[Dict[str, object]]:
similarities: List[Dict[str, object]] = []
for i, crop in object_crops.items():
prod_emb = extract_embedding_from_pil(crop, device)
sim = torch.matmul(ref_embeddings, prod_emb.T).max().item()
similarities.append({
"box_id": i,
"similarity": float(sim),
"box": boxes[i],
})
similarities.sort(key=lambda x: x["similarity"], reverse=True)
return similarities
def adaptive_similarity_threshold(
similarities: List[Dict[str, object]],
percentile: int = 80,
std_factor: float = 0.5,
min_threshold: float = 0.7,
) -> float:
if not similarities:
return min_threshold
sims = np.array([s["similarity"] for s in similarities], dtype=float)
p_thresh = float(np.percentile(sims, percentile))
mean_thresh = float(sims.mean() + std_factor * sims.std())
adaptive_thresh = max(p_thresh, mean_thresh, float(min_threshold))
return adaptive_thresh
def count_sim_objects_per_shelf(
similarities: List[Dict[str, object]],
object_metadata: List[Dict[str, object]],
threshold: float,
) -> Tuple[Dict[int, int], List[Dict[str, object]]]:
box_to_shelf = {obj["box_id"]: obj["shelf_id"] for obj in object_metadata}
shelf_counts: Dict[int, int] = defaultdict(int)
valid_objects: List[Dict[str, object]] = []
for s in similarities:
if s["similarity"] < threshold:
continue
box_id = s["box_id"]
shelf_id = box_to_shelf.get(box_id)
if shelf_id is None:
continue
shelf_counts[int(shelf_id)] += 1
valid_objects.append({
"box_id": box_id,
"shelf_id": shelf_id,
"similarity": s["similarity"],
"box": s["box"],
})
return dict(shelf_counts), valid_objects
def calculate_shelf_share(similarities: List[Dict[str, object]], threshold: float):
matched_area = 0.0
total_area = 0.0
for s in similarities:
x1, y1, x2, y2 = s["box"]
area = float((x2 - x1) * (y2 - y1))
total_area += area
if s["similarity"] >= threshold:
matched_area += area
share = matched_area / total_area if total_area > 0 else 0.0
share_pct = share * 100.0
if share_pct > 90:
stock_status = "high"
elif share_pct < 50:
stock_status = "low"
else:
stock_status = "medium"
return share, stock_status, total_area, matched_area
def classify_facing_by_shelf(shelf_match_counts: Dict[int, int], shelf_count: int) -> Dict[int, str]:
top_limit = int(np.ceil(shelf_count * 0.5))
facing_status: Dict[int, str] = {}
for shelf_id in range(1, shelf_count + 1):
if shelf_id <= top_limit:
facing_status[shelf_id] = "good place"
else:
facing_status[shelf_id] = "not good place"
return facing_status