shell-size-app / spatstatapp /inference.py
HarryEslick's picture
src removed
b91a8b9
# %% |
from io import BytesIO
import PIL
from typing import Optional
import cv2
import numpy as np
import requests
import supervision as sv
import torch
from ultralytics import YOLO, SAM
from pathlib import Path
from supervision.detection.overlap_filter import OverlapFilter
from supervision.utils.image import crop_image
from supervision.detection.overlap_filter import (
box_non_max_suppression,
)
def plot_detections(image: np.ndarray|PIL.Image.Image|Path|str,
detections: sv.Detections,
annotations=["mask", "box", "label"],
):
"""
Uses supervision package to plot detections on image.
Args:
image: np.array or PIL.Image.Image
detections: supervision Dectetions object
annotations: list[str]. Defaults to ["mask", "box", "label"].
"""
if isinstance(image, (Path, str)):
annotated_image = cv2.imread(image)
else:
annotated_image=image.copy()
mask_annotator = sv.MaskAnnotator()
box_annotator = sv.BoxAnnotator()
label_annotator = sv.LabelAnnotator()
if "label" in annotations:
annotated_image = label_annotator.annotate(
scene=annotated_image,
detections=detections
)
if "box" in annotations:
annotated_image = box_annotator.annotate(
scene=annotated_image, detections=detections)
if "mask" in annotations:
annotated_image = mask_annotator.annotate(
scene=annotated_image, detections=detections)
sv.plot_image(annotated_image)
# annotated_image = label_annotator.annotate(
# scene=annotated_image, detections=detections)
def filter_ultralytics(results, edge_pct=0.01, conf_threshold=0.7):
"""returns boolean tensor for filtering detections"""
filters = [
(results.boxes.xyxyn<1-edge_pct).all(axis=1),
(results.boxes.xyxyn>edge_pct).all(axis=1),
results.boxes.conf>conf_threshold,
]
boolean_tensor = torch.stack(filters).all(axis=0)
return boolean_tensor
def inference_large(image:np.ndarray|PIL.Image.Image|Path|str,
model_path:Path,
sam_path:Optional[Path]=None,
edge_pct=0.01,
conf_threshold=0.7,
overlap_px=320,
tile_px=800
) -> sv.Detections:
if isinstance(image, (Path, str)):
image = cv2.imread(image)
else:
image=image.copy()
det_model = YOLO(model_path)
if sam_path:
sam_model = SAM(sam_path)
def callback(image_slice: np.ndarray) -> sv.Detections:
det_results = det_model(image_slice)[0]
keep_idx = filter_ultralytics(det_results, edge_pct=edge_pct, conf_threshold=conf_threshold)
if sum(keep_idx)==0:
return sv.Detections.empty()
if sam_path is None:
return sv.Detections.from_ultralytics(det_results[keep_idx])
else:
boxes = det_results.boxes.xyxy[keep_idx] # Boxes object for bbox outputs
sam_results = sam_model(det_results.orig_img.copy(), bboxes=boxes, verbose=False, save=False, device="cpu")[0]
detections = sv.Detections.from_ultralytics(sam_results)
return detections
slicer = sv.InferenceSlicer(callback = callback,
overlap_ratio_wh = None,
# overlap_wh = (200, 200),
# slice_wh = (320, 320),
overlap_wh = (overlap_px, overlap_px),
slice_wh = (tile_px, tile_px),
iou_threshold = 0.7,
)
detections = slicer(image)
# manual nms to remove overlapping detections
# standard method uses mask IOU, but this is not catch all box IOU overlaps
predictions = np.hstack((detections.xyxy, detections.confidence.reshape(-1, 1)))
indices = box_non_max_suppression(
predictions=predictions, iou_threshold=0.7
)
return detections[indices]