File size: 9,136 Bytes
f5925eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6054f8
f5925eb
c6054f8
 
f5925eb
c6054f8
f5925eb
c6054f8
 
 
 
f5925eb
c6054f8
 
 
 
 
f5925eb
 
 
 
e24ff1f
f5925eb
 
e24ff1f
 
 
 
 
 
 
 
 
1306c71
e24ff1f
 
 
 
1306c71
 
e24ff1f
 
 
 
 
 
 
 
c6054f8
e24ff1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1306c71
e24ff1f
1306c71
e24ff1f
c6054f8
 
e24ff1f
c6054f8
e24ff1f
 
 
 
1306c71
e24ff1f
c6054f8
1306c71
e24ff1f
 
1306c71
e24ff1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5925eb
 
e24ff1f
f5925eb
 
 
 
 
1306c71
c6054f8
e24ff1f
f5925eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
import torch
import os

import comfy.model_base
import comfy.model_management
import comfy.model_patcher
import comfy.supported_models
import folder_paths

if "tensorrt" in folder_paths.folder_names_and_paths:
    folder_paths.folder_names_and_paths["tensorrt"][0].append(
        os.path.join(folder_paths.models_dir, "tensorrt"))
    folder_paths.folder_names_and_paths["tensorrt"][1].add(".engine")
else:
    folder_paths.folder_names_and_paths["tensorrt"] = (
        [os.path.join(folder_paths.models_dir, "tensorrt")], {".engine"})

import tensorrt as trt

trt.init_libnvinfer_plugins(None, "")

logger = trt.Logger(trt.Logger.INFO)
runtime = trt.Runtime(logger)


def trt_datatype_to_torch(datatype):
    # Works for TRT 8/9/10
    if datatype in (getattr(trt, "float16", None), getattr(trt.DataType, "HALF", None)):
        return torch.float16
    if datatype in (getattr(trt, "float32", None), getattr(trt.DataType, "FLOAT", None)):
        return torch.float32
    if hasattr(trt, "bfloat16") and datatype in (
        getattr(trt, "bfloat16", None),
        getattr(trt.DataType, "BF16", None),
    ):
        return torch.bfloat16
    if datatype in (getattr(trt, "int32", None), getattr(trt.DataType, "INT32", None)):
        return torch.int32
    # Fallback – shouldn't normally hit this for UNets
    return torch.float32


class TrTUnet:
    def __init__(self, engine_path):
        with open(engine_path, "rb") as f:
            self.engine = runtime.deserialize_cuda_engine(f.read())
        self.context = self.engine.create_execution_context()

        # Default torch device / dtype for allocations
        self.device = comfy.model_management.get_torch_device()
        self.default_dtype = torch.float16  # fallback if something unknown shows up

    def _trt_dtype_to_torch(self, trt_dtype):
        dt = trt_datatype_to_torch(trt_dtype)
        return dt if dt is not None else self.default_dtype

    def __call__(self, x, timesteps, context, y=None, control=None, transformer_options=None, **kwargs):
        """
        x          : [B, C, H, W]
        timesteps  : [B]
        context    : [B, N, D]
        y          : [B, y_dim]   (optional, SDXL etc.)
        """

        # -----------------------------
        # 1. Build dict of actual inputs
        # -----------------------------
        model_inputs = {
            "x": x,
            "timesteps": timesteps,
            "context": context,
        }
        if y is not None:
            model_inputs["y"] = y

        # If your engine has extra inputs (e.g. 'guidance' for Flux),
        # they must either come from kwargs or be absent from the engine.
        tensor_names = [self.engine.get_tensor_name(i) for i in range(self.engine.num_io_tensors)]
        input_names  = [n for n in tensor_names if self.engine.get_tensor_mode(n) == trt.TensorIOMode.INPUT]
        output_names = [n for n in tensor_names if self.engine.get_tensor_mode(n) == trt.TensorIOMode.OUTPUT]

        # Fill missing inputs from kwargs if present
        for name in input_names:
            if name in model_inputs:
                continue
            if name in kwargs:
                model_inputs[name] = kwargs[name]

        if len(model_inputs) != len(input_names):
            missing = [n for n in input_names if n not in model_inputs]
            raise RuntimeError(
                f"TensorRT UNet: missing required inputs for engine: {missing} "
                f"(have {list(model_inputs.keys())})"
            )

        # -----------------------------
        # 2. Convert each input to engine dtype + bind it
        # -----------------------------
        for name in input_names:
            t = model_inputs[name]

            # Move to correct device
            if t.device != self.device:
                t = t.to(self.device)

            # Match TensorRT's expected dtype for this tensor
            trt_dtype = self.engine.get_tensor_dtype(name)
            torch_dtype = self._trt_dtype_to_torch(trt_dtype)
            if t.dtype != torch_dtype:
                t = t.to(dtype=torch_dtype)

            # Update back (so later code sees the converted tensor if needed)
            model_inputs[name] = t

            # Set runtime shape and bind memory
            self.context.set_input_shape(name, tuple(t.shape))
            self.context.set_tensor_address(name, int(t.data_ptr()))

        # Make sure all shapes are resolved
        missing = self.context.infer_shapes()
        if missing:
            raise RuntimeError(f"TensorRT shape inference failed, unresolved tensors: {missing}")

        # -----------------------------
        # 3. Allocate & bind outputs
        # -----------------------------
        outputs = {}
        for name in output_names:
            out_dims = self.context.get_tensor_shape(name)  # trt.Dims
            out_shape = tuple(int(d) for d in out_dims)

            trt_dtype = self.engine.get_tensor_dtype(name)
            torch_dtype = self._trt_dtype_to_torch(trt_dtype)

            out_tensor = torch.empty(out_shape, device=self.device, dtype=torch_dtype)
            self.context.set_tensor_address(name, int(out_tensor.data_ptr()))
            outputs[name] = out_tensor

        # -----------------------------
        # 4. Execute on the current torch CUDA stream
        # -----------------------------
        stream = torch.cuda.current_stream(self.device)
        self.context.execute_async_v3(stream_handle=stream.cuda_stream)

        # No need to sync explicitly; ComfyUI uses the same default stream.

        # Return outputs in a stable order
        out_list = [outputs[name] for name in output_names]
        return out_list[0] if len(out_list) == 1 else tuple(out_list)

    def load_state_dict(self, sd, strict=False):
        pass

    def state_dict(self):
        return {}





