saliacoel commited on
Commit
4c0e632
·
verified ·
1 Parent(s): 67b99e9

Upload TensorRTBBoxDetector.py

Browse files
Files changed (1) hide show
  1. TensorRTBBoxDetector.py +409 -0
TensorRTBBoxDetector.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from typing import List, Tuple
4
+
5
+ import torch
6
+ import numpy as np
7
+ from ultralytics import YOLO
8
+
9
+ # Impact Pack (for SEG and SEGS helpers)
10
+ import impact.core as core
11
+ from impact.core import SEG
12
+
13
+ # Local helpers (your utils_salia)
14
+ try:
15
+ # Package-style import (recommended inside a ComfyUI custom node package)
16
+ from .utils_salia import (
17
+ NODE_DIR,
18
+ IMGSZ,
19
+ list_local_pt_files,
20
+ tensor_to_pil,
21
+ make_crop_region,
22
+ crop_image,
23
+ crop_ndarray2,
24
+ dilate_mask,
25
+ )
26
+ except ImportError:
27
+ # Fallback if utils_salia is importable directly (not as a package)
28
+ from utils_salia import (
29
+ NODE_DIR,
30
+ IMGSZ,
31
+ list_local_pt_files,
32
+ tensor_to_pil,
33
+ make_crop_region,
34
+ crop_image,
35
+ crop_ndarray2,
36
+ dilate_mask,
37
+ )
38
+
39
+
40
+ logger = logging.getLogger(__name__)
41
+
42
+
43
+ # -------------------------------------------------------------------------
44
+ # YOLO TensorRT-based BBOX_DETECTOR implementation
45
+ # -------------------------------------------------------------------------
46
+
47
+
48
+ class TRTYOLOBBoxDetector:
49
+ """
50
+ BBOX_DETECTOR interface compatible with Impact Pack / FaceDetailer.
51
+
52
+ Required API:
53
+ - setAux(x)
54
+ - detect(image, threshold, dilation, crop_factor, drop_size=1, detailer_hook=None)
55
+ - detect_combined(image, threshold, dilation)
56
+ """
57
+
58
+ def __init__(self, yolo_model: YOLO, device: str = "0"):
59
+ self.bbox_model = yolo_model
60
+ self.device = device or "0"
61
+ # aux is used as a class name filter, e.g. FaceDetailer calls setAux('face')
62
+ self.aux: str | None = None
63
+
64
+ # ------------------------------------------------------------------
65
+ # API: setAux
66
+ # ------------------------------------------------------------------
67
+ def setAux(self, x: str):
68
+ """
69
+ Store auxiliary info (typically a class filter like 'face').
70
+ FaceDetailer calls setAux('face') before detect() and setAux(None) after.
71
+ """
72
+ self.aux = x
73
+
74
+ # ------------------------------------------------------------------
75
+ # API: detect
76
+ # ------------------------------------------------------------------
77
+ def detect(
78
+ self,
79
+ image: torch.Tensor,
80
+ threshold: float,
81
+ dilation: int,
82
+ crop_factor: float,
83
+ drop_size: int = 1,
84
+ detailer_hook=None,
85
+ ) -> Tuple[Tuple[int, int], List[SEG]]:
86
+ """
87
+ Main detection method used by FaceDetailer.
88
+
89
+ Args:
90
+ image: ComfyUI IMAGE tensor [B, H, W, C] in 0..1.
91
+ threshold: confidence threshold for detections.
92
+ dilation: mask dilation/erosion size in pixels (>0 dilate, <0 erode).
93
+ crop_factor: expansion factor for bbox when computing crop_region.
94
+ drop_size: minimum bbox width/height to keep.
95
+ detailer_hook: optional hook with post_crop_region / post_detection.
96
+
97
+ Returns:
98
+ SEGS tuple: ( (H, W), [SEG, SEG, ...] )
99
+ """
100
+
101
+ if image.dim() != 4:
102
+ raise ValueError(
103
+ "[TRTYOLOBBoxDetector] Expected IMAGE tensor with 4 dims [B, H, W, C]."
104
+ )
105
+
106
+ # Impact Pack detectors typically only use the first image in a batch.
107
+ if image.shape[0] != 1:
108
+ logger.warning(
109
+ "[TRTYOLOBBoxDetector] Batch > 1 detected; using only the first image for detection."
110
+ )
111
+ image = image[:1]
112
+
113
+ # Original image size
114
+ h, w = int(image.shape[1]), int(image.shape[2])
115
+ shape = (h, w)
116
+
117
+ # Convert tensor to PIL for Ultralytics inference
118
+ pil_img = tensor_to_pil(image)
119
+
120
+ # Run YOLO model prediction with given threshold on the chosen device
121
+ pred_list = self.bbox_model(pil_img, conf=threshold, device=self.device, verbose=False)
122
+ if len(pred_list) == 0:
123
+ return (shape, [])
124
+
125
+ pred = pred_list[0]
126
+ boxes = pred.boxes
127
+ if boxes is None or boxes.xyxy is None or boxes.xyxy.shape[0] == 0:
128
+ return (shape, [])
129
+
130
+ xyxy = boxes.xyxy.cpu().numpy() # [N, 4] (x1, y1, x2, y2)
131
+ confs = boxes.conf.cpu().numpy() # [N] confidence
132
+ clses = boxes.cls.cpu().numpy().astype(int) # [N] class indices
133
+ names = pred.names # class names (can be list/tuple or dict)
134
+
135
+ seg_items: List[SEG] = []
136
+
137
+ for i in range(xyxy.shape[0]):
138
+ x1, y1, x2, y2 = xyxy[i]
139
+ score = float(confs[i])
140
+ cls_id = int(clses[i])
141
+
142
+ # ------------------------------------------------------------------
143
+ # Class label lookup robust to list/dict for names
144
+ # ------------------------------------------------------------------
145
+ if isinstance(names, (list, tuple)):
146
+ label = names[cls_id] if 0 <= cls_id < len(names) else str(cls_id)
147
+ else:
148
+ # dict-like: {class_index: "name"}
149
+ label = names.get(cls_id, str(cls_id))
150
+
151
+ # ------------------------------------------------------------------
152
+ # Aux filter (e.g. only keep 'face')
153
+ # ------------------------------------------------------------------
154
+ if self.aux and isinstance(self.aux, str):
155
+ if label.lower() != self.aux.lower():
156
+ # Skip detections for other classes
157
+ continue
158
+
159
+ # ------------------------------------------------------------------
160
+ # Drop tiny boxes
161
+ # ------------------------------------------------------------------
162
+ box_w = x2 - x1
163
+ box_h = y2 - y1
164
+ if box_w <= drop_size or box_h <= drop_size:
165
+ continue
166
+
167
+ # Clamp bbox to image bounds (integer pixel coords)
168
+ x1_i = max(int(np.floor(x1)), 0)
169
+ y1_i = max(int(np.floor(y1)), 0)
170
+ x2_i = min(int(np.ceil(x2)), w)
171
+ y2_i = min(int(np.ceil(y2)), h)
172
+ if x2_i <= x1_i or y2_i <= y1_i:
173
+ continue
174
+
175
+ # ------------------------------------------------------------------
176
+ # Create full-image mask from bbox (uint8 0/255)
177
+ # ------------------------------------------------------------------
178
+ mask = np.zeros((h, w), dtype=np.uint8)
179
+ mask[y1_i:y2_i, x1_i:x2_i] = 255
180
+
181
+ # Optional dilation / erosion via GPU-aware helper
182
+ if dilation:
183
+ mask = dilate_mask(mask, dilation)
184
+
185
+ # Impact core uses bbox as [x1, y1, x2, y2]
186
+ item_bbox = [float(x1), float(y1), float(x2), float(y2)]
187
+
188
+ # ------------------------------------------------------------------
189
+ # Compute crop region (expanded bbox) in xyxy format
190
+ # ------------------------------------------------------------------
191
+ crop_region = make_crop_region(w, h, item_bbox, crop_factor)
192
+ if detailer_hook is not None and hasattr(detailer_hook, "post_crop_region"):
193
+ crop_region = detailer_hook.post_crop_region(w, h, item_bbox, crop_region)
194
+
195
+ # ------------------------------------------------------------------
196
+ # Crop image + mask
197
+ # ------------------------------------------------------------------
198
+ cropped_image = crop_image(image, crop_region) # torch [1, h', w', C]
199
+ cropped_mask = crop_ndarray2(mask, crop_region) # np.uint8 [h', w']
200
+
201
+ # Build SEG object for this detection
202
+ seg = SEG(
203
+ cropped_image,
204
+ cropped_mask,
205
+ score,
206
+ crop_region,
207
+ item_bbox,
208
+ label,
209
+ None, # control_net_wrapper
210
+ )
211
+ seg_items.append(seg)
212
+
213
+ segs = (shape, seg_items)
214
+
215
+ # Optional post-detection hook
216
+ if detailer_hook is not None and hasattr(detailer_hook, "post_detection"):
217
+ segs = detailer_hook.post_detection(segs)
218
+
219
+ return segs
220
+
221
+ # ------------------------------------------------------------------
222
+ # API: detect_combined
223
+ # ------------------------------------------------------------------
224
+ def detect_combined(
225
+ self,
226
+ image: torch.Tensor,
227
+ threshold: float,
228
+ dilation: int,
229
+ ) -> torch.Tensor:
230
+ """
231
+ Optional combined-mask API: returns a single MASK tensor covering all detections.
232
+ """
233
+ shape, seg_list = self.detect(
234
+ image=image,
235
+ threshold=threshold,
236
+ dilation=dilation,
237
+ crop_factor=1.0,
238
+ drop_size=1,
239
+ detailer_hook=None,
240
+ )
241
+ return core.segs_to_combined_mask((shape, seg_list))
242
+
243
+
244
+ # -------------------------------------------------------------------------
245
+ # NODE 1: TRTYOLOEngineBuilder
246
+ # - Builds a TensorRT engine from a .pt file in the node folder.
247
+ # -------------------------------------------------------------------------
248
+
249
+
250
+ class TRTYOLOEngineBuilder:
251
+ @classmethod
252
+ def INPUT_TYPES(cls):
253
+ pt_files = list_local_pt_files()
254
+ default_name = pt_files[0] if pt_files else "face.pt"
255
+
256
+ return {
257
+ "required": {
258
+ "pt_model_name": (
259
+ pt_files,
260
+ {
261
+ "default": default_name,
262
+ "tooltip": (
263
+ "Select a YOLO .pt file that lives in the SAME folder as this node file."
264
+ ),
265
+ },
266
+ ),
267
+ }
268
+ }
269
+
270
+ RETURN_TYPES = ("STRING",)
271
+ RETURN_NAMES = ("engine_path",)
272
+ FUNCTION = "build"
273
+ CATEGORY = "ImpactPack/TensorRT"
274
+
275
+ def build(self, pt_model_name: str):
276
+ # Resolve .pt path relative to this node file
277
+ pt_path = os.path.join(NODE_DIR, pt_model_name)
278
+ if not os.path.isfile(pt_path):
279
+ raise FileNotFoundError(
280
+ f"[TRTYOLOEngineBuilder] .pt model not found: {pt_path}"
281
+ )
282
+
283
+ logger.info(
284
+ f"[TRTYOLOEngineBuilder] Exporting TensorRT engine from '{pt_path}' "
285
+ f"with imgsz={IMGSZ} (H,W), batch=1, half=True, device='0', exist_ok=True"
286
+ )
287
+
288
+ # Export the model to TensorRT engine format
289
+ try:
290
+ result = YOLO(pt_path).export(
291
+ format="engine",
292
+ imgsz=IMGSZ,
293
+ half=True,
294
+ device="0",
295
+ exist_ok=True,
296
+ )
297
+ except TypeError:
298
+ # Fallback for older Ultralytics versions without 'exist_ok' or similar args
299
+ result = YOLO(pt_path).export(
300
+ format="engine",
301
+ imgsz=IMGSZ,
302
+ half=True,
303
+ device="0",
304
+ )
305
+
306
+ # Handle return type (path string, Path, or list/tuple of them)
307
+ if isinstance(result, (list, tuple)):
308
+ engine_path = result[0] if len(result) > 0 else ""
309
+ else:
310
+ engine_path = result
311
+
312
+ engine_path = str(engine_path)
313
+
314
+ if not engine_path:
315
+ raise RuntimeError(
316
+ "[TRTYOLOEngineBuilder] Engine export failed (empty output path)."
317
+ )
318
+
319
+ # If Ultralytics returned a relative path, try to resolve it robustly.
320
+ if not os.path.isabs(engine_path):
321
+ # 1) Check next to the .pt model (Ultralytics usually uses self.file.with_suffix('.engine'))
322
+ model_dir = os.path.dirname(pt_path)
323
+ candidate = os.path.join(model_dir, engine_path)
324
+ if os.path.isfile(candidate):
325
+ engine_path = candidate
326
+ else:
327
+ # 2) As a fallback, try relative to NODE_DIR
328
+ candidate = os.path.join(NODE_DIR, engine_path)
329
+ if os.path.isfile(candidate):
330
+ engine_path = candidate
331
+ # If still not found, we leave engine_path as-is; user may have a runs/... path.
332
+
333
+ logger.info(f"[TRTYOLOEngineBuilder] Export complete. Engine path: {engine_path}")
334
+ return (engine_path,)
335
+
336
+
337
+ # -------------------------------------------------------------------------
338
+ # NODE 2: TRTYOLOBBoxDetectorProvider
339
+ # - Loads the TensorRT engine and provides a BBOX_DETECTOR object.
340
+ # -------------------------------------------------------------------------
341
+
342
+
343
+ class TRTYOLOBBoxDetectorProvider:
344
+ @classmethod
345
+ def INPUT_TYPES(cls):
346
+ return {
347
+ "required": {
348
+ "engine_path": (
349
+ "STRING",
350
+ {
351
+ "default": "",
352
+ "tooltip": (
353
+ "Path to the TensorRT .engine file.\n"
354
+ "Can be an absolute path or relative to this node's folder.\n"
355
+ "Typically use the output of TRTYOLOEngineBuilder."
356
+ ),
357
+ },
358
+ ),
359
+ }
360
+ }
361
+
362
+ RETURN_TYPES = ("BBOX_DETECTOR",)
363
+ RETURN_NAMES = ("bbox_detector",)
364
+ FUNCTION = "load"
365
+ CATEGORY = "ImpactPack/TensorRT"
366
+
367
+ def load(self, engine_path: str):
368
+ if not engine_path:
369
+ raise ValueError(
370
+ "[TRTYOLOBBoxDetectorProvider] 'engine_path' is empty. "
371
+ "Provide a valid path or connect from TRTYOLOEngineBuilder."
372
+ )
373
+
374
+ engine_path = engine_path.strip()
375
+
376
+ # Resolve relative paths against this node's folder
377
+ if not os.path.isabs(engine_path):
378
+ engine_path = os.path.join(NODE_DIR, engine_path)
379
+
380
+ if not os.path.isfile(engine_path):
381
+ raise FileNotFoundError(
382
+ f"[TRTYOLOBBoxDetectorProvider] Engine file not found: {engine_path}"
383
+ )
384
+
385
+ logger.info(
386
+ f"[TRTYOLOBBoxDetectorProvider] Loading YOLO TensorRT engine from '{engine_path}' on device '0'"
387
+ )
388
+
389
+ # Load the TensorRT engine with Ultralytics (TensorRT backend)
390
+ yolo_model = YOLO(engine_path)
391
+ detector = TRTYOLOBBoxDetector(yolo_model, device="0")
392
+
393
+ return (detector,)
394
+
395
+
396
+ # -------------------------------------------------------------------------
397
+ # ComfyUI node registration
398
+ # -------------------------------------------------------------------------
399
+
400
+
401
+ NODE_CLASS_MAPPINGS = {
402
+ "TRTYOLOEngineBuilder": TRTYOLOEngineBuilder,
403
+ "TRTYOLOBBoxDetectorProvider": TRTYOLOBBoxDetectorProvider,
404
+ }
405
+
406
+ NODE_DISPLAY_NAME_MAPPINGS = {
407
+ "TRTYOLOEngineBuilder": "TensorRT YOLO Engine Builder (1344x768)",
408
+ "TRTYOLOBBoxDetectorProvider": "TensorRT YOLO BBox Detector",
409
+ }