import torch import os import comfy.model_base import comfy.model_management import comfy.model_patcher import comfy.supported_models import folder_paths if "tensorrt" in folder_paths.folder_names_and_paths: folder_paths.folder_names_and_paths["tensorrt"][0].append( os.path.join(folder_paths.models_dir, "tensorrt")) folder_paths.folder_names_and_paths["tensorrt"][1].add(".engine") else: folder_paths.folder_names_and_paths["tensorrt"] = ( [os.path.join(folder_paths.models_dir, "tensorrt")], {".engine"}) import tensorrt as trt trt.init_libnvinfer_plugins(None, "") logger = trt.Logger(trt.Logger.INFO) runtime = trt.Runtime(logger) def trt_datatype_to_torch(datatype): # Works for TRT 8/9/10 if datatype in (getattr(trt, "float16", None), getattr(trt.DataType, "HALF", None)): return torch.float16 if datatype in (getattr(trt, "float32", None), getattr(trt.DataType, "FLOAT", None)): return torch.float32 if hasattr(trt, "bfloat16") and datatype in ( getattr(trt, "bfloat16", None), getattr(trt.DataType, "BF16", None), ): return torch.bfloat16 if datatype in (getattr(trt, "int32", None), getattr(trt.DataType, "INT32", None)): return torch.int32 # Fallback – shouldn't normally hit this for UNets return torch.float32 class TrTUnet: def __init__(self, engine_path): with open(engine_path, "rb") as f: self.engine = runtime.deserialize_cuda_engine(f.read()) self.context = self.engine.create_execution_context() # Default torch device / dtype for allocations self.device = comfy.model_management.get_torch_device() self.default_dtype = torch.float16 # fallback if something unknown shows up def _trt_dtype_to_torch(self, trt_dtype): dt = trt_datatype_to_torch(trt_dtype) return dt if dt is not None else self.default_dtype def __call__(self, x, timesteps, context, y=None, control=None, transformer_options=None, **kwargs): """ x : [B, C, H, W] timesteps : [B] context : [B, N, D] y : [B, y_dim] (optional, SDXL etc.) """ # ----------------------------- # 1. Build dict of actual inputs # ----------------------------- model_inputs = { "x": x, "timesteps": timesteps, "context": context, } if y is not None: model_inputs["y"] = y # If your engine has extra inputs (e.g. 'guidance' for Flux), # they must either come from kwargs or be absent from the engine. tensor_names = [self.engine.get_tensor_name(i) for i in range(self.engine.num_io_tensors)] input_names = [n for n in tensor_names if self.engine.get_tensor_mode(n) == trt.TensorIOMode.INPUT] output_names = [n for n in tensor_names if self.engine.get_tensor_mode(n) == trt.TensorIOMode.OUTPUT] # Fill missing inputs from kwargs if present for name in input_names: if name in model_inputs: continue if name in kwargs: model_inputs[name] = kwargs[name] if len(model_inputs) != len(input_names): missing = [n for n in input_names if n not in model_inputs] raise RuntimeError( f"TensorRT UNet: missing required inputs for engine: {missing} " f"(have {list(model_inputs.keys())})" ) # ----------------------------- # 2. Convert each input to engine dtype + bind it # ----------------------------- for name in input_names: t = model_inputs[name] # Move to correct device if t.device != self.device: t = t.to(self.device) # Match TensorRT's expected dtype for this tensor trt_dtype = self.engine.get_tensor_dtype(name) torch_dtype = self._trt_dtype_to_torch(trt_dtype) if t.dtype != torch_dtype: t = t.to(dtype=torch_dtype) # Update back (so later code sees the converted tensor if needed) model_inputs[name] = t # Set runtime shape and bind memory self.context.set_input_shape(name, tuple(t.shape)) self.context.set_tensor_address(name, int(t.data_ptr())) # Make sure all shapes are resolved missing = self.context.infer_shapes() if missing: raise RuntimeError(f"TensorRT shape inference failed, unresolved tensors: {missing}") # ----------------------------- # 3. Allocate & bind outputs # ----------------------------- outputs = {} for name in output_names: out_dims = self.context.get_tensor_shape(name) # trt.Dims out_shape = tuple(int(d) for d in out_dims) trt_dtype = self.engine.get_tensor_dtype(name) torch_dtype = self._trt_dtype_to_torch(trt_dtype) out_tensor = torch.empty(out_shape, device=self.device, dtype=torch_dtype) self.context.set_tensor_address(name, int(out_tensor.data_ptr())) outputs[name] = out_tensor # ----------------------------- # 4. Execute on the current torch CUDA stream # ----------------------------- stream = torch.cuda.current_stream(self.device) self.context.execute_async_v3(stream_handle=stream.cuda_stream) # No need to sync explicitly; ComfyUI uses the same default stream. # Return outputs in a stable order out_list = [outputs[name] for name in output_names] return out_list[0] if len(out_list) == 1 else tuple(out_list) def load_state_dict(self, sd, strict=False): pass def state_dict(self): return {} class TensorRTLoader: @classmethod def INPUT_TYPES(s): return {"required": {"unet_name": (folder_paths.get_filename_list("tensorrt"), ), "model_type": (["sdxl_base", "sdxl_refiner", "sd1.x", "sd2.x-768v", "svd", "sd3", "auraflow", "flux_dev", "flux_schnell"], ), }} RETURN_TYPES = ("MODEL",) FUNCTION = "load_unet" CATEGORY = "TensorRT" def load_unet(self, unet_name, model_type): unet_path = folder_paths.get_full_path("tensorrt", unet_name) if not os.path.isfile(unet_path): raise FileNotFoundError(f"File {unet_path} does not exist") unet = TrTUnet(unet_path) if model_type == "sdxl_base": conf = comfy.supported_models.SDXL({"adm_in_channels": 2816}) conf.unet_config["disable_unet_model_creation"] = True model = comfy.model_base.SDXL(conf) elif model_type == "sdxl_refiner": conf = comfy.supported_models.SDXLRefiner( {"adm_in_channels": 2560}) conf.unet_config["disable_unet_model_creation"] = True model = comfy.model_base.SDXLRefiner(conf) elif model_type == "sd1.x": conf = comfy.supported_models.SD15({}) conf.unet_config["disable_unet_model_creation"] = True model = comfy.model_base.BaseModel(conf) elif model_type == "sd2.x-768v": conf = comfy.supported_models.SD20({}) conf.unet_config["disable_unet_model_creation"] = True model = comfy.model_base.BaseModel(conf, model_type=comfy.model_base.ModelType.V_PREDICTION) elif model_type == "svd": conf = comfy.supported_models.SVD_img2vid({}) conf.unet_config["disable_unet_model_creation"] = True model = conf.get_model({}) elif model_type == "sd3": conf = comfy.supported_models.SD3({}) conf.unet_config["disable_unet_model_creation"] = True model = conf.get_model({}) elif model_type == "auraflow": conf = comfy.supported_models.AuraFlow({}) conf.unet_config["disable_unet_model_creation"] = True model = conf.get_model({}) elif model_type == "flux_dev": conf = comfy.supported_models.Flux({}) conf.unet_config["disable_unet_model_creation"] = True model = conf.get_model({}) unet.dtype = torch.bfloat16 #TODO: autodetect elif model_type == "flux_schnell": conf = comfy.supported_models.FluxSchnell({}) conf.unet_config["disable_unet_model_creation"] = True model = conf.get_model({}) unet.dtype = torch.bfloat16 #TODO: autodetect model.diffusion_model = unet model.memory_required = lambda *args, **kwargs: 0 #always pass inputs batched up as much as possible, our TRT code will handle batch splitting return (comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()),) NODE_CLASS_MAPPINGS = { "TensorRTLoader": TensorRTLoader, }