|
|
import sys |
|
|
from io import BytesIO |
|
|
from typing import List, Optional, Tuple |
|
|
|
|
|
import cv2 |
|
|
import numpy as np |
|
|
import torch |
|
|
from fastapi import FastAPI, UploadFile |
|
|
from fastapi.responses import FileResponse |
|
|
from PIL import Image |
|
|
|
|
|
from .models.experimental import attempt_load |
|
|
from .utils.general import ( |
|
|
check_img_size, |
|
|
non_max_suppression, |
|
|
scale_coords, |
|
|
set_logging, |
|
|
) |
|
|
from .utils.torch_utils import select_device |
|
|
|
|
|
|
|
|
class YOLODetector: |
|
|
"""YOLO object detector for detecting and cropping cats from images.""" |
|
|
|
|
|
def __init__(self, weights: str, device: str = '', img_size: int = 640): |
|
|
"""Initialize YOLO detector with specified weights and device. |
|
|
|
|
|
Args: |
|
|
weights: Path to weights file |
|
|
device: Device to run inference on ('cpu' or 'cuda') |
|
|
img_size: Input image size |
|
|
""" |
|
|
self.device = select_device(device) |
|
|
try: |
|
|
|
|
|
print("segmentation_model1") |
|
|
|
|
|
self.model = torch.hub.load( |
|
|
'ultralytics/yolov5', 'yolov5s', pretrained=True) |
|
|
|
|
|
self.model.to(self.device) |
|
|
self.model.eval() |
|
|
|
|
|
self.model.classes = [15] |
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Failed to load YOLOv5 model: {str(e)}") |
|
|
|
|
|
self.img_size = check_img_size(img_size, s=self.model.stride) |
|
|
set_logging() |
|
|
|
|
|
def preprocess_image( |
|
|
self, file: UploadFile) -> Tuple[np.ndarray, torch.Tensor]: |
|
|
"""Preprocess uploaded image for cat detection. |
|
|
|
|
|
Args: |
|
|
file: Uploaded file containing image |
|
|
|
|
|
Returns: |
|
|
Tuple of (original image, preprocessed tensor) |
|
|
|
|
|
Raises: |
|
|
RuntimeError: If image preprocessing fails |
|
|
""" |
|
|
try: |
|
|
contents = file.file.read() |
|
|
nparr = np.frombuffer(contents, np.uint8) |
|
|
im0 = cv2.imdecode(nparr, cv2.IMREAD_COLOR) |
|
|
if im0 is None: |
|
|
raise ValueError("Failed to decode image") |
|
|
|
|
|
|
|
|
img = letterbox(im0, new_shape=self.img_size)[0] |
|
|
|
|
|
|
|
|
img = torch.from_numpy(img).to(self.device) |
|
|
img = img.permute(2, 0, 1).float() |
|
|
img = img.unsqueeze(0) / 255.0 |
|
|
|
|
|
return im0, img |
|
|
|
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Error preprocessing image: {str(e)}") |
|
|
|
|
|
def preprocess_image_binary( |
|
|
self, image: np.ndarray) -> Tuple[np.ndarray, torch.Tensor]: |
|
|
"""Preprocess image for cat detection. |
|
|
|
|
|
Args: |
|
|
image: Image in numpy.ndarray format |
|
|
|
|
|
Returns: |
|
|
Tuple of (original image, preprocessed tensor) |
|
|
|
|
|
Raises: |
|
|
RuntimeError: If image preprocessing fails |
|
|
""" |
|
|
try: |
|
|
if image is None: |
|
|
raise ValueError("Image is None") |
|
|
|
|
|
|
|
|
im0 = image |
|
|
|
|
|
|
|
|
img = letterbox(im0, new_shape=self.img_size)[0] |
|
|
|
|
|
|
|
|
img = torch.from_numpy(img).to(self.device) |
|
|
img = img.permute(2, 0, 1).float() |
|
|
img = img.unsqueeze(0) / 255.0 |
|
|
|
|
|
return im0, img |
|
|
|
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Error preprocessing image: {str(e)}") |
|
|
|
|
|
def infer(self, img: torch.Tensor) -> List[torch.Tensor]: |
|
|
"""Run cat detection inference on preprocessed image. |
|
|
|
|
|
Args: |
|
|
img: Preprocessed image tensor |
|
|
|
|
|
Returns: |
|
|
List of detection tensors after NMS |
|
|
""" |
|
|
with torch.no_grad(): |
|
|
pred = self.model(img) |
|
|
|
|
|
return non_max_suppression(pred, conf_thres=0.4, iou_thres=0.45) |
|
|
|
|
|
def process_detections(self, |
|
|
im0: np.ndarray, |
|
|
detections: List[torch.Tensor], |
|
|
resolution_factor: float = 1) -> Optional[np.ndarray]: |
|
|
"""Process detections to crop detected cat with adjustable padding. |
|
|
|
|
|
Args: |
|
|
im0: Original image |
|
|
detections: Detection results from inference |
|
|
padding_factor: Factor to increase the bounding box size (default is 0.1 for 10% padding) |
|
|
|
|
|
Returns: |
|
|
Cropped image of detected cat or None if no cats detected |
|
|
""" |
|
|
if not isinstance(detections, (list, tuple)) or not detections: |
|
|
return None |
|
|
|
|
|
for det in detections: |
|
|
if len(det): |
|
|
|
|
|
det[:, :4] = scale_coords( |
|
|
(self.img_size, self.img_size), det[:, :4], im0.shape).round() |
|
|
|
|
|
|
|
|
best_det = det[det[:, 4].argmax()] |
|
|
return self.crop_and_save_image( |
|
|
im0, best_det[:4], resolution_factor) |
|
|
return None |
|
|
|
|
|
@staticmethod |
|
|
def crop_and_save_image( |
|
|
im0: np.ndarray, |
|
|
xyxy: torch.Tensor, |
|
|
resolution_factor: float) -> np.ndarray: |
|
|
"""Crop detected cat from image with a specified resolution factor. |
|
|
|
|
|
Args: |
|
|
im0: Original image |
|
|
xyxy: Bounding box coordinates tensor |
|
|
resolution_factor: Factor to adjust the width relative to the height |
|
|
|
|
|
Returns: |
|
|
Cropped image array of the cat |
|
|
|
|
|
Raises: |
|
|
ValueError: If bounding box coordinates are invalid |
|
|
""" |
|
|
|
|
|
x_min, y_min, x_max, y_max = map(int, xyxy) |
|
|
|
|
|
|
|
|
x_center = (x_min + x_max) / 2 |
|
|
y_center = (y_min + y_max) / 2 |
|
|
|
|
|
|
|
|
new_height = (y_max - y_min) |
|
|
new_width = new_height * resolution_factor |
|
|
|
|
|
|
|
|
x_min_new = int(x_center - new_width / 2) |
|
|
x_max_new = int(x_center + new_width / 2) |
|
|
y_min_new = int(y_center - new_height / 2) |
|
|
y_max_new = int(y_center + new_height / 2) |
|
|
|
|
|
|
|
|
if x_min_new < 0: |
|
|
x_max_new = min(x_max_new - x_min_new, im0.shape[1]) |
|
|
x_min_new = 0 |
|
|
if x_max_new > im0.shape[1]: |
|
|
x_min_new = max(x_min_new - (x_max_new - im0.shape[1]), 0) |
|
|
x_max_new = im0.shape[1] |
|
|
|
|
|
|
|
|
if x_min_new >= x_max_new or y_min_new >= y_max_new: |
|
|
raise ValueError("Invalid bounding box coordinates") |
|
|
|
|
|
return im0[y_min_new:y_max_new, x_min_new:x_max_new] |
|
|
|
|
|
|
|
|
def letterbox(img: np.ndarray, new_shape: int = 640, color: Tuple[int, int, int] = ( |
|
|
114, 114, 114)) -> Tuple[np.ndarray, float, Tuple[float, float]]: |
|
|
"""Resize and pad image while meeting stride-multiple constraints.""" |
|
|
shape = img.shape[:2] |
|
|
|
|
|
|
|
|
r = min(new_shape / shape[0], new_shape / shape[1]) |
|
|
|
|
|
|
|
|
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) |
|
|
dw, dh = new_shape - new_unpad[0], new_shape - new_unpad[1] |
|
|
|
|
|
dw /= 2 |
|
|
dh /= 2 |
|
|
|
|
|
if shape[::-1] != new_unpad: |
|
|
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR) |
|
|
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) |
|
|
left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) |
|
|
|
|
|
img = cv2.copyMakeBorder( |
|
|
img, |
|
|
top, |
|
|
bottom, |
|
|
left, |
|
|
right, |
|
|
cv2.BORDER_CONSTANT, |
|
|
value=color) |
|
|
return img, r, (dw, dh) |
|
|
|