niknikita's picture
fix
d73dbf8
import argparse
import time
from pathlib import Path
import cv2
import torch
import torch.backends.cudnn as cudnn
from numpy import random
from PIL import Image
from models.experimental import attempt_load
from utils.datasets import LoadStreams, LoadImages
from utils.general import check_img_size, non_max_suppression, scale_coords, strip_optimizer, set_logging, increment_path
from utils.plots import plot_one_box
from utils.torch_utils import select_device, time_synchronized
def process_and_save_image(im0, xyxy, save_dir, filename, target_size=(224, 224)):
# Crop and pad the image to make it square
x_min, y_min, x_max, y_max = int(xyxy[0]), int(xyxy[1]), int(xyxy[2]), int(xyxy[3])
width, height = x_max - x_min, y_max - y_min
max_side = max(width, height)
center_x, center_y = (x_min + x_max) // 2, (y_min + y_max) // 2
new_x_min = max(0, center_x - max_side // 2)
new_y_min = max(0, center_y - max_side // 2)
new_x_max = min(im0.shape[1], new_x_min + max_side)
new_y_max = min(im0.shape[0], new_y_min + max_side)
cat_image = im0[new_y_min:new_y_max, new_x_min:new_x_max]
padded_image = cv2.copyMakeBorder(
cat_image,
top=max(0, new_y_min - y_min),
bottom=max(0, y_max - new_y_max),
left=max(0, new_x_min - x_min),
right=max(0, x_max - new_x_max),
borderType=cv2.BORDER_CONSTANT, value=[0, 0, 0]
)
# Resize to 224x224 and convert to RGB
resized_image = cv2.resize(padded_image, target_size)
image_rgb = cv2.cvtColor(resized_image, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(image_rgb)
output_path = save_dir / f"{filename}_cat.png"
pil_image.save(output_path, 'PNG')
print(f"Saved processed image to {output_path}")
def detect(save_img=False):
source, weights, view_img, save_txt, imgsz = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size
webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(
('rtsp://', 'rtmp://', 'http://'))
# Directories
save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) # increment run
save_dir.mkdir(parents=True, exist_ok=True) # make dir
# Initialize
set_logging()
device = select_device(opt.device)
half = device.type != 'cpu' # half precision only supported on CUDA
# Load model
model = attempt_load(weights, map_location=device) # load FP32 model
imgsz = check_img_size(imgsz, s=model.stride.max()) # check img_size
if half:
model.half() # to FP16
# Set Dataloader
dataset = LoadStreams(source, img_size=imgsz) if webcam else LoadImages(source, img_size=imgsz)
# Run inference
t0 = time.time()
img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img
_ = model(img.half() if half else img) if device.type != 'cpu' else None # run once
for path, img, im0s, vid_cap in dataset:
img = torch.from_numpy(img).to(device)
img = img.half() if half else img.float() # uint8 to fp16/32
img /= 255.0 # 0 - 255 to 0.0 - 1.0
if img.ndimension() == 3:
img = img.unsqueeze(0)
# Inference
t1 = time_synchronized()
pred = model(img, augment=opt.augment)[0]
# Apply NMS
pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
t2 = time_synchronized()
# Process detections
for i, det in enumerate(pred): # detections per image
p, im0 = (path[i], im0s[i].copy()) if webcam else (path, im0s)
# Rescale boxes from img_size to im0 size
if len(det):
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
# Iterate over detected objects
for *xyxy, conf, cls in reversed(det):
if opt.save_txt: # Save detection details if needed
pass
# Check if detected object is a cat (you can set your class filter to cats only)
if int(cls) in [14, 15, 16, 17, 18, 19, 20]: # assuming 15 is the class index for 'cat'
filename = Path(p).stem
process_and_save_image(im0, xyxy, save_dir, filename)
print(f'Done. ({time.time() - t0:.3f}s)')