import os from matplotlib import pyplot as plt import skimage from skimage import io as skimage_io import numpy as np import re import cv2 def rescale_detection_box(boxes, image): h_img, w_img, _ = image.shape size = max(h_img, w_img) pad_h = size - h_img pad_w = size - w_img recovered_boxes = [] for box in boxes: cx, cy, w, h = box cx = cx * size cy = cy * size w = w * size h = h * size # if cx < 0 or cx > w_img or cy < 0 or cy > h_img: # continue; x1 = cx - w / 2 y1 = cy - h / 2 x2 = cx + w / 2 y2 = cy + h / 2 recovered_boxes.append((x1, y1, x2, y2)) return recovered_boxes def read_images(image_dir): images = [] filenames = [p for p in os.listdir(image_dir) if os.path.splitext(p)[-1].lower() in [".png", ".jpg", ".jpeg",]] filenames.sort(key=lambda p: os.path.splitext(p)[0]) for filename in filenames: file_path = os.path.join(image_dir, filename) image_uint8 = skimage_io.imread(file_path) image = image_uint8.astype(np.float32) / 255.0 images.append(image) return images, filenames def preprocess_images(images, model_input_size): processed_images = [] for image in images: # Pad image to square h, w, d = image.shape size = max(h, w) image_padded = np.pad(image, ((0, size - h), (0, size - w), (0, 0)), constant_values=0.5) # Resize image to fit model's input size image_resized = skimage.transform.resize( image_padded, (model_input_size, model_input_size), anti_aliasing=True, ) processed_images.append(image_resized) # Shape: (b, h, w, d) return np.array(processed_images, dtype=np.float32) def too_small(bbox, threshold=400): x1, y1, x2, y2 = bbox width = max(0, x2 - x1) height = max(0, y2 - y1) area = width * height # Return True if area is too small return area < threshold def too_large(bbox, image, threshold=0.9): x1, y1, x2, y2 = bbox bbox_width = x2 - x1 bbox_height = y2 - y1 bbox_area = bbox_width * bbox_height image_height, image_width = image.shape[:2] image_area = image_width * image_height area_ratio = bbox_area / image_area return area_ratio >= threshold def plot_bboxes_on_orig_image(image, boxes, output_path): plt.clf() plt.imshow(image) plt.axis('off') for box in boxes: x1, y1, x2, y2 = box plt.plot( [x1, x2, x2, x1, x1], [y1, y1, y2, y2, y1], linewidth=0.8, alpha=0.6 ) plt.savefig(output_path, bbox_inches='tight', pad_inches=0.1, dpi=300) plt.close() print(f" Done! Visualization saved to {output_path}") def compute_iou(box1, box2): x1_inter = max(box1[0], box2[0]) y1_inter = max(box1[1], box2[1]) x2_inter = min(box1[2], box2[2]) y2_inter = min(box1[3], box2[3]) inter_width = max(0, x2_inter - x1_inter) inter_height = max(0, y2_inter - y1_inter) inter_area = inter_width * inter_height box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1]) union_area = box1_area + box2_area - inter_area return inter_area / union_area if union_area > 0 else 0 def remove_overlapping_bboxes(bboxes, iou_threshold=0.7): if not bboxes: return [] bboxes = sorted(bboxes, key=lambda x: (x[2] - x[0]) * (x[3] - x[1]), reverse=True) keep = [] for bbox in bboxes: should_keep = True for kept_bbox in keep: if compute_iou(bbox, kept_bbox) > iou_threshold: should_keep = False break if should_keep: keep.append(bbox) return keep def get_centroid(bbox): x1, y1, x2, y2 = bbox cx = (x1 + x2) / 2 cy = (y1 + y2) / 2 return (int(cx), int(cy))