saliacoel commited on
Commit
1306c71
·
verified ·
1 Parent(s): 18518ac

Upload tensorrt_loader.py

Browse files
Files changed (1) hide show
  1. tensorrt_loader.py +114 -49
tensorrt_loader.py CHANGED
@@ -40,72 +40,137 @@ class TrTUnet:
40
  with open(engine_path, "rb") as f:
41
  self.engine = runtime.deserialize_cuda_engine(f.read())
42
  self.context = self.engine.create_execution_context()
 
43
  self.dtype = torch.float16
44
 
45
  def set_bindings_shape(self, inputs, split_batch):
 
 
46
  for k in inputs:
47
  shape = inputs[k].shape
48
  shape = [shape[0] // split_batch] + list(shape[1:])
49
  self.context.set_input_shape(k, shape)
50
 
51
- def __call__(self, x, timesteps, context, y=None, **kwargs):
52
- # Ensure input types match engine precision (e.g., FP16)
53
- if x.dtype != self.dtype:
54
- x = x.to(dtype=self.dtype)
55
- timesteps = timesteps.to(dtype=self.dtype)
56
- context = context.to(dtype=self.dtype)
 
 
 
 
 
 
 
 
 
 
 
57
  if y is not None:
58
- y = y.to(dtype=self.dtype)
59
-
60
- # Prepare model inputs list
61
- model_inputs = [x, timesteps, context]
62
- if y is not None:
63
- model_inputs.append(y)
64
-
65
- # Set dynamic input shapes for the execution context
66
- tensor_names = [self.engine.get_tensor_name(i) for i in range(self.engine.num_io_tensors)]
67
- # Identify input and output names using TensorRT I/O mode
68
- input_names = [n for n in tensor_names if self.engine.get_tensor_mode(n) == trt.TensorIOMode.INPUT]
69
- output_names = [n for n in tensor_names if self.engine.get_tensor_mode(n) == trt.TensorIOMode.OUTPUT]
70
-
71
- # Ensure we have a matching number of input names and provided tensors
72
- if len(input_names) != len(model_inputs):
73
- raise RuntimeError(f"Expected {len(input_names)} inputs for TensorRT engine, but got {len(model_inputs)}.")
74
-
75
- # Set input shapes and addresses
76
- for name, tensor in zip(input_names, model_inputs):
77
- shape = tuple(tensor.shape)
78
- self.context.set_input_shape(name, shape) # specify runtime shape for dynamic dims
79
- self.context.set_tensor_address(name, tensor.data_ptr()) # bind input memory
80
-
81
- # Infer shapes (ensures all dynamic dims are resolved)
82
- missing = self.context.infer_shapes()
83
- if missing: # if any tensor shapes still unspecified, something is wrong
84
- raise RuntimeError(f"TensorRT shape inference failed, unresolved tensors: {missing}")
85
-
86
- # Allocate outputs with proper shapes
87
- outputs = []
88
- for name in output_names:
89
- out_dims = self.context.get_tensor_shape(name) # get resolved output shape (trt.Dims)
90
- out_shape = [int(d) for d in out_dims] # convert Dims to list of ints
91
- out_tensor = torch.empty(out_shape, device=self.torch_device, dtype=self.torch_dtype)
92
- self.context.set_tensor_address(name, out_tensor.data_ptr()) # bind output memory
93
- outputs.append(out_tensor)
94
-
95
- # Execute the engine (on default CUDA stream or a pre-created stream)
96
- self.context.execute_async_v3(stream_handle=0) # using default stream (0) for simplicity
97
-
98
- # If only one output tensor, return it directly for convenience
99
- return outputs[0] if len(outputs) == 1 else tuple(outputs)
100
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  def load_state_dict(self, sd, strict=False):
 
103
  pass
104
 
105
  def state_dict(self):
 
106
  return {}
107
 
108
 
 
109
  class TensorRTLoader:
110
  @classmethod
111
  def INPUT_TYPES(s):
 
40
  with open(engine_path, "rb") as f:
41
  self.engine = runtime.deserialize_cuda_engine(f.read())
42
  self.context = self.engine.create_execution_context()
43
+ # default dtype in case something doesn't have a specific TRT dtype
44
  self.dtype = torch.float16
45
 
46
  def set_bindings_shape(self, inputs, split_batch):
47
+ # still here in case something else calls it, but the new __call__
48
+ # no longer uses this split-batch path
49
  for k in inputs:
50
  shape = inputs[k].shape
51
  shape = [shape[0] // split_batch] + list(shape[1:])
52
  self.context.set_input_shape(k, shape)
53
 
54
+ def __call__(self, x, timesteps, context, y=None,
55
+ control=None, transformer_options=None, **kwargs):
56
+ """
57
+ Run the TensorRT UNet.
58
+
59
+ - `control` and `transformer_options` are accepted for API compatibility
60
+ with Comfy, but ignored by the TRT engine.
61
+ - Any extra tensor inputs (e.g. `guidance` for Flux) are taken from
62
+ **kwargs and matched by name to the engine’s input tensors.
63
+ """
64
+
65
+ # Collect all tensors we might need by name
66
+ available = {
67
+ "x": x,
68
+ "timesteps": timesteps,
69
+ "context": context,
70
+ }
71
  if y is not None:
72
+ available["y"] = y
73
+
74
+ # Extra conds (e.g. 'guidance', etc.) may come in via kwargs
75
+ for name, value in kwargs.items():
76
+ if isinstance(value, torch.Tensor):
77
+ available[name] = value
78
+
79
+ # Query engine IO tensors
80
+ tensor_names = [
81
+ self.engine.get_tensor_name(i)
82
+ for i in range(self.engine.num_io_tensors)
83
+ ]
84
+ input_names = [
85
+ n for n in tensor_names
86
+ if self.engine.get_tensor_mode(n) == trt.TensorIOMode.INPUT
87
+ ]
88
+ output_names = [
89
+ n for n in tensor_names
90
+ if self.engine.get_tensor_mode(n) == trt.TensorIOMode.OUTPUT
91
+ ]
92
+
93
+ # Sanity check: we must have a tensor for every input
94
+ missing = [n for n in input_names if n not in available]
95
+ if missing:
96
+ raise RuntimeError(
97
+ f"Missing tensors for TensorRT engine inputs: {missing}. "
98
+ f"Available: {list(available.keys())}"
99
+ )
100
+
101
+ device = x.device
102
+
103
+ # Bind inputs: fix dtype + device, set shapes and addresses
104
+ for name in input_names:
105
+ t = available[name]
106
+
107
+ if not t.is_contiguous():
108
+ t = t.contiguous()
109
+
110
+ # Match engine dtype
111
+ trt_dtype = self.engine.get_tensor_dtype(name)
112
+ torch_dtype = trt_datatype_to_torch(trt_dtype)
113
+ if torch_dtype is None:
114
+ raise RuntimeError(
115
+ f"Unsupported TensorRT dtype {trt_dtype} for input '{name}'"
116
+ )
117
+
118
+ if t.dtype != torch_dtype:
119
+ t = t.to(dtype=torch_dtype)
120
+
121
+ if t.device != device:
122
+ t = t.to(device)
123
+
124
+ # Save back in case we changed it
125
+ available[name] = t
126
+
127
+ # Tell TRT the runtime shape and bind the memory
128
+ self.context.set_input_shape(name, tuple(t.shape))
129
+ self.context.set_tensor_address(name, t.data_ptr())
130
+
131
+ # Let TRT resolve all dynamic shapes (outputs etc.)
132
+ unresolved = self.context.infer_shapes()
133
+ if unresolved:
134
+ raise RuntimeError(
135
+ f"TensorRT shape inference failed, unresolved tensors: {unresolved}"
136
+ )
137
+
138
+ # Allocate and bind outputs
139
+ outputs = []
140
+ for name in output_names:
141
+ dims = self.context.get_tensor_shape(name) # trt.Dims
142
+
143
+ # Guard against the old nbDims == -1 issue
144
+ if hasattr(dims, "nb_dims") and dims.nb_dims < 0:
145
+ raise RuntimeError(f"Output '{name}' has invalid dims: {dims}")
146
+
147
+ shape = [int(d) for d in dims]
148
+ trt_dtype = self.engine.get_tensor_dtype(name)
149
+ torch_dtype = trt_datatype_to_torch(trt_dtype)
150
+
151
+ out = torch.empty(shape, device=device, dtype=torch_dtype)
152
+ self.context.set_tensor_address(name, out.data_ptr())
153
+ outputs.append(out)
154
+
155
+ # Run on the default torch CUDA stream
156
+ stream = torch.cuda.default_stream(device)
157
+ self.context.execute_async_v3(stream_handle=stream.cuda_stream)
158
+
159
+ # Return single tensor or a tuple
160
+ if len(outputs) == 1:
161
+ return outputs[0]
162
+ return tuple(outputs)
163
 
164
  def load_state_dict(self, sd, strict=False):
165
+ # Nothing to load for a serialized TensorRT engine
166
  pass
167
 
168
  def state_dict(self):
169
+ # Keep API compatible with nn.Module
170
  return {}
171
 
172
 
173
+
174
  class TensorRTLoader:
175
  @classmethod
176
  def INPUT_TYPES(s):