class TensorRTLoader:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {"unet_name": (folder_paths.get_filename_list("tensorrt"), ),
                             "model_type": (["sdxl_base", "sdxl_refiner", "sd1.x", "sd2.x-768v", "svd", "sd3", "auraflow", "flux_dev", "flux_schnell"], ),
                             }}
    RETURN_TYPES = ("MODEL",)
    FUNCTION = "load_unet"
    CATEGORY = "TensorRT"

    def load_unet(self, unet_name, model_type):
        unet_path = folder_paths.get_full_path("tensorrt", unet_name)
        if not os.path.isfile(unet_path):
            raise FileNotFoundError(f"File {unet_path} does not exist")
        unet = TrTUnet(unet_path)
        if model_type == "sdxl_base":
            conf = comfy.supported_models.SDXL({"adm_in_channels": 2816})
            conf.unet_config["disable_unet_model_creation"] = True
            model = comfy.model_base.SDXL(conf)
        elif model_type == "sdxl_refiner":
            conf = comfy.supported_models.SDXLRefiner(
                {"adm_in_channels": 2560})
            conf.unet_config["disable_unet_model_creation"] = True
            model = comfy.model_base.SDXLRefiner(conf)
        elif model_type == "sd1.x":
            conf = comfy.supported_models.SD15({})
            conf.unet_config["disable_unet_model_creation"] = True
            model = comfy.model_base.BaseModel(conf)
        elif model_type == "sd2.x-768v":
            conf = comfy.supported_models.SD20({})
            conf.unet_config["disable_unet_model_creation"] = True
            model = comfy.model_base.BaseModel(conf, model_type=comfy.model_base.ModelType.V_PREDICTION)
        elif model_type == "svd":
            conf = comfy.supported_models.SVD_img2vid({})
            conf.unet_config["disable_unet_model_creation"] = True
            model = conf.get_model({})
        elif model_type == "sd3":
            conf = comfy.supported_models.SD3({})
            conf.unet_config["disable_unet_model_creation"] = True
            model = conf.get_model({})
        elif model_type == "auraflow":
            conf = comfy.supported_models.AuraFlow({})
            conf.unet_config["disable_unet_model_creation"] = True
            model = conf.get_model({})
        elif model_type == "flux_dev":
            conf = comfy.supported_models.Flux({})
            conf.unet_config["disable_unet_model_creation"] = True
            model = conf.get_model({})
            unet.dtype = torch.bfloat16 #TODO: autodetect
        elif model_type == "flux_schnell":
            conf = comfy.supported_models.FluxSchnell({})
            conf.unet_config["disable_unet_model_creation"] = True
            model = conf.get_model({})
            unet.dtype = torch.bfloat16 #TODO: autodetect
        model.diffusion_model = unet
        model.memory_required = lambda *args, **kwargs: 0 #always pass inputs batched up as much as possible, our TRT code will handle batch splitting

        return (comfy.model_patcher.ModelPatcher(model,
                                                 load_device=comfy.model_management.get_torch_device(),
                                                 offload_device=comfy.model_management.unet_offload_device()),)

NODE_CLASS_MAPPINGS = {
    "TensorRTLoader": TensorRTLoader,
}