Upload TensorRTBBoxDetector.py
Browse files- TensorRTBBoxDetector.py +8 -5
TensorRTBBoxDetector.py
CHANGED
|
@@ -173,12 +173,13 @@ class TRTYOLOBBoxDetector:
|
|
| 173 |
continue
|
| 174 |
|
| 175 |
# ------------------------------------------------------------------
|
| 176 |
-
# Create full-image mask from bbox
|
| 177 |
# ------------------------------------------------------------------
|
| 178 |
-
mask = np.zeros((h, w), dtype=np.
|
| 179 |
-
mask[y1_i:y2_i, x1_i:x2_i] =
|
| 180 |
|
| 181 |
-
# Optional dilation / erosion via GPU-aware helper
|
|
|
|
| 182 |
if dilation:
|
| 183 |
mask = dilate_mask(mask, dilation)
|
| 184 |
|
|
@@ -196,7 +197,7 @@ class TRTYOLOBBoxDetector:
|
|
| 196 |
# Crop image + mask
|
| 197 |
# ------------------------------------------------------------------
|
| 198 |
cropped_image = crop_image(image, crop_region) # torch [1, h', w', C]
|
| 199 |
-
cropped_mask = crop_ndarray2(mask, crop_region) # np.
|
| 200 |
|
| 201 |
# Build SEG object for this detection
|
| 202 |
seg = SEG(
|
|
@@ -241,6 +242,8 @@ class TRTYOLOBBoxDetector:
|
|
| 241 |
return core.segs_to_combined_mask((shape, seg_list))
|
| 242 |
|
| 243 |
|
|
|
|
|
|
|
| 244 |
# -------------------------------------------------------------------------
|
| 245 |
# NODE 1: TRTYOLOEngineBuilder
|
| 246 |
# - Builds a TensorRT engine from a .pt file in the node folder.
|
|
|
|
| 173 |
continue
|
| 174 |
|
| 175 |
# ------------------------------------------------------------------
|
| 176 |
+
# Create full-image mask from bbox as float32 in [0, 1]
|
| 177 |
# ------------------------------------------------------------------
|
| 178 |
+
mask = np.zeros((h, w), dtype=np.float32)
|
| 179 |
+
mask[y1_i:y2_i, x1_i:x2_i] = 1.0
|
| 180 |
|
| 181 |
+
# Optional dilation / erosion via GPU-aware helper.
|
| 182 |
+
# IMPORTANT: dilate_mask must return float32 [0,1] as well.
|
| 183 |
if dilation:
|
| 184 |
mask = dilate_mask(mask, dilation)
|
| 185 |
|
|
|
|
| 197 |
# Crop image + mask
|
| 198 |
# ------------------------------------------------------------------
|
| 199 |
cropped_image = crop_image(image, crop_region) # torch [1, h', w', C]
|
| 200 |
+
cropped_mask = crop_ndarray2(mask, crop_region) # np.float32 [h', w'] in [0,1]
|
| 201 |
|
| 202 |
# Build SEG object for this detection
|
| 203 |
seg = SEG(
|
|
|
|
| 242 |
return core.segs_to_combined_mask((shape, seg_list))
|
| 243 |
|
| 244 |
|
| 245 |
+
|
| 246 |
+
|
| 247 |
# -------------------------------------------------------------------------
|
| 248 |
# NODE 1: TRTYOLOEngineBuilder
|
| 249 |
# - Builds a TensorRT engine from a .pt file in the node folder.
|