saliacoel commited on
Commit
202cca2
·
verified ·
1 Parent(s): d7c9f29

Upload TensorRTBBoxDetector.py

Browse files
Files changed (1) hide show
  1. TensorRTBBoxDetector.py +412 -0
TensorRTBBoxDetector.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from typing import List, Tuple
4
+
5
+ import torch
6
+ import numpy as np
7
+ from ultralytics import YOLO
8
+
9
+ # Impact Pack (for SEG and SEGS helpers)
10
+ import impact.core as core
11
+ from impact.core import SEG
12
+
13
+ # Optional: TensorRT sanity check
14
+ try:
15
+ import tensorrt as trt # type: ignore
16
+ except Exception:
17
+ trt = None
18
+
19
+ # Local helpers
20
+ try:
21
+ # If this folder is a package (has __init__.py), use relative import
22
+ from .utils_salia import (
23
+ NODE_DIR,
24
+ IMGSZ,
25
+ list_local_pt_files,
26
+ tensor_to_pil,
27
+ make_crop_region,
28
+ crop_image,
29
+ crop_ndarray2,
30
+ dilate_mask,
31
+ )
32
+ except ImportError:
33
+ # Fallback: direct import if utils_salia is on sys.path
34
+ from utils_salia import (
35
+ NODE_DIR,
36
+ IMGSZ,
37
+ list_local_pt_files,
38
+ tensor_to_pil,
39
+ make_crop_region,
40
+ crop_image,
41
+ crop_ndarray2,
42
+ dilate_mask,
43
+ )
44
+
45
+ logger = logging.getLogger(__name__)
46
+
47
+
48
+ # -------------------------------------------------------------------------
49
+ # YOLO TensorRT-based BBOX_DETECTOR implementation
50
+ # -------------------------------------------------------------------------
51
+
52
+
53
+ class TRTYOLOBBoxDetector:
54
+ """
55
+ BBOX_DETECTOR interface compatible with Impact Pack / FaceDetailer.
56
+
57
+ Methods required:
58
+ - setAux(x)
59
+ - detect(image, threshold, dilation, crop_factor, drop_size=1, detailer_hook=None)
60
+ - detect_combined(image, threshold, dilation)
61
+ """
62
+
63
+ def __init__(self, yolo_model: YOLO, device: str = "0"):
64
+ self.bbox_model = yolo_model
65
+ # Ultralytics accepts "0" or "cuda:0"; normalize a bit:
66
+ if device in ("0", "cuda", "cuda:0"):
67
+ self.device = "0"
68
+ else:
69
+ self.device = str(device)
70
+
71
+ # FaceDetailer calls setAux('face'); we keep it for compatibility.
72
+ self.aux = None
73
+
74
+ def setAux(self, x):
75
+ # Kept for API compatibility. You could use this
76
+ # later to filter by specific labels/classes.
77
+ self.aux = x
78
+
79
+ def detect(
80
+ self,
81
+ image: torch.Tensor,
82
+ threshold: float,
83
+ dilation: int,
84
+ crop_factor: float,
85
+ drop_size: int = 1,
86
+ detailer_hook=None,
87
+ ) -> Tuple[Tuple[int, int], List[SEG]]:
88
+ """
89
+ Main detection method used by FaceDetailer.
90
+
91
+ Args:
92
+ image: ComfyUI IMAGE tensor [B, H, W, C] in 0..1
93
+ threshold: confidence threshold
94
+ dilation: mask dilation in pixels
95
+ crop_factor: expansion factor for bbox when computing crop_region
96
+ drop_size: minimum bbox width/height to keep
97
+ detailer_hook: optional hook with post_crop_region / post_detection
98
+
99
+ Returns:
100
+ SEGS tuple: ( (H, W), [SEG, SEG, ...] )
101
+ """
102
+
103
+ if image.dim() != 4:
104
+ raise ValueError(
105
+ "[TRTYOLOBBoxDetector] Expected IMAGE tensor with 4 dims [B, H, W, C]."
106
+ )
107
+
108
+ if image.shape[0] != 1:
109
+ logger.warning(
110
+ "[TRTYOLOBBoxDetector] Batch > 1 detected, using only the first image for detection."
111
+ )
112
+ image = image[:1]
113
+
114
+ h, w = int(image.shape[1]), int(image.shape[2])
115
+ shape = (h, w)
116
+
117
+ # -----------------------------------------------------------------
118
+ # Run YOLO TensorRT via Ultralytics wrapper
119
+ # -----------------------------------------------------------------
120
+ pil_img = tensor_to_pil(image) # should return a single PIL image for B=1
121
+
122
+ # Ultralytics chooses TensorRT backend automatically when you pass an .engine
123
+ # model to YOLO(). Here we only set device & threshold.
124
+ pred_list = self.bbox_model(
125
+ pil_img,
126
+ conf=threshold,
127
+ device=self.device,
128
+ verbose=False,
129
+ )
130
+
131
+ if len(pred_list) == 0:
132
+ return (shape, [])
133
+
134
+ pred = pred_list[0]
135
+ boxes = pred.boxes
136
+
137
+ if boxes is None or boxes.xyxy is None or boxes.xyxy.shape[0] == 0:
138
+ return (shape, [])
139
+
140
+ xyxy = boxes.xyxy.cpu().numpy() # [N, 4] (x1, y1, x2, y2)
141
+ confs = boxes.conf.cpu().numpy()
142
+ clses = boxes.cls.cpu().numpy().astype(int)
143
+ names = pred.names # dict: class_index -> class_name
144
+
145
+ seg_items: List[SEG] = []
146
+
147
+ for i in range(xyxy.shape[0]):
148
+ x1, y1, x2, y2 = xyxy[i]
149
+ score = float(confs[i])
150
+ cls_id = int(clses[i])
151
+ label = names.get(cls_id, str(cls_id))
152
+
153
+ box_w = x2 - x1
154
+ box_h = y2 - y1
155
+ if box_w <= drop_size or box_h <= drop_size:
156
+ continue
157
+
158
+ # Clamp bbox to image bounds
159
+ x1_i = max(int(np.floor(x1)), 0)
160
+ y1_i = max(int(np.floor(y1)), 0)
161
+ x2_i = min(int(np.ceil(x2)), w)
162
+ y2_i = min(int(np.ceil(y2)), h)
163
+ if x2_i <= x1_i or y2_i <= y1_i:
164
+ continue
165
+
166
+ # Rectangular mask from bbox, uint8 0..255
167
+ mask = np.zeros((h, w), dtype=np.uint8)
168
+ mask[y1_i:y2_i, x1_i:x2_i] = 255
169
+
170
+ # Optional dilation
171
+ if dilation != 0:
172
+ mask = dilate_mask(mask, dilation)
173
+
174
+ # Impact bbox order here is (x1, y1, x2, y2)
175
+ item_bbox = [float(x1), float(y1), float(x2), float(y2)]
176
+
177
+ # Compute crop region from bbox and crop_factor
178
+ crop_region = make_crop_region(w, h, item_bbox, crop_factor)
179
+ if detailer_hook is not None and hasattr(detailer_hook, "post_crop_region"):
180
+ crop_region = detailer_hook.post_crop_region(
181
+ w, h, item_bbox, crop_region
182
+ )
183
+
184
+ # Crop image + mask
185
+ cropped_image = crop_image(image, crop_region) # torch [1, h', w', C]
186
+ cropped_mask = crop_ndarray2(mask, crop_region) # np.uint8 [h', w']
187
+
188
+ seg = SEG(
189
+ cropped_image,
190
+ cropped_mask,
191
+ score,
192
+ crop_region,
193
+ item_bbox,
194
+ label,
195
+ None, # control_net_wrapper
196
+ )
197
+ seg_items.append(seg)
198
+
199
+ segs = (shape, seg_items)
200
+
201
+ if detailer_hook is not None and hasattr(detailer_hook, "post_detection"):
202
+ segs = detailer_hook.post_detection(segs)
203
+
204
+ return segs
205
+
206
+ def detect_combined(
207
+ self,
208
+ image: torch.Tensor,
209
+ threshold: float,
210
+ dilation: int,
211
+ ) -> torch.Tensor:
212
+ """
213
+ Optional helper API: returns a combined MASK of all detections.
214
+ """
215
+ segs = self.detect(
216
+ image=image,
217
+ threshold=threshold,
218
+ dilation=dilation,
219
+ crop_factor=1.0,
220
+ drop_size=1,
221
+ detailer_hook=None,
222
+ )
223
+ return core.segs_to_combined_mask(segs)
224
+
225
+
226
+ # -------------------------------------------------------------------------
227
+ # NODE 1: TRTYOLOEngineBuilder
228
+ # - Builds TensorRT engine from a .pt file sitting next to this .py
229
+ # imgsz = IMGSZ (H, W) from utils_salia
230
+ # batch = 1
231
+ # half = True (FP16)
232
+ # device = "0"
233
+ # overwrite (exist_ok) = True
234
+ # -------------------------------------------------------------------------
235
+
236
+
237
+ class TRTYOLOEngineBuilder:
238
+ @classmethod
239
+ def INPUT_TYPES(cls):
240
+ pt_files = list_local_pt_files()
241
+ default_name = pt_files[0] if pt_files else "face.pt"
242
+
243
+ return {
244
+ "required": {
245
+ "pt_model_name": (
246
+ pt_files if pt_files else ["face.pt"],
247
+ {
248
+ "default": default_name,
249
+ "tooltip": (
250
+ "Select a YOLO .pt file that lives in the SAME folder as this node file.\n"
251
+ "Example: 'face.pt' next to TensorRTBBoxDetector.py"
252
+ ),
253
+ },
254
+ ),
255
+ }
256
+ }
257
+
258
+ RETURN_TYPES = ("STRING",)
259
+ RETURN_NAMES = ("engine_path",)
260
+ FUNCTION = "build"
261
+ CATEGORY = "ImpactPack/TensorRT"
262
+
263
+ def _check_tensorrt_available(self):
264
+ """
265
+ Optional: preflight check to give a clearer error message if TensorRT
266
+ cannot initialize (instead of a raw pybind11::init() error deep inside Ultralytics).
267
+ """
268
+ if trt is None:
269
+ raise RuntimeError(
270
+ "[TRTYOLOEngineBuilder] TensorRT Python package is not available. "
271
+ "Install it via pip (cu12 build) or use an image with TensorRT preinstalled."
272
+ )
273
+ try:
274
+ logger_trt = trt.Logger(trt.Logger.ERROR)
275
+ builder = trt.Builder(logger_trt)
276
+ # If TensorRT has trouble with CUDA/driver, this is where it will fail.
277
+ del builder
278
+ except Exception as e:
279
+ raise RuntimeError(
280
+ "[TRTYOLOEngineBuilder] TensorRT failed to initialize. "
281
+ "Check that your CUDA driver, CUDA runtime, and TensorRT versions match. "
282
+ f"Original error: {e}"
283
+ ) from e
284
+
285
+ def build(self, pt_model_name: str):
286
+ # Resolve .pt path relative to this .py
287
+ pt_path = os.path.join(NODE_DIR, pt_model_name)
288
+ if not os.path.isfile(pt_path):
289
+ raise FileNotFoundError(
290
+ f"[TRTYOLOEngineBuilder] .pt model not found next to this node: {pt_path}"
291
+ )
292
+
293
+ # Optional: sanity check TensorRT/driver before asking Ultralytics to export.
294
+ self._check_tensorrt_available()
295
+
296
+ logger.info(
297
+ f"[TRTYOLOEngineBuilder] Exporting TensorRT engine from '{pt_path}' "
298
+ f"with imgsz={IMGSZ} (H,W), batch=1, half=True, device='0', exist_ok=True"
299
+ )
300
+
301
+ try:
302
+ # Ultralytics API: export TensorRT engine directly from .pt
303
+ engine_path = YOLO(pt_path).export(
304
+ format="engine",
305
+ imgsz=IMGSZ,
306
+ half=True,
307
+ device="0",
308
+ exist_ok=True, # overwrite or reuse if same settings
309
+ )
310
+ except TypeError:
311
+ # Fallback for older Ultralytics versions that might not recognize some args.
312
+ engine_path = YOLO(pt_path).export(
313
+ format="engine",
314
+ imgsz=IMGSZ,
315
+ half=True,
316
+ device="0",
317
+ )
318
+
319
+ engine_path = str(engine_path)
320
+
321
+ # If Ultralytics returned a relative path, treat it as relative to NODE_DIR
322
+ if not os.path.isabs(engine_path):
323
+ candidate = os.path.join(NODE_DIR, engine_path)
324
+ if os.path.isfile(candidate):
325
+ engine_path = candidate
326
+
327
+ if not os.path.isfile(engine_path):
328
+ raise FileNotFoundError(
329
+ f"[TRTYOLOEngineBuilder] Export completed but engine file not found at: {engine_path}"
330
+ )
331
+
332
+ logger.info(f"[TRTYOLOEngineBuilder] Export done. Engine path: {engine_path}")
333
+
334
+ return (engine_path,)
335
+
336
+
337
+ # -------------------------------------------------------------------------
338
+ # NODE 2: TRTYOLOBBoxDetectorProvider
339
+ # - Loads the TensorRT engine and wraps it as BBOX_DETECTOR
340
+ # - engine_path can be:
341
+ # * Absolute path
342
+ # * Relative to this Python file's folder
343
+ # -------------------------------------------------------------------------
344
+
345
+
346
+ class TRTYOLOBBoxDetectorProvider:
347
+ @classmethod
348
+ def INPUT_TYPES(cls):
349
+ return {
350
+ "required": {
351
+ "engine_path": (
352
+ "STRING",
353
+ {
354
+ "default": "",
355
+ "multiline": False,
356
+ "tooltip": (
357
+ "Path to .engine file.\n"
358
+ "Can be an absolute path OR a path RELATIVE to this node's folder.\n"
359
+ "Typically, you connect this from TRTYOLOEngineBuilder."
360
+ ),
361
+ },
362
+ ),
363
+ }
364
+ }
365
+
366
+ RETURN_TYPES = ("BBOX_DETECTOR",)
367
+ RETURN_NAMES = ("bbox_detector",)
368
+ FUNCTION = "load"
369
+ CATEGORY = "ImpactPack/TensorRT"
370
+
371
+ def load(self, engine_path: str):
372
+ if not engine_path:
373
+ raise ValueError(
374
+ "[TRTYOLOBBoxDetectorProvider] engine_path is empty. "
375
+ "Connect the output from TRTYOLOEngineBuilder or type a path."
376
+ )
377
+
378
+ engine_path = engine_path.strip()
379
+
380
+ # If relative, treat as relative to this file's folder
381
+ if not os.path.isabs(engine_path):
382
+ engine_path = os.path.join(NODE_DIR, engine_path)
383
+
384
+ if not os.path.isfile(engine_path):
385
+ raise FileNotFoundError(
386
+ f"[TRTYOLOBBoxDetectorProvider] Engine file not found: {engine_path}"
387
+ )
388
+
389
+ logger.info(
390
+ f"[TRTYOLOBBoxDetectorProvider] Loading YOLO TensorRT engine from '{engine_path}' on device '0'"
391
+ )
392
+
393
+ # Ultralytics will detect it's a TensorRT engine and use TRT backend internally.
394
+ yolo_model = YOLO(engine_path)
395
+ detector = TRTYOLOBBoxDetector(yolo_model, device="0")
396
+
397
+ return (detector,)
398
+
399
+
400
+ # -------------------------------------------------------------------------
401
+ # ComfyUI registration
402
+ # -------------------------------------------------------------------------
403
+
404
+ NODE_CLASS_MAPPINGS = {
405
+ "TRTYOLOEngineBuilder": TRTYOLOEngineBuilder,
406
+ "TRTYOLOBBoxDetectorProvider": TRTYOLOBBoxDetectorProvider,
407
+ }
408
+
409
+ NODE_DISPLAY_NAME_MAPPINGS = {
410
+ "TRTYOLOEngineBuilder": "TensorRT YOLO Engine Builder (1344x768, local .pt)",
411
+ "TRTYOLOBBoxDetectorProvider": "TensorRT YOLO BBox Detector",
412
+ }