OverMind0 commited on
Commit
a840fa0
·
verified ·
1 Parent(s): ced6f9d

Upload 5 files

Browse files
Files changed (5) hide show
  1. augmentation_embeddings.py +137 -0
  2. best.pt +3 -0
  3. processor.py +188 -0
  4. requirements.txt +7 -0
  5. router.py +143 -0
augmentation_embeddings.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Augmentation and embedding helpers."""
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import List, Dict, Tuple
7
+
8
+ import random
9
+ import numpy as np
10
+ import torch
11
+ from PIL import Image, ImageEnhance, ImageOps
12
+ from transformers import AutoImageProcessor, AutoModel
13
+
14
+ _DINO_PROCESSOR = None
15
+ _DINO_MODEL = None
16
+
17
+
18
+ def get_dino_model(device: torch.device):
19
+ global _DINO_PROCESSOR, _DINO_MODEL
20
+
21
+ if _DINO_PROCESSOR is None or _DINO_MODEL is None:
22
+ _DINO_PROCESSOR = AutoImageProcessor.from_pretrained("facebook/dinov2-small")
23
+ _DINO_MODEL = AutoModel.from_pretrained("facebook/dinov2-small").to(device)
24
+ _DINO_MODEL.eval()
25
+
26
+ return _DINO_PROCESSOR, _DINO_MODEL
27
+
28
+
29
+ def augment_image(img: Image.Image) -> Image.Image:
30
+ aug = img.copy()
31
+
32
+ if random.random() < 0.5:
33
+ aug = ImageOps.mirror(aug)
34
+
35
+ angle = random.uniform(-10, 10)
36
+ aug = aug.rotate(angle, resample=Image.BILINEAR)
37
+
38
+ if random.random() < 0.7:
39
+ enhancer = ImageEnhance.Brightness(aug)
40
+ aug = enhancer.enhance(random.uniform(0.8, 1.2))
41
+
42
+ if random.random() < 0.7:
43
+ enhancer = ImageEnhance.Contrast(aug)
44
+ aug = enhancer.enhance(random.uniform(0.8, 1.2))
45
+
46
+ if random.random() < 0.5:
47
+ enhancer = ImageEnhance.Sharpness(aug)
48
+ aug = enhancer.enhance(random.uniform(0.9, 1.3))
49
+
50
+ return aug
51
+
52
+
53
+ def extract_embedding_from_pil(image: Image.Image, device: torch.device) -> torch.Tensor:
54
+ processor, model = get_dino_model(device)
55
+
56
+ inputs = processor(images=image, return_tensors="pt").to(device)
57
+ with torch.no_grad():
58
+ outputs = model(**inputs)
59
+
60
+ emb = outputs.last_hidden_state[:, 0, :]
61
+ emb = torch.nn.functional.normalize(emb, p=2, dim=1)
62
+ return emb
63
+
64
+
65
+ def build_reference_embeddings(
66
+ ref_images: List[Image.Image],
67
+ device: torch.device,
68
+ augmentations_per_image: int = 10,
69
+ ) -> torch.Tensor:
70
+ augmented_images: List[Image.Image] = []
71
+
72
+ for img in ref_images:
73
+ augmented_images.append(img)
74
+ for _ in range(augmentations_per_image):
75
+ augmented_images.append(augment_image(img))
76
+
77
+ ref_embeddings = []
78
+ for img in augmented_images:
79
+ ref_embeddings.append(extract_embedding_from_pil(img, device))
80
+
81
+ return torch.cat(ref_embeddings, dim=0)
82
+
83
+
84
+ def adaptive_similarity_threshold(
85
+ similarities: List[Dict[str, float]],
86
+ percentile: int = 80,
87
+ std_factor: float = 0.5,
88
+ min_threshold: float = 0.7,
89
+ ) -> float:
90
+ sims = np.array([s["similarity"] for s in similarities])
91
+ if sims.size == 0:
92
+ return min_threshold
93
+
94
+ p_thresh = np.percentile(sims, percentile)
95
+ mean_thresh = sims.mean() + std_factor * sims.std()
96
+
97
+ return max(p_thresh, mean_thresh, min_threshold)
98
+
99
+
100
+ def compute_similarities(
101
+ object_crops: Dict[int, Image.Image],
102
+ ref_embeddings: torch.Tensor,
103
+ device: torch.device,
104
+ ) -> List[Dict[str, float]]:
105
+ similarities = []
106
+ for i, crop in object_crops.items():
107
+ prod_emb = extract_embedding_from_pil(crop, device)
108
+ sim = torch.matmul(ref_embeddings, prod_emb.T).max().item()
109
+ similarities.append({"box_id": i, "similarity": sim})
110
+
111
+ similarities.sort(key=lambda x: x["similarity"], reverse=True)
112
+ return similarities
113
+
114
+
115
+ def calculate_shelf_share(similarities: List[Dict[str, float]], boxes, threshold: float):
116
+ matched_area = 0
117
+ total_area = 0
118
+ stock_status = ""
119
+
120
+ for s in similarities:
121
+ x1, y1, x2, y2 = boxes[s["box_id"]]
122
+ area = (x2 - x1) * (y2 - y1)
123
+ total_area += area
124
+
125
+ if s["similarity"] >= threshold:
126
+ matched_area += area
127
+
128
+ share = matched_area / total_area if total_area > 0 else 0
129
+
130
+ if share > 0.9:
131
+ stock_status = "high"
132
+ elif share < 0.5:
133
+ stock_status = "low"
134
+ else:
135
+ stock_status = "medium"
136
+
137
+ return share, stock_status, total_area, matched_area
best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:731b75c451e7797635b321905e9304b98570c82960c9830a4ae58f43f0634101
3
+ size 5358277
processor.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Shelf inventory processing utilities."""
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import List, Tuple, Dict, Any
7
+
8
+ import numpy as np
9
+ from PIL import Image, ImageDraw
10
+
11
+
12
+ class ShelfInventoryProcessor:
13
+ def __init__(
14
+ self,
15
+ model,
16
+ overlap_threshold: float = 0.5,
17
+ min_box_height: int = 20,
18
+ min_items_per_shelf: int = 8,
19
+ merge_overlap_threshold: float = 0.3,
20
+ ) -> None:
21
+ self.model = model
22
+ self.overlap_threshold = overlap_threshold
23
+ self.min_box_height = min_box_height
24
+ self.min_items_per_shelf = min_items_per_shelf
25
+ self.merge_overlap_threshold = merge_overlap_threshold
26
+
27
+ @staticmethod
28
+ def vertical_overlap(range1: Tuple[float, float], range2: Tuple[float, float]) -> float:
29
+ inter = min(range1[1], range2[1]) - max(range1[0], range2[0])
30
+ if inter <= 0:
31
+ return 0.0
32
+ h1 = range1[1] - range1[0]
33
+ return inter / h1 if h1 > 0 else 0.0
34
+
35
+ def run_inference(self, image: Image.Image):
36
+ results = self.model(image, verbose=False)[0]
37
+ img = image.convert("RGB")
38
+ draw = ImageDraw.Draw(img)
39
+
40
+ if not results.boxes:
41
+ return None, img, draw
42
+
43
+ boxes = results.boxes.xyxy.cpu().numpy()
44
+ boxes = boxes[np.argsort(boxes[:, 1])]
45
+
46
+ return boxes, img, draw
47
+
48
+ def group_boxes_into_shelves(self, boxes: np.ndarray) -> List[List[np.ndarray]]:
49
+ shelves: List[List[np.ndarray]] = []
50
+
51
+ for box in boxes:
52
+ x1, y1, x2, y2 = box
53
+ box_h = y2 - y1
54
+
55
+ if box_h < self.min_box_height:
56
+ continue
57
+
58
+ matched = False
59
+ for shelf in shelves:
60
+ s_y1 = np.median([b[1] for b in shelf])
61
+ s_y2 = np.median([b[3] for b in shelf])
62
+
63
+ inter = min(y2, s_y2) - max(y1, s_y1)
64
+ overlap_ratio = inter / box_h if box_h > 0 else 0
65
+
66
+ if overlap_ratio > self.overlap_threshold:
67
+ shelf.append(box)
68
+ matched = True
69
+ break
70
+
71
+ if not matched:
72
+ shelves.append([box])
73
+
74
+ return shelves
75
+
76
+ def build_shelf_objects(self, shelves: List[List[np.ndarray]]) -> List[Dict[str, Any]]:
77
+ shelf_objs: List[Dict[str, Any]] = []
78
+ for shelf in shelves:
79
+ ys = [b[1] for b in shelf] + [b[3] for b in shelf]
80
+ shelf_objs.append({"boxes": shelf, "y_range": (min(ys), max(ys))})
81
+ return shelf_objs
82
+
83
+ def merge_weak_shelves(self, shelf_objs: List[Dict[str, Any]]) -> List[List[np.ndarray]]:
84
+ merged: List[List[np.ndarray]] = []
85
+ used = [False] * len(shelf_objs)
86
+
87
+ for i in range(len(shelf_objs)):
88
+ if used[i]:
89
+ continue
90
+
91
+ cur_boxes = shelf_objs[i]["boxes"]
92
+ cur_range = shelf_objs[i]["y_range"]
93
+
94
+ for j in range(i + 1, len(shelf_objs)):
95
+ if used[j]:
96
+ continue
97
+
98
+ overlap = self.vertical_overlap(cur_range, shelf_objs[j]["y_range"])
99
+
100
+ if (
101
+ overlap > self.merge_overlap_threshold
102
+ and (
103
+ len(cur_boxes) < self.min_items_per_shelf
104
+ or len(shelf_objs[j]["boxes"]) < self.min_items_per_shelf
105
+ )
106
+ ):
107
+ cur_boxes.extend(shelf_objs[j]["boxes"])
108
+ used[j] = True
109
+
110
+ merged.append(cur_boxes)
111
+ used[i] = True
112
+
113
+ return merged
114
+
115
+ def annotate_and_build_metadata(self, shelves, draw: ImageDraw.ImageDraw):
116
+ final_boxes = []
117
+ shelf_metadata = []
118
+
119
+ avg_items = np.mean([len(s) for s in shelves]) if shelves else 1
120
+
121
+ for shelf_id, shelf in enumerate(shelves, start=1):
122
+ ys = [b[1] for b in shelf] + [b[3] for b in shelf]
123
+ min_y, max_y = min(ys), max(ys)
124
+
125
+ num_items = len(shelf)
126
+ confidence = round(num_items / avg_items, 2)
127
+
128
+ shelf_metadata.append(
129
+ {
130
+ "shelf_id": shelf_id,
131
+ "num_items": num_items,
132
+ "y_range": (int(min_y), int(max_y)),
133
+ "confidence": confidence,
134
+ "status": "stable" if confidence >= 0.5 else "unstable",
135
+ }
136
+ )
137
+
138
+ for b in shelf:
139
+ draw.rectangle([b[0], b[1], b[2], b[3]], outline="red", width=3)
140
+ draw.text((b[0], b[1] - 10), f"S{shelf_id}", fill="red")
141
+ final_boxes.append(b)
142
+
143
+ return final_boxes, shelf_metadata
144
+
145
+ def crop_annotated_image_by_object(
146
+ self,
147
+ annotated_img: Image.Image,
148
+ boxes: List[np.ndarray],
149
+ box_id: int | None = None,
150
+ padding: int = 5,
151
+ ):
152
+ width, height = annotated_img.size
153
+
154
+ def _safe_crop(x1, y1, x2, y2):
155
+ x1 = max(0, int(x1 - padding))
156
+ y1 = max(0, int(y1 - padding))
157
+ x2 = min(width, int(x2 + padding))
158
+ y2 = min(height, int(y2 + padding))
159
+ return annotated_img.crop((x1, y1, x2, y2))
160
+
161
+ if box_id is not None:
162
+ if box_id < 0 or box_id >= len(boxes):
163
+ raise IndexError(f"Box ID {box_id} out of range")
164
+
165
+ x1, y1, x2, y2 = boxes[box_id]
166
+ return _safe_crop(x1, y1, x2, y2)
167
+
168
+ cropped = {}
169
+ for i, (x1, y1, x2, y2) in enumerate(boxes):
170
+ cropped[i] = _safe_crop(x1, y1, x2, y2)
171
+
172
+ return cropped
173
+
174
+ def run(self, image: Image.Image):
175
+ boxes, img, draw = self.run_inference(image)
176
+
177
+ if boxes is None:
178
+ return [], [], 0, img
179
+
180
+ shelves = self.group_boxes_into_shelves(boxes)
181
+ shelf_objs = self.build_shelf_objects(shelves)
182
+ merged_shelves = self.merge_weak_shelves(shelf_objs)
183
+
184
+ final_boxes, shelf_metadata = self.annotate_and_build_metadata(
185
+ merged_shelves, draw
186
+ )
187
+
188
+ return final_boxes, shelf_metadata, len(merged_shelves), img
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ ultralytics>=8.0.0
3
+ torch
4
+ torchvision
5
+ transformers>=4.38.0
6
+ pillow
7
+ numpy
router.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Gradio router for shelf analysis."""
3
+
4
+ from __future__ import annotations
5
+
6
+ from pathlib import Path
7
+ from typing import List
8
+
9
+ import sys
10
+
11
+ import gradio as gr
12
+ import torch
13
+ from PIL import Image, ImageDraw
14
+ from ultralytics import YOLO
15
+
16
+ PROJECT_ROOT = Path(__file__).resolve().parents[1]
17
+ sys.path.append(str(PROJECT_ROOT))
18
+
19
+ from src.processor import ShelfInventoryProcessor
20
+ from src.augmentation_embeddings import (
21
+ build_reference_embeddings,
22
+ compute_similarities,
23
+ adaptive_similarity_threshold,
24
+ calculate_shelf_share,
25
+ )
26
+
27
+
28
+ MODEL_PATH = PROJECT_ROOT / "models" / "best.pt"
29
+
30
+
31
+ def get_device() -> torch.device:
32
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
+
34
+
35
+ def load_model() -> YOLO:
36
+ if not MODEL_PATH.exists():
37
+ raise FileNotFoundError(f"Model not found at {MODEL_PATH}")
38
+ return YOLO(str(MODEL_PATH))
39
+
40
+
41
+ MODEL = load_model()
42
+ PROCESSOR = ShelfInventoryProcessor(model=MODEL)
43
+
44
+
45
+ def _load_reference_images(reference_paths: List[str]) -> List[Image.Image]:
46
+ images: List[Image.Image] = []
47
+ for path in reference_paths:
48
+ img = Image.open(path).convert("RGB")
49
+ images.append(img)
50
+ return images
51
+
52
+
53
+ def _build_facing_text(shelf_metadata, shelf_share: float) -> str:
54
+ if not shelf_metadata:
55
+ return "facing: no shelves detected"
56
+
57
+ best_shelf = max(shelf_metadata, key=lambda s: s["num_items"])
58
+ label = "very good place" if shelf_share >= 0.7 else "needs attention"
59
+ return f"facing: shelf {best_shelf['shelf_id']} {label}"
60
+
61
+
62
+ def analyze_shelf(shelf_image: Image.Image, reference_files: List[str]):
63
+ if shelf_image is None:
64
+ return "Please upload a shelf photo.", None
65
+ if not reference_files:
66
+ return "Please upload at least one reference photo.", None
67
+
68
+ device = get_device()
69
+
70
+ boxes, metadata, _shelf_count, annotated_img = PROCESSOR.run(shelf_image)
71
+ if not boxes:
72
+ return "No products detected.", annotated_img
73
+
74
+ object_crops = PROCESSOR.crop_annotated_image_by_object(shelf_image, boxes)
75
+
76
+ ref_images = _load_reference_images(reference_files)
77
+ ref_embeddings = build_reference_embeddings(ref_images, device)
78
+
79
+ similarities = compute_similarities(object_crops, ref_embeddings, device)
80
+ if not similarities:
81
+ return "No matches found.", annotated_img
82
+
83
+ threshold = adaptive_similarity_threshold(similarities)
84
+ shelf_share, stock_status, total_area, matched_area = calculate_shelf_share(
85
+ similarities, boxes, threshold
86
+ )
87
+
88
+ facing_text = _build_facing_text(metadata, shelf_share)
89
+
90
+ result_lines = [
91
+ f"Shelf Share: {shelf_share * 100:.2f}%",
92
+ facing_text,
93
+ f"Stock Status: {stock_status}",
94
+ f"Matched Area: {matched_area:.0f} px² / Total Shelf Area: {total_area:.0f} px²",
95
+ ]
96
+
97
+ annotated = shelf_image.copy()
98
+ draw = ImageDraw.Draw(annotated)
99
+
100
+ for s in similarities:
101
+ if s["similarity"] < threshold:
102
+ continue
103
+ x1, y1, x2, y2 = map(int, boxes[s["box_id"]])
104
+ draw.rectangle([x1, y1, x2, y2], outline="green", width=3)
105
+ draw.text((x1, max(y1 - 12, 0)), f"{s['similarity']:.2f}", fill="green")
106
+
107
+ return "\n".join(result_lines), annotated
108
+
109
+
110
+ def build_app():
111
+ with gr.Blocks(title="Shelf Analysis") as demo:
112
+ gr.Markdown("# Shelf Analysis")
113
+ gr.Markdown(
114
+ "Upload a shelf photo and one or more reference product photos."
115
+ )
116
+
117
+ with gr.Row():
118
+ shelf_input = gr.Image(type="pil", label="Shelf Photo")
119
+ ref_input = gr.File(
120
+ file_types=["image"],
121
+ file_count="multiple",
122
+ type="filepath",
123
+ label="Reference Photos",
124
+ )
125
+
126
+ with gr.Row():
127
+ output_text = gr.Textbox(label="Results", lines=6)
128
+ output_image = gr.Image(type="pil", label="Annotated Matches")
129
+
130
+ analyze_btn = gr.Button("Analyze")
131
+
132
+ analyze_btn.click(
133
+ fn=analyze_shelf,
134
+ inputs=[shelf_input, ref_input],
135
+ outputs=[output_text, output_image],
136
+ )
137
+
138
+ return demo
139
+
140
+
141
+ if __name__ == "__main__":
142
+ app = build_app()
143
+ app.launch()