saliacoel commited on
Commit
f5925eb
·
verified ·
1 Parent(s): 4c0e632

Upload tensorrt_loader.py

Browse files
Files changed (1) hide show
  1. tensorrt_loader.py +172 -0
tensorrt_loader.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Put this in the custom_nodes folder, put your tensorrt engine files in ComfyUI/models/tensorrt/ (you will have to create the directory)
2
+
3
+ import torch
4
+ import os
5
+
6
+ import comfy.model_base
7
+ import comfy.model_management
8
+ import comfy.model_patcher
9
+ import comfy.supported_models
10
+ import folder_paths
11
+
12
+ if "tensorrt" in folder_paths.folder_names_and_paths:
13
+ folder_paths.folder_names_and_paths["tensorrt"][0].append(
14
+ os.path.join(folder_paths.models_dir, "tensorrt"))
15
+ folder_paths.folder_names_and_paths["tensorrt"][1].add(".engine")
16
+ else:
17
+ folder_paths.folder_names_and_paths["tensorrt"] = (
18
+ [os.path.join(folder_paths.models_dir, "tensorrt")], {".engine"})
19
+
20
+ import tensorrt as trt
21
+
22
+ trt.init_libnvinfer_plugins(None, "")
23
+
24
+ logger = trt.Logger(trt.Logger.INFO)
25
+ runtime = trt.Runtime(logger)
26
+
27
+ # Is there a function that already exists for this?
28
+ def trt_datatype_to_torch(datatype):
29
+ if datatype == trt.float16:
30
+ return torch.float16
31
+ elif datatype == trt.float32:
32
+ return torch.float32
33
+ elif datatype == trt.int32:
34
+ return torch.int32
35
+ elif datatype == trt.bfloat16:
36
+ return torch.bfloat16
37
+
38
+ class TrTUnet:
39
+ def __init__(self, engine_path):
40
+ with open(engine_path, "rb") as f:
41
+ self.engine = runtime.deserialize_cuda_engine(f.read())
42
+ self.context = self.engine.create_execution_context()
43
+ self.dtype = torch.float16
44
+
45
+ def set_bindings_shape(self, inputs, split_batch):
46
+ for k in inputs:
47
+ shape = inputs[k].shape
48
+ shape = [shape[0] // split_batch] + list(shape[1:])
49
+ self.context.set_input_shape(k, shape)
50
+
51
+ def __call__(self, x, timesteps, context, y=None, **kwargs):
52
+ # Ensure input types match engine precision (e.g., FP16)
53
+ if x.dtype != self.dtype:
54
+ x = x.to(dtype=self.dtype)
55
+ timesteps = timesteps.to(dtype=self.dtype)
56
+ context = context.to(dtype=self.dtype)
57
+ if y is not None:
58
+ y = y.to(dtype=self.dtype)
59
+
60
+ # Prepare model inputs list
61
+ model_inputs = [x, timesteps, context]
62
+ if y is not None:
63
+ model_inputs.append(y)
64
+
65
+ # Set dynamic input shapes for the execution context
66
+ tensor_names = [self.engine.get_tensor_name(i) for i in range(self.engine.num_io_tensors)]
67
+ # Identify input and output names using TensorRT I/O mode
68
+ input_names = [n for n in tensor_names if self.engine.get_tensor_mode(n) == trt.TensorIOMode.INPUT]
69
+ output_names = [n for n in tensor_names if self.engine.get_tensor_mode(n) == trt.TensorIOMode.OUTPUT]
70
+
71
+ # Ensure we have a matching number of input names and provided tensors
72
+ if len(input_names) != len(model_inputs):
73
+ raise RuntimeError(f"Expected {len(input_names)} inputs for TensorRT engine, but got {len(model_inputs)}.")
74
+
75
+ # Set input shapes and addresses
76
+ for name, tensor in zip(input_names, model_inputs):
77
+ shape = tuple(tensor.shape)
78
+ self.context.set_input_shape(name, shape) # specify runtime shape for dynamic dims
79
+ self.context.set_tensor_address(name, tensor.data_ptr()) # bind input memory
80
+
81
+ # Infer shapes (ensures all dynamic dims are resolved)
82
+ missing = self.context.infer_shapes()
83
+ if missing: # if any tensor shapes still unspecified, something is wrong
84
+ raise RuntimeError(f"TensorRT shape inference failed, unresolved tensors: {missing}")
85
+
86
+ # Allocate outputs with proper shapes
87
+ outputs = []
88
+ for name in output_names:
89
+ out_dims = self.context.get_tensor_shape(name) # get resolved output shape (trt.Dims)
90
+ out_shape = [int(d) for d in out_dims] # convert Dims to list of ints
91
+ out_tensor = torch.empty(out_shape, device=self.torch_device, dtype=self.torch_dtype)
92
+ self.context.set_tensor_address(name, out_tensor.data_ptr()) # bind output memory
93
+ outputs.append(out_tensor)
94
+
95
+ # Execute the engine (on default CUDA stream or a pre-created stream)
96
+ self.context.execute_async_v3(stream_handle=0) # using default stream (0) for simplicity
97
+
98
+ # If only one output tensor, return it directly for convenience
99
+ return outputs[0] if len(outputs) == 1 else tuple(outputs)
100
+
101
+
102
+ def load_state_dict(self, sd, strict=False):
103
+ pass
104
+
105
+ def state_dict(self):
106
+ return {}
107
+
108
+
109
+ class TensorRTLoader:
110
+ @classmethod
111
+ def INPUT_TYPES(s):
112
+ return {"required": {"unet_name": (folder_paths.get_filename_list("tensorrt"), ),
113
+ "model_type": (["sdxl_base", "sdxl_refiner", "sd1.x", "sd2.x-768v", "svd", "sd3", "auraflow", "flux_dev", "flux_schnell"], ),
114
+ }}
115
+ RETURN_TYPES = ("MODEL",)
116
+ FUNCTION = "load_unet"
117
+ CATEGORY = "TensorRT"
118
+
119
+ def load_unet(self, unet_name, model_type):
120
+ unet_path = folder_paths.get_full_path("tensorrt", unet_name)
121
+ if not os.path.isfile(unet_path):
122
+ raise FileNotFoundError(f"File {unet_path} does not exist")
123
+ unet = TrTUnet(unet_path)
124
+ if model_type == "sdxl_base":
125
+ conf = comfy.supported_models.SDXL({"adm_in_channels": 2816})
126
+ conf.unet_config["disable_unet_model_creation"] = True
127
+ model = comfy.model_base.SDXL(conf)
128
+ elif model_type == "sdxl_refiner":
129
+ conf = comfy.supported_models.SDXLRefiner(
130
+ {"adm_in_channels": 2560})
131
+ conf.unet_config["disable_unet_model_creation"] = True
132
+ model = comfy.model_base.SDXLRefiner(conf)
133
+ elif model_type == "sd1.x":
134
+ conf = comfy.supported_models.SD15({})
135
+ conf.unet_config["disable_unet_model_creation"] = True
136
+ model = comfy.model_base.BaseModel(conf)
137
+ elif model_type == "sd2.x-768v":
138
+ conf = comfy.supported_models.SD20({})
139
+ conf.unet_config["disable_unet_model_creation"] = True
140
+ model = comfy.model_base.BaseModel(conf, model_type=comfy.model_base.ModelType.V_PREDICTION)
141
+ elif model_type == "svd":
142
+ conf = comfy.supported_models.SVD_img2vid({})
143
+ conf.unet_config["disable_unet_model_creation"] = True
144
+ model = conf.get_model({})
145
+ elif model_type == "sd3":
146
+ conf = comfy.supported_models.SD3({})
147
+ conf.unet_config["disable_unet_model_creation"] = True
148
+ model = conf.get_model({})
149
+ elif model_type == "auraflow":
150
+ conf = comfy.supported_models.AuraFlow({})
151
+ conf.unet_config["disable_unet_model_creation"] = True
152
+ model = conf.get_model({})
153
+ elif model_type == "flux_dev":
154
+ conf = comfy.supported_models.Flux({})
155
+ conf.unet_config["disable_unet_model_creation"] = True
156
+ model = conf.get_model({})
157
+ unet.dtype = torch.bfloat16 #TODO: autodetect
158
+ elif model_type == "flux_schnell":
159
+ conf = comfy.supported_models.FluxSchnell({})
160
+ conf.unet_config["disable_unet_model_creation"] = True
161
+ model = conf.get_model({})
162
+ unet.dtype = torch.bfloat16 #TODO: autodetect
163
+ model.diffusion_model = unet
164
+ model.memory_required = lambda *args, **kwargs: 0 #always pass inputs batched up as much as possible, our TRT code will handle batch splitting
165
+
166
+ return (comfy.model_patcher.ModelPatcher(model,
167
+ load_device=comfy.model_management.get_torch_device(),
168
+ offload_device=comfy.model_management.unet_offload_device()),)
169
+
170
+ NODE_CLASS_MAPPINGS = {
171
+ "TensorRTLoader": TensorRTLoader,
172
+ }