CookBasedOnWhatYouHave / fridgeItemDetector.py
Fourtris's picture
Commit Project
74d6c39
import imp
from time import time
from typing import List
from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
import tqdm
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
import torch
import torch.nn as nn
from torchvision import models
import cv2
from PIL import Image
import numpy as np
import timeit
import albumentations as A
from albumentations.pytorch import ToTensorV2
# select the device for computation
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
print(f"using device: {device}")
if device.type == "cuda":
# use bfloat16 for the entire notebook
torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
elif device.type == "mps":
print(
"\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
"give numerically different outputs and sometimes degraded performance on MPS. "
"See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion."
)
class FridgeItemClassifier(nn.Module):
def __init__(self, input_size=2048, hidden_size=512, num_classes=67, dropout=0.5):
super(FridgeItemClassifier, self).__init__()
self.name = "FridgeItemClassifier"
self.classifier = nn.Sequential(
nn.Linear(input_size, hidden_size), # First fully connected layer
nn.ReLU(), # Activation function
nn.Dropout(dropout), # Dropout for regularization
nn.Linear(hidden_size, num_classes) # Output layer matching the number of classes
)
def forward(self, x):
return self.classifier(x)
class FridgeItemDetector:
def __init__(self,
class_dict: List[str],
class_color_dict: List,
input_image_size = (1024, 1024),
segmentation_size = (224, 224),
device = torch.device("cpu"),
sam2_checkpoint = "sam2.1_hiera_small.pt",
sam2_model_cfg = "configs/sam2.1/sam2.1_hiera_s.yaml",
fridge_object_classifier_checkpoint = "model_FridgeItemClassifier_bs128_lr0.001_epoch29.pt",
resnet50_checkpoint = None):
# Parameters
self.class_dict = class_dict
self.class_color_dict = class_color_dict
self.input_image_size = input_image_size
self.segmentation_size = segmentation_size
self.device = device
# SAM2: Object Localization
self.sam2 = build_sam2(sam2_model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)
self.mask_generator = SAM2AutomaticMaskGenerator(
model=self.sam2,
device=device,
points_per_side=32,
points_per_batch=32,
pred_iou_thresh=0.9,
stability_score_thresh=0.9,
stability_score_offset=0.7,
crop_n_layers=0,
box_nms_thresh=0.7,
crop_n_points_downscale_factor=1,
min_mask_region_area=0,
use_m2m=True,
)
# ResNet50: Object Classification
if resnet50_checkpoint is not None:
resnet50_state = torch.load(resnet50_checkpoint, weights_only=True, map_location=device)
self.resnet50 = models.resnet50(pretrained=False)
self.resnet50.load_state_dict(resnet50_state)
else:
self.resnet50 = models.resnet50(pretrained=True)
self.resnet50_feature_extractor = nn.Sequential(*list(self.resnet50.children())[:-1])
self.object_classifier = FridgeItemClassifier(input_size=2048, hidden_size=512, num_classes=len(self.class_dict), dropout=0.2)
state = torch.load(fridge_object_classifier_checkpoint, weights_only=True, map_location=device)
self.object_classifier.load_state_dict(state)
self.resnet50_feature_extractor.eval()
self.object_classifier.eval()
self.resnet50_feature_extractor.to(device)
self.object_classifier.to(device)
self.transform_sub_image = A.Compose([
A.LongestMaxSize(max_size=max(*self.segmentation_size)),
A.PadIfNeeded(min_height=self.segmentation_size[0], min_width=self.segmentation_size[1], border_mode=cv2.BORDER_CONSTANT, value=(0, 0, 0)),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
])
def _set_progress_description(self, description):
print(f"FridgeItemDetector: {description}...", flush=True)
def _set_progress_tick(self):
print("Done!", flush=True)
def load_image(self, image_path):
image = Image.open(image_path)
if self.input_image_size is not None:
image = image.resize(self.input_image_size)
image = np.array(image.convert("RGB"))
return image
def locate_objects(self, image):
self._set_progress_description("Locating Objects")
masks = self.mask_generator.generate(image)
self._set_progress_tick()
return masks
def classify_objects(self, sub_images, batch_size=32):
self._set_progress_description("Classifying Objects")
if len(sub_images) == 0:
return []
# Step 1: transform sub_images to the ResNet50 input dim
transformed_sub_images = [self.transform_sub_image(image=sub_image)["image"] for sub_image in sub_images]
sub_images = torch.stack(transformed_sub_images)
results = []
with torch.no_grad():
for batch_start_idx in range(0, len(sub_images), batch_size):
batch_end_idx = min(batch_start_idx + batch_size, len(sub_images))
batch_sub_images = sub_images[batch_start_idx: batch_end_idx]
batch_sub_images = batch_sub_images.to(self.device)
# Step 2: feature extract using resnet50
batch_features = self.resnet50_feature_extractor(batch_sub_images)
batch_features = batch_features.view(batch_features.size(0), -1)
# Step 3: fridge item classifier
pred_logits = self.object_classifier(batch_features)
pred_prob = torch.softmax(pred_logits, dim=1)
probs, indices = torch.max(pred_prob, dim=1)
probs = probs.detach().cpu().numpy()
indices = indices.detach().cpu().numpy()
for prob, index in zip(probs, indices):
results.append((index.item(), self.class_dict[index], prob.item()))
# clean up
del batch_sub_images
del batch_features
del pred_logits
del pred_prob
del probs
del indices
torch.cuda.empty_cache()
self._set_progress_tick()
return results
def crop_objects(self, image, masks):
sub_images = []
for mask in masks:
x, y, w, h = mask["bbox"]
m = mask["segmentation"]
sub_image = image[int(y):int(y+h), int(x):int(x+w)]
sub_mask = m[int(y):int(y+h), int(x):int(x+w)]
sub_image = sub_image * sub_mask[:, :, np.newaxis]
sub_images.append(sub_image)
return sub_images
def annotate_objects(self, orig_image, results, draw_borders=True, draw_boxes=True, draw_text=True):
if len(results) == 0:
return
sorted_masks = sorted(results, key=(lambda x: x[0]['area']), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)
img = np.ones((orig_image.shape[0], orig_image.shape[1], 4))
img[:, :, 3] = 0
for mask, (class_index, class_name, class_prob) in sorted_masks:
m = mask['segmentation']
class_color = self.class_color_dict[class_index]
color_mask = np.concatenate([class_color, [0.5]])
img[m] = color_mask
if draw_borders:
contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
cv2.drawContours(img, contours, -1, (0, 0, 1, 0.4), thickness=1)
if draw_boxes:
x, y, w, h = mask["bbox"]
color = np.concatenate([class_color, [1]])
cv2.rectangle(img, (int(x), int(y)), (int(x+w), int(y+h)), color, thickness=2)
if draw_text:
x, y, w, h = mask["bbox"]
text = f"{class_name} {class_prob * 100:.1f}%"
font = cv2.FONT_HERSHEY_SIMPLEX
fontScale = 0.5
color = (0, 0, 0, 1)
thickness = 2
cv2.putText(img, text, (int(x), int(y)), font, fontScale, color, thickness, cv2.LINE_AA)
ax.imshow(img)
def detect_objects(self,
image_path, # path of the inference image
prob_cutoff=0.9, # cutoff based on prediction probability
black_list_classes=[], # remove predictions with these class labels
annotate_image=False, # if true, show annotated images
annotate_unfiltered=False, # if true, show unfiltered annotated images
return_unique_label=True, # only return unique labels
debug=False):
start_time = timeit.default_timer()
# perform pipeline
image = self.load_image(image_path) # Load image
masks = self.locate_objects(image) # SAM2 mask gen
sub_images = self.crop_objects(image, masks) # Crop subimages by mask
class_labels = self.classify_objects(sub_images) # Classify subimages
results = [*zip(masks, class_labels)]
# remove non_food predictions
results = [(mask, class_label) for mask, class_label in results if class_label[1] != "non_food"]
# remove predictions on black lists
results = [(mask, class_label) for mask, class_label in results if class_label[1] not in black_list_classes]
# filter by cutoff
results = [(mask, class_label) for mask, class_label in results if class_label[2] >= prob_cutoff]
# sort by prob
results = sorted(results, key=lambda x: x[1][2], reverse=True)
if annotate_image:
self._set_progress_description("Generating Annotation")
# show org image for compare
plt.figure(figsize=(20, 20))
plt.imshow(image)
plt.axis('off')
plt.show()
# show annotated image with non_food masks
if annotate_unfiltered:
plt.figure(figsize=(20, 20))
plt.imshow(image)
annotate_results = [
(mask, class_label)
for mask, class_label in zip(masks, class_labels)
if class_label[2] >= prob_cutoff] # only filter by cutoff
self.annotate_objects(image, annotate_results)
plt.axis('off')
plt.savefig(f'{image_path}_annotated_with_non_food.jpg')
plt.show()
# show annotated image
plt.figure(figsize=(20, 20))
plt.imshow(image)
self.annotate_objects(image, results)
plt.axis('off')
plt.savefig(f'{image_path}_annotated.jpg')
plt.show()
self._set_progress_tick()
elapsed_time = timeit.default_timer() - start_time
print(f"FridgeItemDetector: detect_objects took {elapsed_time:.4f} seconds")
if debug:
return results
if return_unique_label:
result_set = set()
unique_results = []
for _, (class_index, class_name, _) in results:
if class_index not in result_set:
unique_results.append(class_name)
result_set.add(class_index)
return unique_results
return [class_labels for _, class_labels in results]
# Instantiate a FridgeItemDetector
np.random.seed(23333)
class_dict = ['apple', 'asparagus', 'aubergine', 'bacon', 'banana', 'basil', 'beans', 'beef', 'beetroot', 'bell pepper', 'bitter gourd', 'blueberries', 'broccoli', 'cabbage', 'carrot', 'cauliflower', 'cheese', 'chicken', 'chillies', 'chocolate', 'coriander', 'corn', 'courgettes', 'cream', 'cucumber', 'dates', 'egg', 'flour', 'garlic', 'ginger', 'green beans', 'green chilies', 'ham', 'juice', 'lemon', 'lettuce', 'lime', 'mango', 'meat', 'mineral water', 'mushroom', 'olive', 'onion', 'orange', 'parsley', 'peach', 'peas', 'peppers', 'potato', 'pumpkin', 'red grapes', 'red onion', 'salami', 'sauce', 'sausage', 'shallot', 'shrimp', 'spinach', 'spring onion', 'strawberry', 'sugar', 'sweet potato', 'swiss butter', 'swiss jam', 'swiss yoghurt', 'tomato', 'watermelon']
extended_class_dict = class_dict + ['non_food']
class_color_dict = [np.random.random(3) for i in range(len(extended_class_dict))]
fridge_item_detector = FridgeItemDetector(
class_dict=extended_class_dict,
class_color_dict=class_color_dict,
input_image_size=(1024, 1024),
device=device,
sam2_checkpoint="sam2.1_hiera_large.pt",
sam2_model_cfg="configs/sam2.1/sam2.1_hiera_l.yaml",
fridge_object_classifier_checkpoint="model_FridgeItemClassifier_bs128_lr0.001_epoch29.pt",
resnet50_checkpoint="resnet50-0676ba61.pth")
# Define a black list (the classess the models that isn't good on)
black_list_classes = ["bitter gourd", "pumpkin", "blueberries"]
# Inference Example:
#
# from fridgeItemDetector import fridge_item_detector
#
# fridge_item_detector.detect_objects(
# "path/to/image",
# prob_cutoff=0.9,
# annotate_image=False,
# return_unique_label=True
# )
#