import sys from io import BytesIO from typing import List, Optional, Tuple import cv2 import numpy as np import torch from fastapi import FastAPI, UploadFile from fastapi.responses import FileResponse from PIL import Image from .models.experimental import attempt_load from .utils.general import ( check_img_size, non_max_suppression, scale_coords, set_logging, ) from .utils.torch_utils import select_device class YOLODetector: """YOLO object detector for detecting and cropping cats from images.""" def __init__(self, weights: str, device: str = '', img_size: int = 640): """Initialize YOLO detector with specified weights and device. Args: weights: Path to weights file device: Device to run inference on ('cpu' or 'cuda') img_size: Input image size """ self.device = select_device(device) try: # Load YOLOv5 model specifically for cat detection print("segmentation_model1") self.model = torch.hub.load( 'ultralytics/yolov5', 'yolov5s', pretrained=True) self.model.to(self.device) self.model.eval() # Filter classes to only detect cats (class 15 in COCO dataset) self.model.classes = [15] except Exception as e: raise RuntimeError(f"Failed to load YOLOv5 model: {str(e)}") self.img_size = check_img_size(img_size, s=self.model.stride) set_logging() def preprocess_image( self, file: UploadFile) -> Tuple[np.ndarray, torch.Tensor]: """Preprocess uploaded image for cat detection. Args: file: Uploaded file containing image Returns: Tuple of (original image, preprocessed tensor) Raises: RuntimeError: If image preprocessing fails """ try: contents = file.file.read() nparr = np.frombuffer(contents, np.uint8) im0 = cv2.imdecode(nparr, cv2.IMREAD_COLOR) if im0 is None: raise ValueError("Failed to decode image") # Resize while maintaining aspect ratio img = letterbox(im0, new_shape=self.img_size)[0] # Convert to torch tensor img = torch.from_numpy(img).to(self.device) img = img.permute(2, 0, 1).float() # BGR to RGB, HWC to CHW img = img.unsqueeze(0) / 255.0 # Add batch dimension and normalize return im0, img except Exception as e: raise RuntimeError(f"Error preprocessing image: {str(e)}") def preprocess_image_binary( self, image: np.ndarray) -> Tuple[np.ndarray, torch.Tensor]: """Preprocess image for cat detection. Args: image: Image in numpy.ndarray format Returns: Tuple of (original image, preprocessed tensor) Raises: RuntimeError: If image preprocessing fails """ try: if image is None: raise ValueError("Image is None") # Оригинальное изображение (без изменений) im0 = image # Resize while maintaining aspect ratio img = letterbox(im0, new_shape=self.img_size)[0] # Convert to torch tensor img = torch.from_numpy(img).to(self.device) img = img.permute(2, 0, 1).float() # BGR to RGB, HWC to CHW img = img.unsqueeze(0) / 255.0 # Add batch dimension and normalize return im0, img except Exception as e: raise RuntimeError(f"Error preprocessing image: {str(e)}") def infer(self, img: torch.Tensor) -> List[torch.Tensor]: """Run cat detection inference on preprocessed image. Args: img: Preprocessed image tensor Returns: List of detection tensors after NMS """ with torch.no_grad(): pred = self.model(img) # Increase confidence threshold for more reliable cat detection return non_max_suppression(pred, conf_thres=0.4, iou_thres=0.45) def process_detections(self, im0: np.ndarray, detections: List[torch.Tensor], resolution_factor: float = 1) -> Optional[np.ndarray]: """Process detections to crop detected cat with adjustable padding. Args: im0: Original image detections: Detection results from inference padding_factor: Factor to increase the bounding box size (default is 0.1 for 10% padding) Returns: Cropped image of detected cat or None if no cats detected """ if not isinstance(detections, (list, tuple)) or not detections: return None for det in detections: if len(det): # Rescale boxes from img_size to im0 size det[:, :4] = scale_coords( (self.img_size, self.img_size), det[:, :4], im0.shape).round() # Get detection with highest confidence best_det = det[det[:, 4].argmax()] return self.crop_and_save_image( im0, best_det[:4], resolution_factor) return None @staticmethod def crop_and_save_image( im0: np.ndarray, xyxy: torch.Tensor, resolution_factor: float) -> np.ndarray: """Crop detected cat from image with a specified resolution factor. Args: im0: Original image xyxy: Bounding box coordinates tensor resolution_factor: Factor to adjust the width relative to the height Returns: Cropped image array of the cat Raises: ValueError: If bounding box coordinates are invalid """ # Convert to integers x_min, y_min, x_max, y_max = map(int, xyxy) # Calculate the center of the bounding box x_center = (x_min + x_max) / 2 y_center = (y_min + y_max) / 2 # Calculate new height and width based on the resolution factor new_height = (y_max - y_min) new_width = new_height * resolution_factor # Calculate new bounding box coordinates x_min_new = int(x_center - new_width / 2) x_max_new = int(x_center + new_width / 2) y_min_new = int(y_center - new_height / 2) y_max_new = int(y_center + new_height / 2) # Adjust if out of bounds if x_min_new < 0: x_max_new = min(x_max_new - x_min_new, im0.shape[1]) # Shift right x_min_new = 0 if x_max_new > im0.shape[1]: x_min_new = max(x_min_new - (x_max_new - im0.shape[1]), 0) # Shift left x_max_new = im0.shape[1] if x_min_new >= x_max_new or y_min_new >= y_max_new: raise ValueError("Invalid bounding box coordinates") return im0[y_min_new:y_max_new, x_min_new:x_max_new] def letterbox(img: np.ndarray, new_shape: int = 640, color: Tuple[int, int, int] = ( 114, 114, 114)) -> Tuple[np.ndarray, float, Tuple[float, float]]: """Resize and pad image while meeting stride-multiple constraints.""" shape = img.shape[:2] # current shape [height, width] # Scale ratio (new / old) r = min(new_shape / shape[0], new_shape / shape[1]) # Compute padding new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) dw, dh = new_shape - new_unpad[0], new_shape - new_unpad[1] # wh padding dw /= 2 # divide padding into 2 sides dh /= 2 if shape[::-1] != new_unpad: # resize img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR) top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) img = cv2.copyMakeBorder( img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) return img, r, (dw, dh)