saliacoel commited on
Commit
18518ac
·
verified ·
1 Parent(s): f5925eb

Upload TensorRTBBoxDetector.py

Browse files
Files changed (1) hide show
  1. TensorRTBBoxDetector.py +8 -5
TensorRTBBoxDetector.py CHANGED
@@ -173,12 +173,13 @@ class TRTYOLOBBoxDetector:
173
  continue
174
 
175
  # ------------------------------------------------------------------
176
- # Create full-image mask from bbox (uint8 0/255)
177
  # ------------------------------------------------------------------
178
- mask = np.zeros((h, w), dtype=np.uint8)
179
- mask[y1_i:y2_i, x1_i:x2_i] = 255
180
 
181
- # Optional dilation / erosion via GPU-aware helper
 
182
  if dilation:
183
  mask = dilate_mask(mask, dilation)
184
 
@@ -196,7 +197,7 @@ class TRTYOLOBBoxDetector:
196
  # Crop image + mask
197
  # ------------------------------------------------------------------
198
  cropped_image = crop_image(image, crop_region) # torch [1, h', w', C]
199
- cropped_mask = crop_ndarray2(mask, crop_region) # np.uint8 [h', w']
200
 
201
  # Build SEG object for this detection
202
  seg = SEG(
@@ -241,6 +242,8 @@ class TRTYOLOBBoxDetector:
241
  return core.segs_to_combined_mask((shape, seg_list))
242
 
243
 
 
 
244
  # -------------------------------------------------------------------------
245
  # NODE 1: TRTYOLOEngineBuilder
246
  # - Builds a TensorRT engine from a .pt file in the node folder.
 
173
  continue
174
 
175
  # ------------------------------------------------------------------
176
+ # Create full-image mask from bbox as float32 in [0, 1]
177
  # ------------------------------------------------------------------
178
+ mask = np.zeros((h, w), dtype=np.float32)
179
+ mask[y1_i:y2_i, x1_i:x2_i] = 1.0
180
 
181
+ # Optional dilation / erosion via GPU-aware helper.
182
+ # IMPORTANT: dilate_mask must return float32 [0,1] as well.
183
  if dilation:
184
  mask = dilate_mask(mask, dilation)
185
 
 
197
  # Crop image + mask
198
  # ------------------------------------------------------------------
199
  cropped_image = crop_image(image, crop_region) # torch [1, h', w', C]
200
+ cropped_mask = crop_ndarray2(mask, crop_region) # np.float32 [h', w'] in [0,1]
201
 
202
  # Build SEG object for this detection
203
  seg = SEG(
 
242
  return core.segs_to_combined_mask((shape, seg_list))
243
 
244
 
245
+
246
+
247
  # -------------------------------------------------------------------------
248
  # NODE 1: TRTYOLOEngineBuilder
249
  # - Builds a TensorRT engine from a .pt file in the node folder.