aliensmn's picture
Mirror from https://github.com/yuvraj108c/ComfyUI-Upscaler-Tensorrt
afaf90f verified
import os
import folder_paths
import numpy as np
import torch
from comfy.utils import ProgressBar
from .trt_utilities import Engine
from .utilities import download_file, ColoredLogger, get_final_resolutions
import comfy.model_management as mm
import time
import tensorrt
import json # <--- Import json module
logger = ColoredLogger("ComfyUI-Upscaler-Tensorrt")
IMAGE_DIM_MIN = 256
IMAGE_DIM_OPT = 512
IMAGE_DIM_MAX = 1280
# --- Function to load configuration ---
def load_node_config(config_filename="load_upscaler_config.json"):
"""Loads node configuration from a JSON file."""
current_dir = os.path.dirname(__file__)
config_path = os.path.join(current_dir, config_filename)
default_config = { # Fallback in case file is missing or corrupt
"model": {
"options": ["4x-UltraSharp"],
"default": "4x-UltraSharp",
"tooltip": "Default model (fallback from code)"
},
"precision": {
"options": ["fp16", "fp32"],
"default": "fp16",
"tooltip": "Default precision (fallback from code)"
}
}
try:
with open(config_path, 'r') as f:
config = json.load(f)
logger.info(f"Successfully loaded configuration from {config_filename}")
return config
except FileNotFoundError:
logger.warning(f"Configuration file '{config_path}' not found. Using default fallback configuration.")
return default_config
except json.JSONDecodeError:
logger.error(f"Error decoding JSON from '{config_path}'. Using default fallback configuration.")
return default_config
except Exception as e:
logger.error(f"An unexpected error occurred while loading '{config_path}': {e}. Using default fallback.")
return default_config
# --- Load the configuration once when the module is imported ---
LOAD_UPSCALER_NODE_CONFIG = load_node_config()
class UpscalerTensorrt:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"images": ("IMAGE", {"tooltip": f"Images to be upscaled. Resolution must be between {IMAGE_DIM_MIN} and {IMAGE_DIM_MAX} px"}),
"upscaler_trt_model": ("UPSCALER_TRT_MODEL", {"tooltip": "Tensorrt model built and loaded"}),
"resize_to": (["none", "HD", "FHD", "2k", "4k", "2x", "3x"],{"tooltip": "Resize the upscaled image to fixed resolutions, optional"}),
}
}
RETURN_NAMES = ("IMAGE",)
RETURN_TYPES = ("IMAGE",)
FUNCTION = "upscaler_tensorrt"
CATEGORY = "tensorrt"
DESCRIPTION = "Upscale images with tensorrt"
def upscaler_tensorrt(self, images, upscaler_trt_model, resize_to):
images_bchw = images.permute(0, 3, 1, 2)
B, C, H, W = images_bchw.shape
for dim in (H, W):
if dim > IMAGE_DIM_MAX or dim < IMAGE_DIM_MIN:
raise ValueError(f"Input image dimensions fall outside of the supported range: {IMAGE_DIM_MIN} to {IMAGE_DIM_MAX} px!\nImage dimensions: {W}px by {H}px")
final_width, final_height = get_final_resolutions(W, H, resize_to)
logger.info(f"Upscaling {B} images from H:{H}, W:{W} to H:{H*4}, W:{W*4} | Final resolution: H:{final_height}, W:{final_width} | resize_to: {resize_to}")
shape_dict = {
"input": {"shape": (1, 3, H, W)},
"output": {"shape": (1, 3, H*4, W*4)},
}
upscaler_trt_model.activate()
upscaler_trt_model.allocate_buffers(shape_dict=shape_dict)
cudaStream = torch.cuda.current_stream().cuda_stream
pbar = ProgressBar(B)
images_list = list(torch.split(images_bchw, split_size_or_sections=1))
upscaled_frames = torch.empty((B, C, final_height, final_width), dtype=torch.float32, device=mm.intermediate_device())
must_resize = W*4 != final_width or H*4 != final_height
for i, img in enumerate(images_list):
result = upscaler_trt_model.infer({"input": img}, cudaStream)
result = result["output"]
if must_resize:
result = torch.nn.functional.interpolate(
result,
size=(final_height, final_width),
mode='bicubic',
antialias=True
)
upscaled_frames[i] = result.to(mm.intermediate_device())
pbar.update(1)
output = upscaled_frames.permute(0, 2, 3, 1)
upscaler_trt_model.reset()
mm.soft_empty_cache()
logger.info(f"Output shape: {output.shape}")
return (output,)
class LoadUpscalerTensorrtModel:
@classmethod
def INPUT_TYPES(cls): # Changed 's' to 'cls' for convention
# Use the pre-loaded configuration
model_config = LOAD_UPSCALER_NODE_CONFIG.get("model", {})
precision_config = LOAD_UPSCALER_NODE_CONFIG.get("precision", {})
# Provide sensible defaults if keys are missing in the config (though load_node_config handles this broadly)
model_options = model_config.get("options", ["4x-UltraSharp"])
model_default = model_config.get("default", "4x-UltraSharp")
model_tooltip = model_config.get("tooltip", "Select a model.")
precision_options = precision_config.get("options", ["fp16", "fp32"])
precision_default = precision_config.get("default", "fp16")
precision_tooltip = precision_config.get("tooltip", "Select precision.")
return {
"required": {
"model": (model_options, {"default": model_default, "tooltip": model_tooltip}),
"precision": (precision_options, {"default": precision_default, "tooltip": precision_tooltip}),
}
}
RETURN_NAMES = ("upscaler_trt_model",)
RETURN_TYPES = ("UPSCALER_TRT_MODEL",)
# FUNCTION = "main" # This was duplicated, removing
CATEGORY = "tensorrt"
DESCRIPTION = "Load tensorrt models, they will be built automatically if not found."
FUNCTION = "load_upscaler_tensorrt_model" # This is the correct one
def load_upscaler_tensorrt_model(self, model, precision):
tensorrt_models_dir = os.path.join(folder_paths.models_dir, "tensorrt", "upscaler")
onnx_models_dir = os.path.join(folder_paths.models_dir, "onnx")
os.makedirs(tensorrt_models_dir, exist_ok=True)
os.makedirs(onnx_models_dir, exist_ok=True)
onnx_model_path = os.path.join(onnx_models_dir, f"{model}.onnx")
engine_channel = 3
engine_min_batch, engine_opt_batch, engine_max_batch = 1, 1, 1
engine_min_h, engine_opt_h, engine_max_h = IMAGE_DIM_MIN, IMAGE_DIM_OPT, IMAGE_DIM_MAX
engine_min_w, engine_opt_w, engine_max_w = IMAGE_DIM_MIN, IMAGE_DIM_OPT, IMAGE_DIM_MAX
tensorrt_model_path = os.path.join(tensorrt_models_dir, f"{model}_{precision}_{engine_min_batch}x{engine_channel}x{engine_min_h}x{engine_min_w}_{engine_opt_batch}x{engine_channel}x{engine_opt_h}x{engine_opt_w}_{engine_max_batch}x{engine_channel}x{engine_max_h}x{engine_max_w}_{tensorrt.__version__}.trt")
if not os.path.exists(tensorrt_model_path):
if not os.path.exists(onnx_model_path):
onnx_model_download_url = f"https://huggingface.co/yuvraj108c/ComfyUI-Upscaler-Onnx/resolve/main/{model}.onnx"
logger.info(f"Downloading {onnx_model_download_url}")
download_file(url=onnx_model_download_url, save_path=onnx_model_path)
else:
logger.info(f"Onnx model found at: {onnx_model_path}")
logger.info(f"Building TensorRT engine for {onnx_model_path}: {tensorrt_model_path}")
mm.soft_empty_cache()
s = time.time()
engine = Engine(tensorrt_model_path)
engine.build(
onnx_path=onnx_model_path,
fp16= True if precision == "fp16" else False,
input_profile=[
{"input": [(engine_min_batch,engine_channel,engine_min_h,engine_min_w), (engine_opt_batch,engine_channel,engine_opt_h,engine_min_w), (engine_max_batch,engine_channel,engine_max_h,engine_max_w)]},
],
)
e = time.time()
logger.info(f"Time taken to build: {(e-s)} seconds")
logger.info(f"Loading TensorRT engine: {tensorrt_model_path}")
mm.soft_empty_cache()
engine = Engine(tensorrt_model_path)
engine.load()
return (engine,)
NODE_CLASS_MAPPINGS = {
"UpscalerTensorrt": UpscalerTensorrt,
"LoadUpscalerTensorrtModel": LoadUpscalerTensorrtModel,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"UpscalerTensorrt": "Upscaler Tensorrt ⚡",
"LoadUpscalerTensorrtModel": "Load Upscale Tensorrt Model",
}
__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']