faceid / app /services /segmentation /segmentation_model.py
niknikita's picture
Update app/services/segmentation/segmentation_model.py
71c4b57 verified
import sys
from io import BytesIO
from typing import List, Optional, Tuple
import cv2
import numpy as np
import torch
from fastapi import FastAPI, UploadFile
from fastapi.responses import FileResponse
from PIL import Image
from .models.experimental import attempt_load
from .utils.general import (
check_img_size,
non_max_suppression,
scale_coords,
set_logging,
)
from .utils.torch_utils import select_device
class YOLODetector:
"""YOLO object detector for detecting and cropping cats from images."""
def __init__(self, weights: str, device: str = '', img_size: int = 640):
"""Initialize YOLO detector with specified weights and device.
Args:
weights: Path to weights file
device: Device to run inference on ('cpu' or 'cuda')
img_size: Input image size
"""
self.device = select_device(device)
try:
# Load YOLOv5 model specifically for cat detection
print("segmentation_model1")
self.model = torch.hub.load(
'ultralytics/yolov5', 'yolov5s', pretrained=True)
self.model.to(self.device)
self.model.eval()
# Filter classes to only detect cats (class 15 in COCO dataset)
self.model.classes = [15]
except Exception as e:
raise RuntimeError(f"Failed to load YOLOv5 model: {str(e)}")
self.img_size = check_img_size(img_size, s=self.model.stride)
set_logging()
def preprocess_image(
self, file: UploadFile) -> Tuple[np.ndarray, torch.Tensor]:
"""Preprocess uploaded image for cat detection.
Args:
file: Uploaded file containing image
Returns:
Tuple of (original image, preprocessed tensor)
Raises:
RuntimeError: If image preprocessing fails
"""
try:
contents = file.file.read()
nparr = np.frombuffer(contents, np.uint8)
im0 = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if im0 is None:
raise ValueError("Failed to decode image")
# Resize while maintaining aspect ratio
img = letterbox(im0, new_shape=self.img_size)[0]
# Convert to torch tensor
img = torch.from_numpy(img).to(self.device)
img = img.permute(2, 0, 1).float() # BGR to RGB, HWC to CHW
img = img.unsqueeze(0) / 255.0 # Add batch dimension and normalize
return im0, img
except Exception as e:
raise RuntimeError(f"Error preprocessing image: {str(e)}")
def preprocess_image_binary(
self, image: np.ndarray) -> Tuple[np.ndarray, torch.Tensor]:
"""Preprocess image for cat detection.
Args:
image: Image in numpy.ndarray format
Returns:
Tuple of (original image, preprocessed tensor)
Raises:
RuntimeError: If image preprocessing fails
"""
try:
if image is None:
raise ValueError("Image is None")
# Оригинальное изображение (без изменений)
im0 = image
# Resize while maintaining aspect ratio
img = letterbox(im0, new_shape=self.img_size)[0]
# Convert to torch tensor
img = torch.from_numpy(img).to(self.device)
img = img.permute(2, 0, 1).float() # BGR to RGB, HWC to CHW
img = img.unsqueeze(0) / 255.0 # Add batch dimension and normalize
return im0, img
except Exception as e:
raise RuntimeError(f"Error preprocessing image: {str(e)}")
def infer(self, img: torch.Tensor) -> List[torch.Tensor]:
"""Run cat detection inference on preprocessed image.
Args:
img: Preprocessed image tensor
Returns:
List of detection tensors after NMS
"""
with torch.no_grad():
pred = self.model(img)
# Increase confidence threshold for more reliable cat detection
return non_max_suppression(pred, conf_thres=0.4, iou_thres=0.45)
def process_detections(self,
im0: np.ndarray,
detections: List[torch.Tensor],
resolution_factor: float = 1) -> Optional[np.ndarray]:
"""Process detections to crop detected cat with adjustable padding.
Args:
im0: Original image
detections: Detection results from inference
padding_factor: Factor to increase the bounding box size (default is 0.1 for 10% padding)
Returns:
Cropped image of detected cat or None if no cats detected
"""
if not isinstance(detections, (list, tuple)) or not detections:
return None
for det in detections:
if len(det):
# Rescale boxes from img_size to im0 size
det[:, :4] = scale_coords(
(self.img_size, self.img_size), det[:, :4], im0.shape).round()
# Get detection with highest confidence
best_det = det[det[:, 4].argmax()]
return self.crop_and_save_image(
im0, best_det[:4], resolution_factor)
return None
@staticmethod
def crop_and_save_image(
im0: np.ndarray,
xyxy: torch.Tensor,
resolution_factor: float) -> np.ndarray:
"""Crop detected cat from image with a specified resolution factor.
Args:
im0: Original image
xyxy: Bounding box coordinates tensor
resolution_factor: Factor to adjust the width relative to the height
Returns:
Cropped image array of the cat
Raises:
ValueError: If bounding box coordinates are invalid
"""
# Convert to integers
x_min, y_min, x_max, y_max = map(int, xyxy)
# Calculate the center of the bounding box
x_center = (x_min + x_max) / 2
y_center = (y_min + y_max) / 2
# Calculate new height and width based on the resolution factor
new_height = (y_max - y_min)
new_width = new_height * resolution_factor
# Calculate new bounding box coordinates
x_min_new = int(x_center - new_width / 2)
x_max_new = int(x_center + new_width / 2)
y_min_new = int(y_center - new_height / 2)
y_max_new = int(y_center + new_height / 2)
# Adjust if out of bounds
if x_min_new < 0:
x_max_new = min(x_max_new - x_min_new, im0.shape[1]) # Shift right
x_min_new = 0
if x_max_new > im0.shape[1]:
x_min_new = max(x_min_new - (x_max_new - im0.shape[1]), 0) # Shift left
x_max_new = im0.shape[1]
if x_min_new >= x_max_new or y_min_new >= y_max_new:
raise ValueError("Invalid bounding box coordinates")
return im0[y_min_new:y_max_new, x_min_new:x_max_new]
def letterbox(img: np.ndarray, new_shape: int = 640, color: Tuple[int, int, int] = (
114, 114, 114)) -> Tuple[np.ndarray, float, Tuple[float, float]]:
"""Resize and pad image while meeting stride-multiple constraints."""
shape = img.shape[:2] # current shape [height, width]
# Scale ratio (new / old)
r = min(new_shape / shape[0], new_shape / shape[1])
# Compute padding
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape - new_unpad[0], new_shape - new_unpad[1] # wh padding
dw /= 2 # divide padding into 2 sides
dh /= 2
if shape[::-1] != new_unpad: # resize
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
img = cv2.copyMakeBorder(
img,
top,
bottom,
left,
right,
cv2.BORDER_CONSTANT,
value=color)
return img, r, (dw, dh)