Using devices [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)] Device count 4 Global device count 4 Global Batch: 512 Node Batch: 512 Device Batch: 128 Loading dataset Loading dataset creating model beta1: 0.9 beta2: 0.999 bootstrap_cfg: 1 bootstrap_dt_bias: 0 bootstrap_ema: 1 bootstrap_every: 8 cfg_scale: 1.5 class_dropout_prob: 0.1 denoise_timesteps: 128 depth: 12 dropout: 0.0 dt_sampling: uniform hidden_size: 768 lr: 0.0001 mlp_ratio: 4 num_classes: 1000 num_heads: 12 patch_size: 2 sharding: dp t_sampling: discrete-dt target_update_rate: 0.999 train_type: naive use_cosine: 0 use_ema: 0 use_stable_vae: 1 warmup: 0 weight_decay: 0.1 Total devices TPU_0(process=0,(0,0,0,0)) Initializing encoder. Incoming encoder shape (1, 256, 256, 3) Encoder layer (1, 256, 256, 128) doing downsample Encoder layer (1, 128, 128, 128) doing downsample Encoder layer (1, 64, 64, 256) doing downsample Encoder layer (1, 32, 32, 512) Encoder layer (1, 32, 32, 512) Encoder layer final (1, 32, 32, 512) Encoder layer final (1, 32, 32, 512) Final embeddings are size (1, 32, 32, 8) After quant (1, 32, 32, 4) encode finished Decoder incoming shape (1, 32, 32, 4) Decoder input (1, 32, 32, 512) Mid Block Decoder layer (1, 32, 32, 512) Mid Block Decoder layer (1, 32, 32, 512) Decoder layer (1, 64, 64, 512) Decoder layer (1, 128, 128, 512) Decoder layer (1, 256, 256, 256) Decoder layer (1, 256, 256, 128) Total num of VQVAE parameters: 67565323 Disc shape (1, 128, 128, 128) Disc shape (1, 64, 64, 256) Disc shape (1, 32, 32, 512) Disc shape (1, 16, 16, 512) Disc shape (1, 8, 8, 512) Disc shape (1, 4, 4, 512) Total num of Discriminator parameters: 23998017 Loaded checkpoint from 15359025 seconds ago. Loaded model with step 511001 ┌──────────────────────────────────────────────────────────────────────────────┐ │ TPU 0 │ ├──────────────────────────────────────────────────────────────────────────────┤ │ TPU 1 │ ├──────────────────────────────────────────────────────────────────────────────┤ │ TPU 2 │ ├──────────────────────────────────────────────────────────────────────────────┤ │ TPU 3 │ └──────────────────────────────────────────────────────────────────────────────┘ returning model model done Input to vae (4, 1, 256, 256, 3) encode image shape (1, 256, 256, 3) Initializing encoder. Incoming encoder shape (1, 256, 256, 3) Encoder layer (1, 256, 256, 128) doing downsample Encoder layer (1, 128, 128, 128) doing downsample Encoder layer (1, 64, 64, 256) doing downsample Encoder layer (1, 32, 32, 512) Encoder layer (1, 32, 32, 512) Encoder layer final (1, 32, 32, 512) Encoder layer final (1, 32, 32, 512) Final embeddings are size (1, 32, 32, 8) After quant (1, 32, 32, 4) output example shape (4, 1, 32, 32, 4) Test data shape (4, 256, 256, 3) x shape (4, 1, 256, 256, 3) encoded shape (4, 1, 32, 32, 4) z_vectors shape (1, 32, 32, 4) Decoder incoming shape (1, 32, 32, 4) Decoder input (1, 32, 32, 512) Mid Block Decoder layer (1, 32, 32, 512) Mid Block Decoder layer (1, 32, 32, 512) Decoder layer (1, 64, 64, 512) Decoder layer (1, 128, 128, 512) Decoder layer (1, 256, 256, 256) Decoder layer (1, 256, 256, 128) image shape (4, 1, 256, 256, 3) decoded img shape (256, 256, 3) obs shape (4, 32, 32, 4) DiT: Input of shape (4, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (4, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (4, 256, 768) dtype bfloat16 DiT: Conditioning of shape (1, 768) dtype float32  DiT Summary  ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ path  ┃ module  ┃ inputs  ┃ outputs  ┃ params  ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ │ DiT │ - float32[4,32,32,4] │ bfloat16[4,32,32,4] │ │ │ │ │ - float32[1] │ │ │ │ │ │ - float32[1] │ │ │ │ │ │ - int32[1] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ PatchEmbed_0 │ PatchEmbed │ float32[4,32,32,4] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ PatchEmbed_0/Conv_0 │ Conv │ float32[4,32,32,4] │ bfloat16[4,16,16,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[2,2,4,768] │ │ │ │ │ │ │ │ │ │ │ │ 13,056 (52.2 KB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ TimestepEmbedder_0 │ TimestepEmbedder │ float32[1] │ float32[1,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ TimestepEmbedder_0/Dense_0 │ Dense │ bfloat16[1,256] │ bfloat16[1,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[256,768] │ │ │ │ │ │ │ │ │ │ │ │ 197,376 (789.5 KB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ TimestepEmbedder_0/Dense_1 │ Dense │ bfloat16[1,768] │ float32[1,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ TimestepEmbedder_1 │ TimestepEmbedder │ float32[1] │ float32[1,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ TimestepEmbedder_1/Dense_0 │ Dense │ bfloat16[1,256] │ bfloat16[1,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[256,768] │ │ │ │ │ │ │ │ │ │ │ │ 197,376 (789.5 KB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ TimestepEmbedder_1/Dense_1 │ Dense │ bfloat16[1,768] │ float32[1,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ LabelEmbedder_0 │ LabelEmbedder │ int32[1] │ bfloat16[1,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ LabelEmbedder_0/Embed_0 │ Embed │ int32[1] │ bfloat16[1,768] │ embedding: float32[1001,768] │ │ │ │ │ │ │ │ │ │ │ │ 768,768 (3.1 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ FinalLayer_0 │ FinalLayer │ - bfloat16[4,256,768] │ bfloat16[4,256,16] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ FinalLayer_0/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,1536] │ bias: float32[1536] │ │ │ │ │ │ kernel: float32[768,1536] │ │ │ │ │ │ │ │ │ │ │ │ 1,181,184 (4.7 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ FinalLayer_0/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ FinalLayer_0/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,16] │ bias: float32[16] │ │ │ │ │ │ kernel: float32[768,16] │ │ │ │ │ │ │ │ │ │ │ │ 12,304 (49.2 KB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ Embed_0 │ Embed │ int32[1] │ float32[1,1] │ embedding: float32[256,1] │ │ │ │ │ │ │ │ │ │ │ │ 256 (1.0 KB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │   │   │   │  Total │ 131,091,728 (524.4 MB)  │ └──────────────────────────────────┴──────────────────┴───────────────────────┴───────────────────────┴──────────────────────────────┘    Total Parameters: 131,091,728 (524.4 MB)  DiT: Input of shape (4, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (4, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (4, 256, 768) dtype bfloat16 DiT: Conditioning of shape (1, 768) dtype float32 Loaded checkpoint from 110441 seconds ago. parameter shapes: ('PatchEmbed_0', 'Conv_0', 'kernel'): (2, 2, 4, 768) ('PatchEmbed_0', 'Conv_0', 'bias'): (768,) ('TimestepEmbedder_0', 'Dense_0', 'kernel'): (256, 768) ('TimestepEmbedder_0', 'Dense_0', 'bias'): (768,) ('TimestepEmbedder_0', 'Dense_1', 'kernel'): (768, 768) ('TimestepEmbedder_0', 'Dense_1', 'bias'): (768,) ('TimestepEmbedder_1', 'Dense_0', 'kernel'): (256, 768) ('TimestepEmbedder_1', 'Dense_0', 'bias'): (768,) ('TimestepEmbedder_1', 'Dense_1', 'kernel'): (768, 768) ('TimestepEmbedder_1', 'Dense_1', 'bias'): (768,) ('LabelEmbedder_0', 'Embed_0', 'embedding'): (1001, 768) ('DiTBlock_0', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_0', 'Dense_0', 'bias'): (4608,) ('DiTBlock_0', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_0', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_0', 'Dense_2', 'bias'): (768,) ('DiTBlock_0', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_0', 'Dense_3', 'bias'): (768,) ('DiTBlock_0', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_0', 'Dense_4', 'bias'): (768,) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_1', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_1', 'Dense_0', 'bias'): (4608,) ('DiTBlock_1', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_1', 'Dense_1', 'bias'): (768,) ('DiTBlock_1', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_1', 'Dense_2', 'bias'): (768,) ('DiTBlock_1', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_1', 'Dense_3', 'bias'): (768,) ('DiTBlock_1', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_1', 'Dense_4', 'bias'): (768,) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_2', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_2', 'Dense_0', 'bias'): (4608,) ('DiTBlock_2', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_2', 'Dense_1', 'bias'): (768,) ('DiTBlock_2', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_2', 'Dense_2', 'bias'): (768,) ('DiTBlock_2', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_2', 'Dense_3', 'bias'): (768,) ('DiTBlock_2', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_2', 'Dense_4', 'bias'): (768,) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_3', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_3', 'Dense_0', 'bias'): (4608,) ('DiTBlock_3', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_3', 'Dense_1', 'bias'): (768,) ('DiTBlock_3', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_3', 'Dense_2', 'bias'): (768,) ('DiTBlock_3', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_3', 'Dense_3', 'bias'): (768,) ('DiTBlock_3', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_3', 'Dense_4', 'bias'): (768,) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_4', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_4', 'Dense_0', 'bias'): (4608,) ('DiTBlock_4', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_4', 'Dense_1', 'bias'): (768,) ('DiTBlock_4', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_4', 'Dense_2', 'bias'): (768,) ('DiTBlock_4', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_4', 'Dense_3', 'bias'): (768,) ('DiTBlock_4', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_4', 'Dense_4', 'bias'): (768,) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_5', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_5', 'Dense_0', 'bias'): (4608,) ('DiTBlock_5', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_5', 'Dense_1', 'bias'): (768,) ('DiTBlock_5', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_5', 'Dense_2', 'bias'): (768,) ('DiTBlock_5', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_5', 'Dense_3', 'bias'): (768,) ('DiTBlock_5', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_5', 'Dense_4', 'bias'): (768,) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_6', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_6', 'Dense_0', 'bias'): (4608,) ('DiTBlock_6', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_6', 'Dense_1', 'bias'): (768,) ('DiTBlock_6', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_6', 'Dense_2', 'bias'): (768,) ('DiTBlock_6', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_6', 'Dense_3', 'bias'): (768,) ('DiTBlock_6', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_6', 'Dense_4', 'bias'): (768,) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_7', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_7', 'Dense_0', 'bias'): (4608,) ('DiTBlock_7', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_7', 'Dense_1', 'bias'): (768,) ('DiTBlock_7', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_7', 'Dense_2', 'bias'): (768,) ('DiTBlock_7', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_7', 'Dense_3', 'bias'): (768,) ('DiTBlock_7', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_7', 'Dense_4', 'bias'): (768,) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_8', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_8', 'Dense_0', 'bias'): (4608,) ('DiTBlock_8', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_8', 'Dense_1', 'bias'): (768,) ('DiTBlock_8', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_8', 'Dense_2', 'bias'): (768,) ('DiTBlock_8', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_8', 'Dense_3', 'bias'): (768,) ('DiTBlock_8', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_8', 'Dense_4', 'bias'): (768,) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_9', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_9', 'Dense_0', 'bias'): (4608,) ('DiTBlock_9', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_9', 'Dense_1', 'bias'): (768,) ('DiTBlock_9', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_9', 'Dense_2', 'bias'): (768,) ('DiTBlock_9', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_9', 'Dense_3', 'bias'): (768,) ('DiTBlock_9', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_9', 'Dense_4', 'bias'): (768,) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_10', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_10', 'Dense_0', 'bias'): (4608,) ('DiTBlock_10', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_10', 'Dense_1', 'bias'): (768,) ('DiTBlock_10', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_10', 'Dense_2', 'bias'): (768,) ('DiTBlock_10', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_10', 'Dense_3', 'bias'): (768,) ('DiTBlock_10', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_10', 'Dense_4', 'bias'): (768,) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_11', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_11', 'Dense_0', 'bias'): (4608,) ('DiTBlock_11', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_11', 'Dense_1', 'bias'): (768,) ('DiTBlock_11', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_11', 'Dense_2', 'bias'): (768,) ('DiTBlock_11', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_11', 'Dense_3', 'bias'): (768,) ('DiTBlock_11', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_11', 'Dense_4', 'bias'): (768,) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('FinalLayer_0', 'Dense_0', 'kernel'): (768, 1536) ('FinalLayer_0', 'Dense_0', 'bias'): (1536,) ('FinalLayer_0', 'Dense_1', 'kernel'): (768, 16) ('FinalLayer_0', 'Dense_1', 'bias'): (16,) ('Embed_0', 'embedding'): (256, 1) parameter shapes: ('DiTBlock_0', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_0', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_1', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_1', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_1', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_10', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_10', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_10', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_11', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_11', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_11', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_2', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_2', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_2', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_3', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_3', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_3', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_4', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_4', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_4', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_5', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_5', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_5', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_6', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_6', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_6', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_7', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_7', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_7', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_8', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_8', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_8', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_9', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_9', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_9', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('Embed_0', 'embedding'): (1, 256, 1) ('FinalLayer_0', 'Dense_0', 'bias'): (1, 1536) ('FinalLayer_0', 'Dense_0', 'kernel'): (1, 768, 1536) ('FinalLayer_0', 'Dense_1', 'bias'): (1, 16) ('FinalLayer_0', 'Dense_1', 'kernel'): (1, 768, 16) ('LabelEmbedder_0', 'Embed_0', 'embedding'): (1, 1001, 768) ('PatchEmbed_0', 'Conv_0', 'bias'): (1, 768) ('PatchEmbed_0', 'Conv_0', 'kernel'): (1, 2, 2, 4, 768) ('TimestepEmbedder_0', 'Dense_0', 'bias'): (1, 768) ('TimestepEmbedder_0', 'Dense_0', 'kernel'): (1, 256, 768) ('TimestepEmbedder_0', 'Dense_1', 'bias'): (1, 768) ('TimestepEmbedder_0', 'Dense_1', 'kernel'): (1, 768, 768) ('TimestepEmbedder_1', 'Dense_0', 'bias'): (1, 768) ('TimestepEmbedder_1', 'Dense_0', 'kernel'): (1, 256, 768) ('TimestepEmbedder_1', 'Dense_1', 'bias'): (1, 768) ('TimestepEmbedder_1', 'Dense_1', 'kernel'): (1, 768, 768) parameter shapes: ('DiTBlock_0', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_0', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_1', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_1', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_1', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_10', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_10', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_10', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_11', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_11', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_11', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_2', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_2', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_2', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_3', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_3', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_3', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_4', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_4', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_4', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_5', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_5', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_5', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_6', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_6', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_6', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_7', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_7', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_7', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_8', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_8', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_8', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_9', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_9', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_9', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('Embed_0', 'embedding'): (1, 256, 1) ('FinalLayer_0', 'Dense_0', 'bias'): (1, 1536) ('FinalLayer_0', 'Dense_0', 'kernel'): (1, 768, 1536) ('FinalLayer_0', 'Dense_1', 'bias'): (1, 16) ('FinalLayer_0', 'Dense_1', 'kernel'): (1, 768, 16) ('LabelEmbedder_0', 'Embed_0', 'embedding'): (1, 1001, 768) ('PatchEmbed_0', 'Conv_0', 'bias'): (1, 768) ('PatchEmbed_0', 'Conv_0', 'kernel'): (1, 2, 2, 4, 768) ('TimestepEmbedder_0', 'Dense_0', 'bias'): (1, 768) ('TimestepEmbedder_0', 'Dense_0', 'kernel'): (1, 256, 768) ('TimestepEmbedder_0', 'Dense_1', 'bias'): (1, 768) ('TimestepEmbedder_0', 'Dense_1', 'kernel'): (1, 768, 768) ('TimestepEmbedder_1', 'Dense_0', 'bias'): (1, 768) ('TimestepEmbedder_1', 'Dense_0', 'kernel'): (1, 256, 768) ('TimestepEmbedder_1', 'Dense_1', 'bias'): (1, 768) ('TimestepEmbedder_1', 'Dense_1', 'kernel'): (1, 768, 768) parameter shapes: ('DiTBlock_0', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_0', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_1', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_1', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_1', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_10', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_10', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_10', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_11', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_11', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_11', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_2', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_2', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_2', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_3', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_3', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_3', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_4', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_4', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_4', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_5', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_5', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_5', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_6', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_6', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_6', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_7', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_7', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_7', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_8', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_8', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_8', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_9', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_9', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_9', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('Embed_0', 'embedding'): (1, 256, 1) ('FinalLayer_0', 'Dense_0', 'bias'): (1, 1536) ('FinalLayer_0', 'Dense_0', 'kernel'): (1, 768, 1536) ('FinalLayer_0', 'Dense_1', 'bias'): (1, 16) ('FinalLayer_0', 'Dense_1', 'kernel'): (1, 768, 16) ('LabelEmbedder_0', 'Embed_0', 'embedding'): (1, 1001, 768) ('PatchEmbed_0', 'Conv_0', 'bias'): (1, 768) ('PatchEmbed_0', 'Conv_0', 'kernel'): (1, 2, 2, 4, 768) ('TimestepEmbedder_0', 'Dense_0', 'bias'): (1, 768) ('TimestepEmbedder_0', 'Dense_0', 'kernel'): (1, 256, 768) ('TimestepEmbedder_0', 'Dense_1', 'bias'): (1, 768) ('TimestepEmbedder_0', 'Dense_1', 'kernel'): (1, 768, 768) ('TimestepEmbedder_1', 'Dense_0', 'bias'): (1, 768) ('TimestepEmbedder_1', 'Dense_0', 'kernel'): (1, 256, 768) ('TimestepEmbedder_1', 'Dense_1', 'bias'): (1, 768) ('TimestepEmbedder_1', 'Dense_1', 'kernel'): (1, 768, 768) parameter shapes: ('DiTBlock_0', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_0', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_1', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_1', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_1', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_10', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_10', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_10', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_11', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_11', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_11', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_2', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_2', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_2', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_3', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_3', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_3', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_4', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_4', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_4', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_5', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_5', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_5', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_6', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_6', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_6', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_7', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_7', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_7', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_8', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_8', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_8', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_9', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_9', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_9', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('Embed_0', 'embedding'): (1, 256, 1) ('FinalLayer_0', 'Dense_0', 'bias'): (1, 1536) ('FinalLayer_0', 'Dense_0', 'kernel'): (1, 768, 1536) ('FinalLayer_0', 'Dense_1', 'bias'): (1, 16) ('FinalLayer_0', 'Dense_1', 'kernel'): (1, 768, 16) ('LabelEmbedder_0', 'Embed_0', 'embedding'): (1, 1001, 768) ('PatchEmbed_0', 'Conv_0', 'bias'): (1, 768) ('PatchEmbed_0', 'Conv_0', 'kernel'): (1, 2, 2, 4, 768) ('TimestepEmbedder_0', 'Dense_0', 'bias'): (1, 768) ('TimestepEmbedder_0', 'Dense_0', 'kernel'): (1, 256, 768) ('TimestepEmbedder_0', 'Dense_1', 'bias'): (1, 768) ('TimestepEmbedder_0', 'Dense_1', 'kernel'): (1, 768, 768) ('TimestepEmbedder_1', 'Dense_0', 'bias'): (1, 768) ('TimestepEmbedder_1', 'Dense_0', 'kernel'): (1, 256, 768) ('TimestepEmbedder_1', 'Dense_1', 'bias'): (1, 768) ('TimestepEmbedder_1', 'Dense_1', 'kernel'): (1, 768, 768) parameter shapes: ('DiTBlock_0', 'Dense_0', 'bias'): (4608,) ('DiTBlock_0', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_0', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_0', 'Dense_2', 'bias'): (768,) ('DiTBlock_0', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_0', 'Dense_3', 'bias'): (768,) ('DiTBlock_0', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_0', 'Dense_4', 'bias'): (768,) ('DiTBlock_0', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_1', 'Dense_0', 'bias'): (4608,) ('DiTBlock_1', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_1', 'Dense_1', 'bias'): (768,) ('DiTBlock_1', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_1', 'Dense_2', 'bias'): (768,) ('DiTBlock_1', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_1', 'Dense_3', 'bias'): (768,) ('DiTBlock_1', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_1', 'Dense_4', 'bias'): (768,) ('DiTBlock_1', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_10', 'Dense_0', 'bias'): (4608,) ('DiTBlock_10', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_10', 'Dense_1', 'bias'): (768,) ('DiTBlock_10', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_10', 'Dense_2', 'bias'): (768,) ('DiTBlock_10', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_10', 'Dense_3', 'bias'): (768,) ('DiTBlock_10', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_10', 'Dense_4', 'bias'): (768,) ('DiTBlock_10', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_11', 'Dense_0', 'bias'): (4608,) ('DiTBlock_11', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_11', 'Dense_1', 'bias'): (768,) ('DiTBlock_11', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_11', 'Dense_2', 'bias'): (768,) ('DiTBlock_11', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_11', 'Dense_3', 'bias'): (768,) ('DiTBlock_11', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_11', 'Dense_4', 'bias'): (768,) ('DiTBlock_11', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_2', 'Dense_0', 'bias'): (4608,) ('DiTBlock_2', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_2', 'Dense_1', 'bias'): (768,) ('DiTBlock_2', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_2', 'Dense_2', 'bias'): (768,) ('DiTBlock_2', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_2', 'Dense_3', 'bias'): (768,) ('DiTBlock_2', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_2', 'Dense_4', 'bias'): (768,) ('DiTBlock_2', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_3', 'Dense_0', 'bias'): (4608,) ('DiTBlock_3', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_3', 'Dense_1', 'bias'): (768,) ('DiTBlock_3', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_3', 'Dense_2', 'bias'): (768,) ('DiTBlock_3', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_3', 'Dense_3', 'bias'): (768,) ('DiTBlock_3', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_3', 'Dense_4', 'bias'): (768,) ('DiTBlock_3', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_4', 'Dense_0', 'bias'): (4608,) ('DiTBlock_4', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_4', 'Dense_1', 'bias'): (768,) ('DiTBlock_4', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_4', 'Dense_2', 'bias'): (768,) ('DiTBlock_4', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_4', 'Dense_3', 'bias'): (768,) ('DiTBlock_4', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_4', 'Dense_4', 'bias'): (768,) ('DiTBlock_4', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_5', 'Dense_0', 'bias'): (4608,) ('DiTBlock_5', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_5', 'Dense_1', 'bias'): (768,) ('DiTBlock_5', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_5', 'Dense_2', 'bias'): (768,) ('DiTBlock_5', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_5', 'Dense_3', 'bias'): (768,) ('DiTBlock_5', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_5', 'Dense_4', 'bias'): (768,) ('DiTBlock_5', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_6', 'Dense_0', 'bias'): (4608,) ('DiTBlock_6', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_6', 'Dense_1', 'bias'): (768,) ('DiTBlock_6', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_6', 'Dense_2', 'bias'): (768,) ('DiTBlock_6', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_6', 'Dense_3', 'bias'): (768,) ('DiTBlock_6', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_6', 'Dense_4', 'bias'): (768,) ('DiTBlock_6', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_7', 'Dense_0', 'bias'): (4608,) ('DiTBlock_7', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_7', 'Dense_1', 'bias'): (768,) ('DiTBlock_7', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_7', 'Dense_2', 'bias'): (768,) ('DiTBlock_7', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_7', 'Dense_3', 'bias'): (768,) ('DiTBlock_7', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_7', 'Dense_4', 'bias'): (768,) ('DiTBlock_7', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_8', 'Dense_0', 'bias'): (4608,) ('DiTBlock_8', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_8', 'Dense_1', 'bias'): (768,) ('DiTBlock_8', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_8', 'Dense_2', 'bias'): (768,) ('DiTBlock_8', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_8', 'Dense_3', 'bias'): (768,) ('DiTBlock_8', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_8', 'Dense_4', 'bias'): (768,) ('DiTBlock_8', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_9', 'Dense_0', 'bias'): (4608,) ('DiTBlock_9', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_9', 'Dense_1', 'bias'): (768,) ('DiTBlock_9', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_9', 'Dense_2', 'bias'): (768,) ('DiTBlock_9', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_9', 'Dense_3', 'bias'): (768,) ('DiTBlock_9', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_9', 'Dense_4', 'bias'): (768,) ('DiTBlock_9', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('Embed_0', 'embedding'): (256, 1) ('FinalLayer_0', 'Dense_0', 'bias'): (1536,) ('FinalLayer_0', 'Dense_0', 'kernel'): (768, 1536) ('FinalLayer_0', 'Dense_1', 'bias'): (16,) ('FinalLayer_0', 'Dense_1', 'kernel'): (768, 16) ('LabelEmbedder_0', 'Embed_0', 'embedding'): (1001, 768) ('PatchEmbed_0', 'Conv_0', 'bias'): (768,) ('PatchEmbed_0', 'Conv_0', 'kernel'): (2, 2, 4, 768) ('TimestepEmbedder_0', 'Dense_0', 'bias'): (768,) ('TimestepEmbedder_0', 'Dense_0', 'kernel'): (256, 768) ('TimestepEmbedder_0', 'Dense_1', 'bias'): (768,) ('TimestepEmbedder_0', 'Dense_1', 'kernel'): (768, 768) ('TimestepEmbedder_1', 'Dense_0', 'bias'): (768,) ('TimestepEmbedder_1', 'Dense_0', 'kernel'): (256, 768) ('TimestepEmbedder_1', 'Dense_1', 'bias'): (768,) ('TimestepEmbedder_1', 'Dense_1', 'kernel'): (768, 768) ┌────────────────────────────────────────────────┐ │ │ │ │ │ │ │ │ │ TPU 0,1,2,3 │ │ │ │ │ │ │ │ │ └────────────────────────────────────────────────┘ ┌─────────────────────────────────────────────────────────────────────────┐ │ │ │ │ │ │ │ │ │ TPU 0,1,2,3 │ │ │ │ │ │ │ │ │ └─────────────────────────────────────────────────────────────────────────┘ doing the else (512, 256, 256, 3) encode image shape (128, 256, 256, 3) Initializing encoder. Incoming encoder shape (128, 256, 256, 3) Encoder layer (128, 256, 256, 128) doing downsample Encoder layer (128, 128, 128, 128) doing downsample Encoder layer (128, 64, 64, 256) doing downsample Encoder layer (128, 32, 32, 512) Encoder layer (128, 32, 32, 512) Encoder layer final (128, 32, 32, 512) Encoder layer final (128, 32, 32, 512) Final embeddings are size (128, 32, 32, 8) After quant (128, 32, 32, 4) Calc FID for CFG 1.0 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 z_vectors shape (128, 32, 32, 4) Decoder incoming shape (128, 32, 32, 4) Decoder input (128, 32, 32, 512) Mid Block Decoder layer (128, 32, 32, 512) Mid Block Decoder layer (128, 32, 32, 512) Decoder layer (128, 64, 64, 512) Decoder layer (128, 128, 128, 512) Decoder layer (128, 256, 256, 256) Decoder layer (128, 256, 256, 128) FID is 36.521278381347656 (512, 256, 256, 3) Calc FID for CFG 1.0 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 37.26447677612305 (512, 256, 256, 3) Calc FID for CFG 1.0 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 39.397438049316406 (512, 256, 256, 3) Calc FID for CFG 1.0 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 46.02665710449219 (512, 256, 256, 3) Calc FID for CFG 1.0 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 69.42753601074219 (512, 256, 256, 3) Calc FID for CFG 1.0 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 144.00820922851562 (512, 256, 256, 3) Calc FID for CFG 1.0 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 233.5753631591797 (512, 256, 256, 3) Calc FID for CFG 1.0 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 220.07077026367188 (512, 256, 256, 3) Calc FID for CFG 1.25 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 22.999954223632812 (512, 256, 256, 3) Calc FID for CFG 1.25 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 23.585742950439453 (512, 256, 256, 3) Calc FID for CFG 1.25 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 25.252933502197266 (512, 256, 256, 3) Calc FID for CFG 1.25 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 30.56622314453125 (512, 256, 256, 3) Calc FID for CFG 1.25 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 50.273406982421875 (512, 256, 256, 3) Calc FID for CFG 1.25 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 119.06546020507812 (512, 256, 256, 3) Calc FID for CFG 1.25 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 223.52890014648438 (512, 256, 256, 3) Calc FID for CFG 1.25 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 220.658203125 (512, 256, 256, 3) Calc FID for CFG 1.5 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 15.342796325683594 (512, 256, 256, 3) Calc FID for CFG 1.5 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 15.73843002319336 (512, 256, 256, 3) Calc FID for CFG 1.5 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 16.854110717773438 (512, 256, 256, 3) Calc FID for CFG 1.5 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 20.724889755249023 (512, 256, 256, 3) Calc FID for CFG 1.5 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 36.38886260986328 (512, 256, 256, 3) Calc FID for CFG 1.5 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 97.57992553710938 (512, 256, 256, 3) Calc FID for CFG 1.5 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 215.600341796875 (512, 256, 256, 3) Calc FID for CFG 1.5 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 221.83432006835938 (512, 256, 256, 3) Calc FID for CFG 1.75 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 11.401187896728516 (512, 256, 256, 3) Calc FID for CFG 1.75 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 11.639893531799316 (512, 256, 256, 3) Calc FID for CFG 1.75 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 12.35086727142334 (512, 256, 256, 3) Calc FID for CFG 1.75 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 15.005157470703125 (512, 256, 256, 3) Calc FID for CFG 1.75 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 26.803192138671875 (512, 256, 256, 3) Calc FID for CFG 1.75 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 79.97176361083984 (512, 256, 256, 3) Calc FID for CFG 1.75 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 209.54376220703125 (512, 256, 256, 3) Calc FID for CFG 1.75 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 222.31361389160156 (512, 256, 256, 3) Calc FID for CFG 2.0 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 9.723258972167969 (512, 256, 256, 3) Calc FID for CFG 2.0 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 9.815640449523926 (512, 256, 256, 3) Calc FID for CFG 2.0 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 10.231249809265137 (512, 256, 256, 3) Calc FID for CFG 2.0 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 11.961580276489258 (512, 256, 256, 3) Calc FID for CFG 2.0 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 20.644756317138672 (512, 256, 256, 3) Calc FID for CFG 2.0 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 66.29497528076172 (512, 256, 256, 3) Calc FID for CFG 2.0 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 205.03021240234375 (512, 256, 256, 3) Calc FID for CFG 2.0 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 223.5365447998047 (512, 256, 256, 3) Calc FID for CFG 2.25 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 9.30825424194336 (512, 256, 256, 3) Calc FID for CFG 2.25 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 9.312078475952148 (512, 256, 256, 3) Calc FID for CFG 2.25 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 9.459176063537598 (512, 256, 256, 3) Calc FID for CFG 2.25 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 10.515265464782715 (512, 256, 256, 3) Calc FID for CFG 2.25 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 16.84541893005371 (512, 256, 256, 3) Calc FID for CFG 2.25 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 55.62430191040039 (512, 256, 256, 3) Calc FID for CFG 2.25 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 201.66688537597656 (512, 256, 256, 3) Calc FID for CFG 2.25 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 224.67477416992188 (512, 256, 256, 3) Calc FID for CFG 2.5 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 9.604626655578613 (512, 256, 256, 3) Calc FID for CFG 2.5 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 9.526775360107422 (512, 256, 256, 3) Calc FID for CFG 2.5 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 9.512392044067383 (512, 256, 256, 3) Calc FID for CFG 2.5 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 10.053356170654297 (512, 256, 256, 3) Calc FID for CFG 2.5 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 14.589271545410156 (512, 256, 256, 3) Calc FID for CFG 2.5 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 47.51606750488281 (512, 256, 256, 3) Calc FID for CFG 2.5 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 198.90530395507812 (512, 256, 256, 3) Calc FID for CFG 2.5 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 225.52285766601562 (512, 256, 256, 3) Calc FID for CFG 2.75 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 10.243950843811035 (512, 256, 256, 3) Calc FID for CFG 2.75 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 10.139540672302246 (512, 256, 256, 3) Calc FID for CFG 2.75 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 10.006643295288086 (512, 256, 256, 3) Calc FID for CFG 2.75 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 10.160444259643555 (512, 256, 256, 3) Calc FID for CFG 2.75 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 13.304459571838379 (512, 256, 256, 3) Calc FID for CFG 2.75 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 41.18268585205078 (512, 256, 256, 3) Calc FID for CFG 2.75 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 196.8610382080078 (512, 256, 256, 3) Calc FID for CFG 2.75 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 226.20384216308594 (512, 256, 256, 3) Calc FID for CFG 3.0 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 11.028402328491211 (512, 256, 256, 3) Calc FID for CFG 3.0 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 10.897418975830078 (512, 256, 256, 3) Calc FID for CFG 3.0 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 10.685561180114746 (512, 256, 256, 3) Calc FID for CFG 3.0 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 10.594181060791016 (512, 256, 256, 3) Calc FID for CFG 3.0 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 12.689326286315918 (512, 256, 256, 3) Calc FID for CFG 3.0 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 36.416648864746094 (512, 256, 256, 3) Calc FID for CFG 3.0 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 195.21243286132812 (512, 256, 256, 3) Calc FID for CFG 3.0 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 226.7688751220703