| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel |
| | from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel |
| |
|
| | sd3_common_transformer_block_config = { |
| | "dummy_input": { |
| | "hidden_states": (2, 4096, 1536), |
| | "encoder_hidden_states": (2, 333, 1536), |
| | "temb": (2, 1536), |
| | }, |
| | "output_names": ["encoder_hidden_states_out", "hidden_states_out"], |
| | "dynamic_axes": { |
| | "hidden_states": {0: "batch_size"}, |
| | "encoder_hidden_states": {0: "batch_size"}, |
| | "temb": {0: "steps"}, |
| | }, |
| | } |
| |
|
| | ONNX_CONFIG = { |
| | UNet2DConditionModel: { |
| | "down_blocks.0": { |
| | "dummy_input": { |
| | "hidden_states": (2, 320, 128, 128), |
| | "temb": (2, 1280), |
| | }, |
| | "output_names": ["sample", "res_samples_0", "res_samples_1", "res_samples_2"], |
| | "dynamic_axes": { |
| | "hidden_states": {0: "batch_size"}, |
| | "temb": {0: "steps"}, |
| | }, |
| | }, |
| | "down_blocks.1": { |
| | "dummy_input": { |
| | "hidden_states": (2, 320, 64, 64), |
| | "temb": (2, 1280), |
| | "encoder_hidden_states": (2, 77, 2048), |
| | }, |
| | "output_names": ["sample", "res_samples_0", "res_samples_1", "res_samples_2"], |
| | "dynamic_axes": { |
| | "hidden_states": {0: "batch_size"}, |
| | "temb": {0: "steps"}, |
| | "encoder_hidden_states": {0: "batch_size"}, |
| | }, |
| | }, |
| | "down_blocks.2": { |
| | "dummy_input": { |
| | "hidden_states": (2, 640, 32, 32), |
| | "temb": (2, 1280), |
| | "encoder_hidden_states": (2, 77, 2048), |
| | }, |
| | "output_names": ["sample", "res_samples_0", "res_samples_1"], |
| | "dynamic_axes": { |
| | "hidden_states": {0: "batch_size"}, |
| | "temb": {0: "steps"}, |
| | "encoder_hidden_states": {0: "batch_size"}, |
| | }, |
| | }, |
| | "mid_block": { |
| | "dummy_input": { |
| | "hidden_states": (2, 1280, 32, 32), |
| | "temb": (2, 1280), |
| | "encoder_hidden_states": (2, 77, 2048), |
| | }, |
| | "output_names": ["sample"], |
| | "dynamic_axes": { |
| | "hidden_states": {0: "batch_size"}, |
| | "temb": {0: "steps"}, |
| | "encoder_hidden_states": {0: "batch_size"}, |
| | }, |
| | }, |
| | "up_blocks.0": { |
| | "dummy_input": { |
| | "hidden_states": (2, 1280, 32, 32), |
| | "res_hidden_states_0": (2, 640, 32, 32), |
| | "res_hidden_states_1": (2, 1280, 32, 32), |
| | "res_hidden_states_2": (2, 1280, 32, 32), |
| | "temb": (2, 1280), |
| | "encoder_hidden_states": (2, 77, 2048), |
| | }, |
| | "output_names": ["sample"], |
| | "dynamic_axes": { |
| | "hidden_states": {0: "batch_size"}, |
| | "temb": {0: "steps"}, |
| | "encoder_hidden_states": {0: "batch_size"}, |
| | "res_hidden_states_0": {0: "batch_size"}, |
| | "res_hidden_states_1": {0: "batch_size"}, |
| | "res_hidden_states_2": {0: "batch_size"}, |
| | }, |
| | }, |
| | "up_blocks.1": { |
| | "dummy_input": { |
| | "hidden_states": (2, 1280, 64, 64), |
| | "res_hidden_states_0": (2, 320, 64, 64), |
| | "res_hidden_states_1": (2, 640, 64, 64), |
| | "res_hidden_states_2": (2, 640, 64, 64), |
| | "temb": (2, 1280), |
| | "encoder_hidden_states": (2, 77, 2048), |
| | }, |
| | "output_names": ["sample"], |
| | "dynamic_axes": { |
| | "hidden_states": {0: "batch_size"}, |
| | "temb": {0: "steps"}, |
| | "encoder_hidden_states": {0: "batch_size"}, |
| | "res_hidden_states_0": {0: "batch_size"}, |
| | "res_hidden_states_1": {0: "batch_size"}, |
| | "res_hidden_states_2": {0: "batch_size"}, |
| | }, |
| | }, |
| | "up_blocks.2": { |
| | "dummy_input": { |
| | "hidden_states": (2, 640, 128, 128), |
| | "res_hidden_states_0": (2, 320, 128, 128), |
| | "res_hidden_states_1": (2, 320, 128, 128), |
| | "res_hidden_states_2": (2, 320, 128, 128), |
| | "temb": (2, 1280), |
| | }, |
| | "output_names": ["sample"], |
| | "dynamic_axes": { |
| | "hidden_states": {0: "batch_size"}, |
| | "temb": {0: "steps"}, |
| | "res_hidden_states_0": {0: "batch_size"}, |
| | "res_hidden_states_1": {0: "batch_size"}, |
| | "res_hidden_states_2": {0: "batch_size"}, |
| | }, |
| | }, |
| | }, |
| | SD3Transformer2DModel: { |
| | **{f"transformer_blocks.{i}": sd3_common_transformer_block_config for i in range(23)}, |
| | "transformer_blocks.23": { |
| | "dummy_input": { |
| | "hidden_states": (2, 4096, 1536), |
| | "encoder_hidden_states": (2, 333, 1536), |
| | "temb": (2, 1536), |
| | }, |
| | "output_names": ["hidden_states_out"], |
| | "dynamic_axes": { |
| | "hidden_states": {0: "batch_size"}, |
| | "encoder_hidden_states": {0: "batch_size"}, |
| | "temb": {0: "steps"}, |
| | }, |
| | }, |
| | }, |
| | } |
| |
|