| from models.controlnet import ControlNetModel | |
| from safetensors.torch import load_file | |
| import torch | |
| def load_safetensors(model, safetensors_path, strict=True, load_weight_increasement=False): | |
| if not load_weight_increasement: | |
| state_dict = load_file(safetensors_path) | |
| model.load_state_dict(state_dict, strict=strict) | |
| else: | |
| state_dict = load_file(safetensors_path) | |
| pretrained_state_dict = model.state_dict() | |
| for k in state_dict.keys(): | |
| state_dict[k] = state_dict[k] + pretrained_state_dict[k] | |
| model.load_state_dict(state_dict, strict=False) | |
| controlnet = ControlNetModel() | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| controlnet.to(device) | |
| load_safetensors(controlnet, '/home/ControlNeXt/ControlNeXt-SDXL/controlnet.safetensors') | |
| image = torch.randn((1, 3, 1024, 1024), dtype=torch.float32).to(device) | |
| timestep = torch.rand(1, dtype= torch.float32).to(device) | |
| dummy_inputs = (image, timestep) | |
| onnx_output_path = '/home/new_onnx/cnext/model.onnx' | |
| torch.onnx.export( | |
| controlnet, | |
| dummy_inputs, | |
| onnx_output_path, | |
| export_params=True, | |
| opset_version=18, | |
| do_constant_folding=True, | |
| input_names=['controlnext_image', 'timestep'], | |
| output_names=['sample'], | |
| dynamic_axes={ | |
| 'controlnext_image': {0: 'batch_size', 2: 'height', 3: 'width'}, | |
| 'sample': {0: 'batch_size'}, | |
| } | |
| ) |