| import imp |
| from time import time |
| from typing import List |
| from sam2.build_sam import build_sam2 |
| from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator |
| import tqdm |
| import numpy as np |
| import matplotlib.pyplot as plt |
| from torchvision import transforms |
| import torch |
| import torch.nn as nn |
| from torchvision import models |
| import cv2 |
| from PIL import Image |
| import numpy as np |
| import timeit |
| import albumentations as A |
| from albumentations.pytorch import ToTensorV2 |
|
|
| |
| if torch.cuda.is_available(): |
| device = torch.device("cuda") |
| elif torch.backends.mps.is_available(): |
| device = torch.device("mps") |
| else: |
| device = torch.device("cpu") |
| print(f"using device: {device}") |
|
|
| if device.type == "cuda": |
| |
| torch.autocast("cuda", dtype=torch.bfloat16).__enter__() |
| |
| if torch.cuda.get_device_properties(0).major >= 8: |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
| elif device.type == "mps": |
| print( |
| "\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might " |
| "give numerically different outputs and sometimes degraded performance on MPS. " |
| "See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion." |
| ) |
|
|
|
|
| class FridgeItemClassifier(nn.Module): |
| def __init__(self, input_size=2048, hidden_size=512, num_classes=67, dropout=0.5): |
| super(FridgeItemClassifier, self).__init__() |
| self.name = "FridgeItemClassifier" |
| self.classifier = nn.Sequential( |
| nn.Linear(input_size, hidden_size), |
| nn.ReLU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden_size, num_classes) |
| ) |
|
|
| def forward(self, x): |
| return self.classifier(x) |
|
|
|
|
| class FridgeItemDetector: |
| def __init__(self, |
| class_dict: List[str], |
| class_color_dict: List, |
| input_image_size = (1024, 1024), |
| segmentation_size = (224, 224), |
| device = torch.device("cpu"), |
| sam2_checkpoint = "sam2.1_hiera_small.pt", |
| sam2_model_cfg = "configs/sam2.1/sam2.1_hiera_s.yaml", |
| fridge_object_classifier_checkpoint = "model_FridgeItemClassifier_bs128_lr0.001_epoch29.pt", |
| resnet50_checkpoint = None): |
| |
| self.class_dict = class_dict |
| self.class_color_dict = class_color_dict |
| self.input_image_size = input_image_size |
| self.segmentation_size = segmentation_size |
| self.device = device |
| |
| self.sam2 = build_sam2(sam2_model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False) |
| self.mask_generator = SAM2AutomaticMaskGenerator( |
| model=self.sam2, |
| device=device, |
| points_per_side=32, |
| points_per_batch=32, |
| pred_iou_thresh=0.9, |
| stability_score_thresh=0.9, |
| stability_score_offset=0.7, |
| crop_n_layers=0, |
| box_nms_thresh=0.7, |
| crop_n_points_downscale_factor=1, |
| min_mask_region_area=0, |
| use_m2m=True, |
| ) |
| |
| if resnet50_checkpoint is not None: |
| resnet50_state = torch.load(resnet50_checkpoint, weights_only=True, map_location=device) |
| self.resnet50 = models.resnet50(pretrained=False) |
| self.resnet50.load_state_dict(resnet50_state) |
| else: |
| self.resnet50 = models.resnet50(pretrained=True) |
| self.resnet50_feature_extractor = nn.Sequential(*list(self.resnet50.children())[:-1]) |
| self.object_classifier = FridgeItemClassifier(input_size=2048, hidden_size=512, num_classes=len(self.class_dict), dropout=0.2) |
| state = torch.load(fridge_object_classifier_checkpoint, weights_only=True, map_location=device) |
| self.object_classifier.load_state_dict(state) |
| self.resnet50_feature_extractor.eval() |
| self.object_classifier.eval() |
| self.resnet50_feature_extractor.to(device) |
| self.object_classifier.to(device) |
| self.transform_sub_image = A.Compose([ |
| A.LongestMaxSize(max_size=max(*self.segmentation_size)), |
| A.PadIfNeeded(min_height=self.segmentation_size[0], min_width=self.segmentation_size[1], border_mode=cv2.BORDER_CONSTANT, value=(0, 0, 0)), |
| A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), |
| ToTensorV2(), |
| ]) |
|
|
| def _set_progress_description(self, description): |
| print(f"FridgeItemDetector: {description}...", flush=True) |
|
|
| def _set_progress_tick(self): |
| print("Done!", flush=True) |
|
|
| def load_image(self, image_path): |
| image = Image.open(image_path) |
| if self.input_image_size is not None: |
| image = image.resize(self.input_image_size) |
| image = np.array(image.convert("RGB")) |
| return image |
|
|
| def locate_objects(self, image): |
| self._set_progress_description("Locating Objects") |
| masks = self.mask_generator.generate(image) |
| self._set_progress_tick() |
| return masks |
|
|
| def classify_objects(self, sub_images, batch_size=32): |
| self._set_progress_description("Classifying Objects") |
| if len(sub_images) == 0: |
| return [] |
| |
| transformed_sub_images = [self.transform_sub_image(image=sub_image)["image"] for sub_image in sub_images] |
| sub_images = torch.stack(transformed_sub_images) |
| results = [] |
| with torch.no_grad(): |
| for batch_start_idx in range(0, len(sub_images), batch_size): |
| batch_end_idx = min(batch_start_idx + batch_size, len(sub_images)) |
| batch_sub_images = sub_images[batch_start_idx: batch_end_idx] |
| batch_sub_images = batch_sub_images.to(self.device) |
| |
| batch_features = self.resnet50_feature_extractor(batch_sub_images) |
| batch_features = batch_features.view(batch_features.size(0), -1) |
| |
| pred_logits = self.object_classifier(batch_features) |
| pred_prob = torch.softmax(pred_logits, dim=1) |
| probs, indices = torch.max(pred_prob, dim=1) |
| probs = probs.detach().cpu().numpy() |
| indices = indices.detach().cpu().numpy() |
| for prob, index in zip(probs, indices): |
| results.append((index.item(), self.class_dict[index], prob.item())) |
| |
| del batch_sub_images |
| del batch_features |
| del pred_logits |
| del pred_prob |
| del probs |
| del indices |
| torch.cuda.empty_cache() |
| self._set_progress_tick() |
| return results |
|
|
| def crop_objects(self, image, masks): |
| sub_images = [] |
| for mask in masks: |
| x, y, w, h = mask["bbox"] |
| m = mask["segmentation"] |
| sub_image = image[int(y):int(y+h), int(x):int(x+w)] |
| sub_mask = m[int(y):int(y+h), int(x):int(x+w)] |
| sub_image = sub_image * sub_mask[:, :, np.newaxis] |
| sub_images.append(sub_image) |
| return sub_images |
|
|
| def annotate_objects(self, orig_image, results, draw_borders=True, draw_boxes=True, draw_text=True): |
| if len(results) == 0: |
| return |
| sorted_masks = sorted(results, key=(lambda x: x[0]['area']), reverse=True) |
| ax = plt.gca() |
| ax.set_autoscale_on(False) |
|
|
| img = np.ones((orig_image.shape[0], orig_image.shape[1], 4)) |
| img[:, :, 3] = 0 |
| for mask, (class_index, class_name, class_prob) in sorted_masks: |
| m = mask['segmentation'] |
| class_color = self.class_color_dict[class_index] |
| color_mask = np.concatenate([class_color, [0.5]]) |
| img[m] = color_mask |
| if draw_borders: |
| contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) |
| contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours] |
| cv2.drawContours(img, contours, -1, (0, 0, 1, 0.4), thickness=1) |
| if draw_boxes: |
| x, y, w, h = mask["bbox"] |
| color = np.concatenate([class_color, [1]]) |
| cv2.rectangle(img, (int(x), int(y)), (int(x+w), int(y+h)), color, thickness=2) |
| if draw_text: |
| x, y, w, h = mask["bbox"] |
| text = f"{class_name} {class_prob * 100:.1f}%" |
| font = cv2.FONT_HERSHEY_SIMPLEX |
| fontScale = 0.5 |
| color = (0, 0, 0, 1) |
| thickness = 2 |
| cv2.putText(img, text, (int(x), int(y)), font, fontScale, color, thickness, cv2.LINE_AA) |
| ax.imshow(img) |
|
|
| def detect_objects(self, |
| image_path, |
| prob_cutoff=0.9, |
| black_list_classes=[], |
| annotate_image=False, |
| annotate_unfiltered=False, |
| return_unique_label=True, |
| debug=False): |
| start_time = timeit.default_timer() |
| |
| image = self.load_image(image_path) |
| masks = self.locate_objects(image) |
| sub_images = self.crop_objects(image, masks) |
| class_labels = self.classify_objects(sub_images) |
| results = [*zip(masks, class_labels)] |
| |
| results = [(mask, class_label) for mask, class_label in results if class_label[1] != "non_food"] |
| |
| results = [(mask, class_label) for mask, class_label in results if class_label[1] not in black_list_classes] |
| |
| results = [(mask, class_label) for mask, class_label in results if class_label[2] >= prob_cutoff] |
| |
| results = sorted(results, key=lambda x: x[1][2], reverse=True) |
| if annotate_image: |
| self._set_progress_description("Generating Annotation") |
| |
| plt.figure(figsize=(20, 20)) |
| plt.imshow(image) |
| plt.axis('off') |
| plt.show() |
| |
| if annotate_unfiltered: |
| plt.figure(figsize=(20, 20)) |
| plt.imshow(image) |
| annotate_results = [ |
| (mask, class_label) |
| for mask, class_label in zip(masks, class_labels) |
| if class_label[2] >= prob_cutoff] |
| self.annotate_objects(image, annotate_results) |
| plt.axis('off') |
| plt.savefig(f'{image_path}_annotated_with_non_food.jpg') |
| plt.show() |
| |
| plt.figure(figsize=(20, 20)) |
| plt.imshow(image) |
| self.annotate_objects(image, results) |
| plt.axis('off') |
| plt.savefig(f'{image_path}_annotated.jpg') |
| plt.show() |
| self._set_progress_tick() |
| elapsed_time = timeit.default_timer() - start_time |
| print(f"FridgeItemDetector: detect_objects took {elapsed_time:.4f} seconds") |
| if debug: |
| return results |
| if return_unique_label: |
| result_set = set() |
| unique_results = [] |
| for _, (class_index, class_name, _) in results: |
| if class_index not in result_set: |
| unique_results.append(class_name) |
| result_set.add(class_index) |
| return unique_results |
| return [class_labels for _, class_labels in results] |
|
|
| |
| np.random.seed(23333) |
| class_dict = ['apple', 'asparagus', 'aubergine', 'bacon', 'banana', 'basil', 'beans', 'beef', 'beetroot', 'bell pepper', 'bitter gourd', 'blueberries', 'broccoli', 'cabbage', 'carrot', 'cauliflower', 'cheese', 'chicken', 'chillies', 'chocolate', 'coriander', 'corn', 'courgettes', 'cream', 'cucumber', 'dates', 'egg', 'flour', 'garlic', 'ginger', 'green beans', 'green chilies', 'ham', 'juice', 'lemon', 'lettuce', 'lime', 'mango', 'meat', 'mineral water', 'mushroom', 'olive', 'onion', 'orange', 'parsley', 'peach', 'peas', 'peppers', 'potato', 'pumpkin', 'red grapes', 'red onion', 'salami', 'sauce', 'sausage', 'shallot', 'shrimp', 'spinach', 'spring onion', 'strawberry', 'sugar', 'sweet potato', 'swiss butter', 'swiss jam', 'swiss yoghurt', 'tomato', 'watermelon'] |
| extended_class_dict = class_dict + ['non_food'] |
| class_color_dict = [np.random.random(3) for i in range(len(extended_class_dict))] |
| fridge_item_detector = FridgeItemDetector( |
| class_dict=extended_class_dict, |
| class_color_dict=class_color_dict, |
| input_image_size=(1024, 1024), |
| device=device, |
| sam2_checkpoint="sam2.1_hiera_large.pt", |
| sam2_model_cfg="configs/sam2.1/sam2.1_hiera_l.yaml", |
| fridge_object_classifier_checkpoint="model_FridgeItemClassifier_bs128_lr0.001_epoch29.pt", |
| resnet50_checkpoint="resnet50-0676ba61.pth") |
|
|
| |
| black_list_classes = ["bitter gourd", "pumpkin", "blueberries"] |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |