File size: 4,958 Bytes
af29397
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import numpy as np
import torch
import logging
from typing import List, Tuple, Dict, Any, Optional

def get_slice_bboxes(
    image_height: int,
    image_width: int,
    slice_height: int = 640,
    slice_width: int = 640,
    overlap_height_ratio: float = 0.2,
    overlap_width_ratio: float = 0.2,
) -> List[List[int]]:
    """
    Calculate bounding boxes for slices with overlap.
    Returns: List of [x_min, y_min, x_max, y_max]
    """
    slice_bboxes = []
    y_max = y_min = 0
    y_overlap = int(slice_height * overlap_height_ratio)
    x_overlap = int(slice_width * overlap_width_ratio)

    while y_max < image_height:
        x_min = x_max = 0
        y_max = y_min + slice_height
        
        while x_max < image_width:
            x_max = x_min + slice_width
            
            # Adjustment for boundaries
            if y_max > image_height:
                y_max = image_height
                y_min = max(0, image_height - slice_height)
                
            if x_max > image_width:
                x_max = image_width
                x_min = max(0, image_width - slice_width)
                
            slice_bboxes.append([x_min, y_min, x_max, y_max])
            
            x_min = x_max - x_overlap
        y_min = y_max - y_overlap
        
    return slice_bboxes

def slice_image(
    image: np.ndarray,
    slice_bboxes: List[List[int]]
) -> List[np.ndarray]:
    """Crops the image based on provided bounding boxes."""
    slices = []
    for bbox in slice_bboxes:
        xmin, ymin, xmax, ymax = bbox
        slices.append(image[ymin:ymax, xmin:xmax])
    return slices

def shift_bboxes(
    bboxes: List[List[float]],
    slice_coords: List[int]
) -> List[List[float]]:
    """
    Shifts bounding boxes from slice coordinates to global image coordinates.
    slice_coords: [xmin, ymin, xmax, ymax]
    bboxes: List of [xmin, ymin, xmax, ymax]
    """
    shift_x = slice_coords[0]
    shift_y = slice_coords[1]
    
    shifted = []
    for box in bboxes:
        # box = [x1, y1, x2, y2]
        shifted.append([
            box[0] + shift_x,
            box[1] + shift_y,
            box[2] + shift_x,
            box[3] + shift_y
        ])
    return shifted

def batched_nms(
    boxes: torch.Tensor,
    scores: torch.Tensor,
    idxs: torch.Tensor,
    iou_threshold: float = 0.5
) -> torch.Tensor:
    """
    Performs non-maximum suppression in a batched fashion.
    Fallback to simple NMS if torchvision/ultralytics unavailable.
    """
    if boxes.numel() == 0:
        return torch.empty((0,), dtype=torch.int64, device=boxes.device)
        
    # Try importing efficient NMS implementations
    try:
        import torchvision
        return torchvision.ops.batched_nms(boxes, scores, idxs, iou_threshold)
    except ImportError:
        pass
        
    try:
        from ultralytics.utils.ops import non_max_suppression
        # Ultralytics NMS is usually complex/end-to-end. We need simple box NMS.
        # Fallback to custom greedy NMS
    except ImportError:
        pass

    # Custom Batched NMS Implementation (Slow but standard)
    keep_indices = []
    unique_labels = idxs.unique()
    
    for label in unique_labels:
        mask = (idxs == label)
        cls_boxes = boxes[mask]
        cls_scores = scores[mask]
        original_indices = torch.where(mask)[0]
        
        # Sort by score
        sorted_indices = torch.argsort(cls_scores, descending=True)
        cls_boxes = cls_boxes[sorted_indices]
        original_indices = original_indices[sorted_indices]
        
        cls_keep = []
        while cls_boxes.size(0) > 0:
            current_idx = 0
            cls_keep.append(original_indices[current_idx])
            
            if cls_boxes.size(0) == 1:
                break
                
            current_box = cls_boxes[current_idx].unsqueeze(0)
            rest_boxes = cls_boxes[1:]
            
            # IoU Calculation
            x1 = torch.max(current_box[:, 0], rest_boxes[:, 0])
            y1 = torch.max(current_box[:, 1], rest_boxes[:, 1])
            x2 = torch.min(current_box[:, 2], rest_boxes[:, 2])
            y2 = torch.min(current_box[:, 3], rest_boxes[:, 3])
            
            inter_area = (x2 - x1).clamp(min=0) * (y2 - y1).clamp(min=0)
            box_area = (current_box[:, 2] - current_box[:, 0]) * (current_box[:, 3] - current_box[:, 1])
            rest_area = (rest_boxes[:, 2] - rest_boxes[:, 0]) * (rest_boxes[:, 3] - rest_boxes[:, 1])
            union_area = box_area + rest_area - inter_area
            
            iou = inter_area / (union_area + 1e-6)
            
            # Keep boxes with low IoU
            mask_iou = iou < iou_threshold
            cls_boxes = rest_boxes[mask_iou]
            original_indices = original_indices[1:][mask_iou]
            
        keep_indices.extend(cls_keep)
        
    return torch.tensor(keep_indices, dtype=torch.int64, device=boxes.device)