Spaces:
Runtime error
Runtime error
| from io import BytesIO | |
| from icevision import * | |
| import collections | |
| import PIL | |
| import torch | |
| import numpy as np | |
| import torchvision | |
| import icevision.models.ross.efficientdet | |
| MODEL_TYPE = icevision.models.ross.efficientdet | |
| def get_model(checkpoint_path): | |
| extra_args = {} | |
| backbone = MODEL_TYPE.backbones.d0 | |
| # The efficientdet model requires an img_size parameter | |
| extra_args['img_size'] = 512 | |
| model = MODEL_TYPE.model(backbone=backbone(pretrained=True), | |
| num_classes=2, | |
| **extra_args) | |
| ckpt = get_checkpoint(checkpoint_path) | |
| model.load_state_dict(ckpt) | |
| return model | |
| def get_checkpoint(checkpoint_path): | |
| ckpt = torch.load('checkpoint.ckpt', map_location=torch.device('cpu')) | |
| fixed_state_dict = collections.OrderedDict() | |
| for k, v in ckpt['state_dict'].items(): | |
| new_k = k[6:] | |
| fixed_state_dict[new_k] = v | |
| return fixed_state_dict | |
| def predict(model, image, detection_threshold): | |
| img = PIL.Image.open(image) | |
| #img = PIL.Image.open(BytesIO(image)) | |
| img = np.array(img) | |
| img = PIL.Image.fromarray(img) | |
| class_map = ClassMap(classes=['Waste']) | |
| transforms = tfms.A.Adapter([ | |
| *tfms.A.resize_and_pad(512), | |
| tfms.A.Normalize() | |
| ]) | |
| pred_dict = MODEL_TYPE.end2end_detect(img, | |
| transforms, | |
| model, | |
| class_map=class_map, | |
| detection_threshold=detection_threshold, | |
| return_as_pil_img=False, | |
| return_img=True, | |
| display_bbox=False, | |
| display_score=False, | |
| display_label=False) | |
| return pred_dict | |
| def prepare_prediction(pred_dict, threshold): | |
| boxes = [box.to_tensor() for box in pred_dict['detection']['bboxes']] | |
| boxes = torch.stack(boxes) | |
| scores = torch.as_tensor(pred_dict['detection']['scores']) | |
| labels = torch.as_tensor(pred_dict['detection']['label_ids']) | |
| image = np.array(pred_dict['img']) | |
| fixed_boxes = torchvision.ops.batched_nms(boxes, scores, labels, threshold) | |
| boxes = boxes[fixed_boxes, :] | |
| return boxes, image |