Using devices [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)] Device count 4 Global device count 4 Global Batch: 512 Node Batch: 512 Device Batch: 128 /tmp/tmptvnw60fi Loading dataset Loading dataset creating model beta1: 0.9 beta2: 0.999 bootstrap_cfg: 1 bootstrap_dt_bias: 0 bootstrap_ema: 1 bootstrap_every: 8 cfg_scale: 1.5 class_dropout_prob: 0.1 denoise_timesteps: 128 depth: 12 dropout: 0.0 dt_sampling: uniform hidden_size: 768 lr: 0.0001 mlp_ratio: 4 num_classes: 1000 num_heads: 12 patch_size: 2 sharding: dp t_sampling: discrete-dt target_update_rate: 0.999 train_type: naive use_cosine: 0 use_ema: 0 use_stable_vae: 1 warmup: 0 weight_decay: 0.1 Total devices TPU_0(process=0,(0,0,0,0)) Initializing encoder. Incoming encoder shape (1, 256, 256, 3) Encoder layer (1, 256, 256, 128) doing downsample Encoder layer (1, 128, 128, 128) doing downsample Encoder layer (1, 64, 64, 256) doing downsample Encoder layer (1, 32, 32, 512) Encoder layer (1, 32, 32, 512) Encoder layer final (1, 32, 32, 512) Encoder layer final (1, 32, 32, 512) Final embeddings are size (1, 32, 32, 8) After quant (1, 32, 32, 4) encode finished Decoder incoming shape (1, 32, 32, 4) Decoder input (1, 32, 32, 512) Mid Block Decoder layer (1, 32, 32, 512) Mid Block Decoder layer (1, 32, 32, 512) Decoder layer (1, 64, 64, 512) Decoder layer (1, 128, 128, 512) Decoder layer (1, 256, 256, 256) Decoder layer (1, 256, 256, 128) Total num of VQVAE parameters: 67565323 Disc shape (1, 128, 128, 128) Disc shape (1, 64, 64, 256) Disc shape (1, 32, 32, 512) Disc shape (1, 16, 16, 512) Disc shape (1, 8, 8, 512) Disc shape (1, 4, 4, 512) Total num of Discriminator parameters: 23998017 Loaded checkpoint from 17922328 seconds ago. Loaded model with step 214001 ┌──────────────────────────────────────────────────────────────────────────────┐ │ TPU 0 │ ├──────────────────────────────────────────────────────────────────────────────┤ │ TPU 1 │ ├──────────────────────────────────────────────────────────────────────────────┤ │ TPU 2 │ ├──────────────────────────────────────────────────────────────────────────────┤ │ TPU 3 │ └──────────────────────────────────────────────────────────────────────────────┘ returning model model done Input to vae (4, 1, 256, 256, 3) encode image shape (1, 256, 256, 3) Initializing encoder. Incoming encoder shape (1, 256, 256, 3) Encoder layer (1, 256, 256, 128) doing downsample Encoder layer (1, 128, 128, 128) doing downsample Encoder layer (1, 64, 64, 256) doing downsample Encoder layer (1, 32, 32, 512) Encoder layer (1, 32, 32, 512) Encoder layer final (1, 32, 32, 512) Encoder layer final (1, 32, 32, 512) Final embeddings are size (1, 32, 32, 8) After quant (1, 32, 32, 4) output example shape (4, 1, 32, 32, 4) Test data shape (4, 256, 256, 3) x shape (4, 1, 256, 256, 3) encoded shape (4, 1, 32, 32, 4) z_vectors shape (1, 32, 32, 4) Decoder incoming shape (1, 32, 32, 4) Decoder input (1, 32, 32, 512) Mid Block Decoder layer (1, 32, 32, 512) Mid Block Decoder layer (1, 32, 32, 512) Decoder layer (1, 64, 64, 512) Decoder layer (1, 128, 128, 512) Decoder layer (1, 256, 256, 256) Decoder layer (1, 256, 256, 128) image shape (4, 1, 256, 256, 3) decoded img shape (256, 256, 3) obs shape (4, 32, 32, 4) DiT: Input of shape (4, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (4, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (4, 256, 768) dtype bfloat16 DiT: Conditioning of shape (1, 768) dtype float32  DiT Summary  ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ path  ┃ module  ┃ inputs  ┃ outputs  ┃ params  ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ │ DiT │ - float32[4,32,32,4] │ bfloat16[4,32,32,4] │ │ │ │ │ - float32[1] │ │ │ │ │ │ - float32[1] │ │ │ │ │ │ - int32[1] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ PatchEmbed_0 │ PatchEmbed │ float32[4,32,32,4] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ PatchEmbed_0/Conv_0 │ Conv │ float32[4,32,32,4] │ bfloat16[4,16,16,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[2,2,4,768] │ │ │ │ │ │ │ │ │ │ │ │ 13,056 (52.2 KB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ TimestepEmbedder_0 │ TimestepEmbedder │ float32[1] │ float32[1,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ TimestepEmbedder_0/Dense_0 │ Dense │ bfloat16[1,256] │ bfloat16[1,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[256,768] │ │ │ │ │ │ │ │ │ │ │ │ 197,376 (789.5 KB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ TimestepEmbedder_0/Dense_1 │ Dense │ bfloat16[1,768] │ float32[1,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ TimestepEmbedder_1 │ TimestepEmbedder │ float32[1] │ float32[1,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ TimestepEmbedder_1/Dense_0 │ Dense │ bfloat16[1,256] │ bfloat16[1,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[256,768] │ │ │ │ │ │ │ │ │ │ │ │ 197,376 (789.5 KB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ TimestepEmbedder_1/Dense_1 │ Dense │ bfloat16[1,768] │ float32[1,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ LabelEmbedder_0 │ LabelEmbedder │ int32[1] │ bfloat16[1,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ LabelEmbedder_0/Embed_0 │ Embed │ int32[1] │ bfloat16[1,768] │ embedding: float32[1001,768] │ │ │ │ │ │ │ │ │ │ │ │ 768,768 (3.1 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ FinalLayer_0 │ FinalLayer │ - bfloat16[4,256,768] │ bfloat16[4,256,16] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ FinalLayer_0/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,1536] │ bias: float32[1536] │ │ │ │ │ │ kernel: float32[768,1536] │ │ │ │ │ │ │ │ │ │ │ │ 1,181,184 (4.7 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ FinalLayer_0/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ FinalLayer_0/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,16] │ bias: float32[16] │ │ │ │ │ │ kernel: float32[768,16] │ │ │ │ │ │ │ │ │ │ │ │ 12,304 (49.2 KB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ Embed_0 │ Embed │ int32[1] │ float32[1,1] │ embedding: float32[256,1] │ │ │ │ │ │ │ │ │ │ │ │ 256 (1.0 KB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │   │   │   │  Total │ 131,091,728 (524.4 MB)  │ └──────────────────────────────────┴──────────────────┴───────────────────────┴───────────────────────┴──────────────────────────────┘    Total Parameters: 131,091,728 (524.4 MB)  DiT: Input of shape (4, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (4, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (4, 256, 768) dtype bfloat16 DiT: Conditioning of shape (1, 768) dtype float32 Loaded checkpoint from 8454 seconds ago. parameter shapes: ('PatchEmbed_0', 'Conv_0', 'kernel'): (2, 2, 4, 768) ('PatchEmbed_0', 'Conv_0', 'bias'): (768,) ('TimestepEmbedder_0', 'Dense_0', 'kernel'): (256, 768) ('TimestepEmbedder_0', 'Dense_0', 'bias'): (768,) ('TimestepEmbedder_0', 'Dense_1', 'kernel'): (768, 768) ('TimestepEmbedder_0', 'Dense_1', 'bias'): (768,) ('TimestepEmbedder_1', 'Dense_0', 'kernel'): (256, 768) ('TimestepEmbedder_1', 'Dense_0', 'bias'): (768,) ('TimestepEmbedder_1', 'Dense_1', 'kernel'): (768, 768) ('TimestepEmbedder_1', 'Dense_1', 'bias'): (768,) ('LabelEmbedder_0', 'Embed_0', 'embedding'): (1001, 768) ('DiTBlock_0', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_0', 'Dense_0', 'bias'): (4608,) ('DiTBlock_0', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_0', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_0', 'Dense_2', 'bias'): (768,) ('DiTBlock_0', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_0', 'Dense_3', 'bias'): (768,) ('DiTBlock_0', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_0', 'Dense_4', 'bias'): (768,) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_1', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_1', 'Dense_0', 'bias'): (4608,) ('DiTBlock_1', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_1', 'Dense_1', 'bias'): (768,) ('DiTBlock_1', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_1', 'Dense_2', 'bias'): (768,) ('DiTBlock_1', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_1', 'Dense_3', 'bias'): (768,) ('DiTBlock_1', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_1', 'Dense_4', 'bias'): (768,) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_2', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_2', 'Dense_0', 'bias'): (4608,) ('DiTBlock_2', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_2', 'Dense_1', 'bias'): (768,) ('DiTBlock_2', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_2', 'Dense_2', 'bias'): (768,) ('DiTBlock_2', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_2', 'Dense_3', 'bias'): (768,) ('DiTBlock_2', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_2', 'Dense_4', 'bias'): (768,) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_3', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_3', 'Dense_0', 'bias'): (4608,) ('DiTBlock_3', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_3', 'Dense_1', 'bias'): (768,) ('DiTBlock_3', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_3', 'Dense_2', 'bias'): (768,) ('DiTBlock_3', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_3', 'Dense_3', 'bias'): (768,) ('DiTBlock_3', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_3', 'Dense_4', 'bias'): (768,) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_4', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_4', 'Dense_0', 'bias'): (4608,) ('DiTBlock_4', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_4', 'Dense_1', 'bias'): (768,) ('DiTBlock_4', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_4', 'Dense_2', 'bias'): (768,) ('DiTBlock_4', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_4', 'Dense_3', 'bias'): (768,) ('DiTBlock_4', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_4', 'Dense_4', 'bias'): (768,) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_5', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_5', 'Dense_0', 'bias'): (4608,) ('DiTBlock_5', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_5', 'Dense_1', 'bias'): (768,) ('DiTBlock_5', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_5', 'Dense_2', 'bias'): (768,) ('DiTBlock_5', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_5', 'Dense_3', 'bias'): (768,) ('DiTBlock_5', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_5', 'Dense_4', 'bias'): (768,) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_6', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_6', 'Dense_0', 'bias'): (4608,) ('DiTBlock_6', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_6', 'Dense_1', 'bias'): (768,) ('DiTBlock_6', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_6', 'Dense_2', 'bias'): (768,) ('DiTBlock_6', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_6', 'Dense_3', 'bias'): (768,) ('DiTBlock_6', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_6', 'Dense_4', 'bias'): (768,) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_7', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_7', 'Dense_0', 'bias'): (4608,) ('DiTBlock_7', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_7', 'Dense_1', 'bias'): (768,) ('DiTBlock_7', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_7', 'Dense_2', 'bias'): (768,) ('DiTBlock_7', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_7', 'Dense_3', 'bias'): (768,) ('DiTBlock_7', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_7', 'Dense_4', 'bias'): (768,) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_8', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_8', 'Dense_0', 'bias'): (4608,) ('DiTBlock_8', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_8', 'Dense_1', 'bias'): (768,) ('DiTBlock_8', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_8', 'Dense_2', 'bias'): (768,) ('DiTBlock_8', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_8', 'Dense_3', 'bias'): (768,) ('DiTBlock_8', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_8', 'Dense_4', 'bias'): (768,) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_9', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_9', 'Dense_0', 'bias'): (4608,) ('DiTBlock_9', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_9', 'Dense_1', 'bias'): (768,) ('DiTBlock_9', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_9', 'Dense_2', 'bias'): (768,) ('DiTBlock_9', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_9', 'Dense_3', 'bias'): (768,) ('DiTBlock_9', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_9', 'Dense_4', 'bias'): (768,) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_10', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_10', 'Dense_0', 'bias'): (4608,) ('DiTBlock_10', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_10', 'Dense_1', 'bias'): (768,) ('DiTBlock_10', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_10', 'Dense_2', 'bias'): (768,) ('DiTBlock_10', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_10', 'Dense_3', 'bias'): (768,) ('DiTBlock_10', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_10', 'Dense_4', 'bias'): (768,) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_11', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_11', 'Dense_0', 'bias'): (4608,) ('DiTBlock_11', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_11', 'Dense_1', 'bias'): (768,) ('DiTBlock_11', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_11', 'Dense_2', 'bias'): (768,) ('DiTBlock_11', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_11', 'Dense_3', 'bias'): (768,) ('DiTBlock_11', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_11', 'Dense_4', 'bias'): (768,) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('FinalLayer_0', 'Dense_0', 'kernel'): (768, 1536) ('FinalLayer_0', 'Dense_0', 'bias'): (1536,) ('FinalLayer_0', 'Dense_1', 'kernel'): (768, 16) ('FinalLayer_0', 'Dense_1', 'bias'): (16,) ('Embed_0', 'embedding'): (256, 1) parameter shapes: ('DiTBlock_0', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_0', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_1', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_1', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_1', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_10', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_10', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_10', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_11', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_11', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_11', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_2', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_2', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_2', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_3', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_3', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_3', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_4', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_4', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_4', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_5', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_5', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_5', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_6', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_6', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_6', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_7', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_7', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_7', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_8', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_8', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_8', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_9', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_9', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_9', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('Embed_0', 'embedding'): (1, 256, 1) ('FinalLayer_0', 'Dense_0', 'bias'): (1, 1536) ('FinalLayer_0', 'Dense_0', 'kernel'): (1, 768, 1536) ('FinalLayer_0', 'Dense_1', 'bias'): (1, 16) ('FinalLayer_0', 'Dense_1', 'kernel'): (1, 768, 16) ('LabelEmbedder_0', 'Embed_0', 'embedding'): (1, 1001, 768) ('PatchEmbed_0', 'Conv_0', 'bias'): (1, 768) ('PatchEmbed_0', 'Conv_0', 'kernel'): (1, 2, 2, 4, 768) ('TimestepEmbedder_0', 'Dense_0', 'bias'): (1, 768) ('TimestepEmbedder_0', 'Dense_0', 'kernel'): (1, 256, 768) ('TimestepEmbedder_0', 'Dense_1', 'bias'): (1, 768) ('TimestepEmbedder_0', 'Dense_1', 'kernel'): (1, 768, 768) ('TimestepEmbedder_1', 'Dense_0', 'bias'): (1, 768) ('TimestepEmbedder_1', 'Dense_0', 'kernel'): (1, 256, 768) ('TimestepEmbedder_1', 'Dense_1', 'bias'): (1, 768) ('TimestepEmbedder_1', 'Dense_1', 'kernel'): (1, 768, 768) parameter shapes: ('DiTBlock_0', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_0', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_1', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_1', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_1', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_10', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_10', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_10', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_11', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_11', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_11', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_2', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_2', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_2', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_3', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_3', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_3', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_4', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_4', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_4', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_5', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_5', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_5', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_6', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_6', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_6', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_7', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_7', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_7', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_8', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_8', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_8', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_9', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_9', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_9', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('Embed_0', 'embedding'): (1, 256, 1) ('FinalLayer_0', 'Dense_0', 'bias'): (1, 1536) ('FinalLayer_0', 'Dense_0', 'kernel'): (1, 768, 1536) ('FinalLayer_0', 'Dense_1', 'bias'): (1, 16) ('FinalLayer_0', 'Dense_1', 'kernel'): (1, 768, 16) ('LabelEmbedder_0', 'Embed_0', 'embedding'): (1, 1001, 768) ('PatchEmbed_0', 'Conv_0', 'bias'): (1, 768) ('PatchEmbed_0', 'Conv_0', 'kernel'): (1, 2, 2, 4, 768) ('TimestepEmbedder_0', 'Dense_0', 'bias'): (1, 768) ('TimestepEmbedder_0', 'Dense_0', 'kernel'): (1, 256, 768) ('TimestepEmbedder_0', 'Dense_1', 'bias'): (1, 768) ('TimestepEmbedder_0', 'Dense_1', 'kernel'): (1, 768, 768) ('TimestepEmbedder_1', 'Dense_0', 'bias'): (1, 768) ('TimestepEmbedder_1', 'Dense_0', 'kernel'): (1, 256, 768) ('TimestepEmbedder_1', 'Dense_1', 'bias'): (1, 768) ('TimestepEmbedder_1', 'Dense_1', 'kernel'): (1, 768, 768) parameter shapes: ('DiTBlock_0', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_0', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_1', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_1', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_1', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_10', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_10', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_10', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_11', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_11', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_11', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_2', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_2', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_2', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_3', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_3', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_3', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_4', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_4', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_4', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_5', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_5', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_5', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_6', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_6', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_6', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_7', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_7', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_7', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_8', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_8', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_8', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_9', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_9', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_9', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('Embed_0', 'embedding'): (1, 256, 1) ('FinalLayer_0', 'Dense_0', 'bias'): (1, 1536) ('FinalLayer_0', 'Dense_0', 'kernel'): (1, 768, 1536) ('FinalLayer_0', 'Dense_1', 'bias'): (1, 16) ('FinalLayer_0', 'Dense_1', 'kernel'): (1, 768, 16) ('LabelEmbedder_0', 'Embed_0', 'embedding'): (1, 1001, 768) ('PatchEmbed_0', 'Conv_0', 'bias'): (1, 768) ('PatchEmbed_0', 'Conv_0', 'kernel'): (1, 2, 2, 4, 768) ('TimestepEmbedder_0', 'Dense_0', 'bias'): (1, 768) ('TimestepEmbedder_0', 'Dense_0', 'kernel'): (1, 256, 768) ('TimestepEmbedder_0', 'Dense_1', 'bias'): (1, 768) ('TimestepEmbedder_0', 'Dense_1', 'kernel'): (1, 768, 768) ('TimestepEmbedder_1', 'Dense_0', 'bias'): (1, 768) ('TimestepEmbedder_1', 'Dense_0', 'kernel'): (1, 256, 768) ('TimestepEmbedder_1', 'Dense_1', 'bias'): (1, 768) ('TimestepEmbedder_1', 'Dense_1', 'kernel'): (1, 768, 768) parameter shapes: ('DiTBlock_0', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_0', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_1', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_1', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_1', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_10', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_10', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_10', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_11', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_11', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_11', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_2', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_2', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_2', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_3', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_3', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_3', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_4', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_4', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_4', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_5', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_5', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_5', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_6', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_6', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_6', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_7', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_7', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_7', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_8', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_8', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_8', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_9', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_9', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_9', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('Embed_0', 'embedding'): (1, 256, 1) ('FinalLayer_0', 'Dense_0', 'bias'): (1, 1536) ('FinalLayer_0', 'Dense_0', 'kernel'): (1, 768, 1536) ('FinalLayer_0', 'Dense_1', 'bias'): (1, 16) ('FinalLayer_0', 'Dense_1', 'kernel'): (1, 768, 16) ('LabelEmbedder_0', 'Embed_0', 'embedding'): (1, 1001, 768) ('PatchEmbed_0', 'Conv_0', 'bias'): (1, 768) ('PatchEmbed_0', 'Conv_0', 'kernel'): (1, 2, 2, 4, 768) ('TimestepEmbedder_0', 'Dense_0', 'bias'): (1, 768) ('TimestepEmbedder_0', 'Dense_0', 'kernel'): (1, 256, 768) ('TimestepEmbedder_0', 'Dense_1', 'bias'): (1, 768) ('TimestepEmbedder_0', 'Dense_1', 'kernel'): (1, 768, 768) ('TimestepEmbedder_1', 'Dense_0', 'bias'): (1, 768) ('TimestepEmbedder_1', 'Dense_0', 'kernel'): (1, 256, 768) ('TimestepEmbedder_1', 'Dense_1', 'bias'): (1, 768) ('TimestepEmbedder_1', 'Dense_1', 'kernel'): (1, 768, 768) parameter shapes: ('DiTBlock_0', 'Dense_0', 'bias'): (4608,) ('DiTBlock_0', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_0', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_0', 'Dense_2', 'bias'): (768,) ('DiTBlock_0', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_0', 'Dense_3', 'bias'): (768,) ('DiTBlock_0', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_0', 'Dense_4', 'bias'): (768,) ('DiTBlock_0', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_1', 'Dense_0', 'bias'): (4608,) ('DiTBlock_1', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_1', 'Dense_1', 'bias'): (768,) ('DiTBlock_1', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_1', 'Dense_2', 'bias'): (768,) ('DiTBlock_1', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_1', 'Dense_3', 'bias'): (768,) ('DiTBlock_1', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_1', 'Dense_4', 'bias'): (768,) ('DiTBlock_1', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_10', 'Dense_0', 'bias'): (4608,) ('DiTBlock_10', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_10', 'Dense_1', 'bias'): (768,) ('DiTBlock_10', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_10', 'Dense_2', 'bias'): (768,) ('DiTBlock_10', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_10', 'Dense_3', 'bias'): (768,) ('DiTBlock_10', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_10', 'Dense_4', 'bias'): (768,) ('DiTBlock_10', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_11', 'Dense_0', 'bias'): (4608,) ('DiTBlock_11', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_11', 'Dense_1', 'bias'): (768,) ('DiTBlock_11', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_11', 'Dense_2', 'bias'): (768,) ('DiTBlock_11', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_11', 'Dense_3', 'bias'): (768,) ('DiTBlock_11', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_11', 'Dense_4', 'bias'): (768,) ('DiTBlock_11', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_2', 'Dense_0', 'bias'): (4608,) ('DiTBlock_2', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_2', 'Dense_1', 'bias'): (768,) ('DiTBlock_2', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_2', 'Dense_2', 'bias'): (768,) ('DiTBlock_2', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_2', 'Dense_3', 'bias'): (768,) ('DiTBlock_2', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_2', 'Dense_4', 'bias'): (768,) ('DiTBlock_2', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_3', 'Dense_0', 'bias'): (4608,) ('DiTBlock_3', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_3', 'Dense_1', 'bias'): (768,) ('DiTBlock_3', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_3', 'Dense_2', 'bias'): (768,) ('DiTBlock_3', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_3', 'Dense_3', 'bias'): (768,) ('DiTBlock_3', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_3', 'Dense_4', 'bias'): (768,) ('DiTBlock_3', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_4', 'Dense_0', 'bias'): (4608,) ('DiTBlock_4', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_4', 'Dense_1', 'bias'): (768,) ('DiTBlock_4', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_4', 'Dense_2', 'bias'): (768,) ('DiTBlock_4', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_4', 'Dense_3', 'bias'): (768,) ('DiTBlock_4', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_4', 'Dense_4', 'bias'): (768,) ('DiTBlock_4', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_5', 'Dense_0', 'bias'): (4608,) ('DiTBlock_5', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_5', 'Dense_1', 'bias'): (768,) ('DiTBlock_5', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_5', 'Dense_2', 'bias'): (768,) ('DiTBlock_5', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_5', 'Dense_3', 'bias'): (768,) ('DiTBlock_5', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_5', 'Dense_4', 'bias'): (768,) ('DiTBlock_5', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_6', 'Dense_0', 'bias'): (4608,) ('DiTBlock_6', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_6', 'Dense_1', 'bias'): (768,) ('DiTBlock_6', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_6', 'Dense_2', 'bias'): (768,) ('DiTBlock_6', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_6', 'Dense_3', 'bias'): (768,) ('DiTBlock_6', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_6', 'Dense_4', 'bias'): (768,) ('DiTBlock_6', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_7', 'Dense_0', 'bias'): (4608,) ('DiTBlock_7', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_7', 'Dense_1', 'bias'): (768,) ('DiTBlock_7', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_7', 'Dense_2', 'bias'): (768,) ('DiTBlock_7', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_7', 'Dense_3', 'bias'): (768,) ('DiTBlock_7', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_7', 'Dense_4', 'bias'): (768,) ('DiTBlock_7', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_8', 'Dense_0', 'bias'): (4608,) ('DiTBlock_8', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_8', 'Dense_1', 'bias'): (768,) ('DiTBlock_8', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_8', 'Dense_2', 'bias'): (768,) ('DiTBlock_8', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_8', 'Dense_3', 'bias'): (768,) ('DiTBlock_8', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_8', 'Dense_4', 'bias'): (768,) ('DiTBlock_8', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_9', 'Dense_0', 'bias'): (4608,) ('DiTBlock_9', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_9', 'Dense_1', 'bias'): (768,) ('DiTBlock_9', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_9', 'Dense_2', 'bias'): (768,) ('DiTBlock_9', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_9', 'Dense_3', 'bias'): (768,) ('DiTBlock_9', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_9', 'Dense_4', 'bias'): (768,) ('DiTBlock_9', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('Embed_0', 'embedding'): (256, 1) ('FinalLayer_0', 'Dense_0', 'bias'): (1536,) ('FinalLayer_0', 'Dense_0', 'kernel'): (768, 1536) ('FinalLayer_0', 'Dense_1', 'bias'): (16,) ('FinalLayer_0', 'Dense_1', 'kernel'): (768, 16) ('LabelEmbedder_0', 'Embed_0', 'embedding'): (1001, 768) ('PatchEmbed_0', 'Conv_0', 'bias'): (768,) ('PatchEmbed_0', 'Conv_0', 'kernel'): (2, 2, 4, 768) ('TimestepEmbedder_0', 'Dense_0', 'bias'): (768,) ('TimestepEmbedder_0', 'Dense_0', 'kernel'): (256, 768) ('TimestepEmbedder_0', 'Dense_1', 'bias'): (768,) ('TimestepEmbedder_0', 'Dense_1', 'kernel'): (768, 768) ('TimestepEmbedder_1', 'Dense_0', 'bias'): (768,) ('TimestepEmbedder_1', 'Dense_0', 'kernel'): (256, 768) ('TimestepEmbedder_1', 'Dense_1', 'bias'): (768,) ('TimestepEmbedder_1', 'Dense_1', 'kernel'): (768, 768) ┌────────────────────────────────────────────────┐ │ │ │ │ │ │ │ │ │ TPU 0,1,2,3 │ │ │ │ │ │ │ │ │ └────────────────────────────────────────────────┘ ┌─────────────────────────────────────────────────────────────────────────┐ │ │ │ │ │ │ │ │ │ TPU 0,1,2,3 │ │ │ │ │ │ │ │ │ └─────────────────────────────────────────────────────────────────────────┘ doing the else (512, 256, 256, 3) encode image shape (128, 256, 256, 3) Initializing encoder. Incoming encoder shape (128, 256, 256, 3) Encoder layer (128, 256, 256, 128) doing downsample Encoder layer (128, 128, 128, 128) doing downsample Encoder layer (128, 64, 64, 256) doing downsample Encoder layer (128, 32, 32, 512) Encoder layer (128, 32, 32, 512) Encoder layer final (128, 32, 32, 512) Encoder layer final (128, 32, 32, 512) Final embeddings are size (128, 32, 32, 8) After quant (128, 32, 32, 4) Calc FID for CFG 1.0 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 z_vectors shape (128, 32, 32, 4) Decoder incoming shape (128, 32, 32, 4) Decoder input (128, 32, 32, 512) Mid Block Decoder layer (128, 32, 32, 512) Mid Block Decoder layer (128, 32, 32, 512) Decoder layer (128, 64, 64, 512) Decoder layer (128, 128, 128, 512) Decoder layer (128, 256, 256, 256) Decoder layer (128, 256, 256, 128) FID is 36.508827209472656 (512, 256, 256, 3) Calc FID for CFG 1.0 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 37.08784484863281 (512, 256, 256, 3) Calc FID for CFG 1.0 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 38.94371032714844 (512, 256, 256, 3) Calc FID for CFG 1.0 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 45.21038055419922 (512, 256, 256, 3) Calc FID for CFG 1.0 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 68.17457580566406 (512, 256, 256, 3) Calc FID for CFG 1.0 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 156.36541748046875 (512, 256, 256, 3) Calc FID for CFG 1.0 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 351.53521728515625 (512, 256, 256, 3) Calc FID for CFG 1.0 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 272.97235107421875 (512, 256, 256, 3) Calc FID for CFG 1.25 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 22.692970275878906 (512, 256, 256, 3) Calc FID for CFG 1.25 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 23.221485137939453 (512, 256, 256, 3) Calc FID for CFG 1.25 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 24.819042205810547 (512, 256, 256, 3) Calc FID for CFG 1.25 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 30.26287078857422 (512, 256, 256, 3) Calc FID for CFG 1.25 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 50.48595428466797 (512, 256, 256, 3) Calc FID for CFG 1.25 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 131.1754150390625 (512, 256, 256, 3) Calc FID for CFG 1.25 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 328.754150390625 (512, 256, 256, 3) Calc FID for CFG 1.25 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 266.94677734375 (512, 256, 256, 3) Calc FID for CFG 1.5 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 14.32100772857666 (512, 256, 256, 3) Calc FID for CFG 1.5 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 14.753721237182617 (512, 256, 256, 3) Calc FID for CFG 1.5 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 16.01094627380371 (512, 256, 256, 3) Calc FID for CFG 1.5 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 20.28468894958496 (512, 256, 256, 3) Calc FID for CFG 1.5 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 36.810386657714844 (512, 256, 256, 3) Calc FID for CFG 1.5 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 108.87908935546875 (512, 256, 256, 3) Calc FID for CFG 1.5 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 308.9098815917969 (512, 256, 256, 3) Calc FID for CFG 1.5 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 262.80682373046875 (512, 256, 256, 3) Calc FID for CFG 1.75 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 9.926182746887207 (512, 256, 256, 3) Calc FID for CFG 1.75 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 10.220373153686523 (512, 256, 256, 3) Calc FID for CFG 1.75 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 11.12224292755127 (512, 256, 256, 3) Calc FID for CFG 1.75 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 14.252967834472656 (512, 256, 256, 3) Calc FID for CFG 1.75 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 27.18169403076172 (512, 256, 256, 3) Calc FID for CFG 1.75 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 90.34546661376953 (512, 256, 256, 3) Calc FID for CFG 1.75 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 292.087646484375 (512, 256, 256, 3) Calc FID for CFG 1.75 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 259.931396484375 (512, 256, 256, 3) Calc FID for CFG 2.0 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 8.024701118469238 (512, 256, 256, 3) Calc FID for CFG 2.0 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 8.21211051940918 (512, 256, 256, 3) Calc FID for CFG 2.0 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 8.804532051086426 (512, 256, 256, 3) Calc FID for CFG 2.0 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 10.978399276733398 (512, 256, 256, 3) Calc FID for CFG 2.0 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 20.80478286743164 (512, 256, 256, 3) Calc FID for CFG 2.0 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 75.16299438476562 (512, 256, 256, 3) Calc FID for CFG 2.0 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 277.906005859375 (512, 256, 256, 3) Calc FID for CFG 2.0 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 257.190673828125 (512, 256, 256, 3) Calc FID for CFG 2.25 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 7.494307041168213 (512, 256, 256, 3) Calc FID for CFG 2.25 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 7.625980377197266 (512, 256, 256, 3) Calc FID for CFG 2.25 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 8.020242691040039 (512, 256, 256, 3) Calc FID for CFG 2.25 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 9.495993614196777 (512, 256, 256, 3) Calc FID for CFG 2.25 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 16.77423667907715 (512, 256, 256, 3) Calc FID for CFG 2.25 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 62.92259979248047 (512, 256, 256, 3) Calc FID for CFG 2.25 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 265.89666748046875 (512, 256, 256, 3) Calc FID for CFG 2.25 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 254.52099609375 (512, 256, 256, 3) Calc FID for CFG 2.5 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 7.730135440826416 (512, 256, 256, 3) Calc FID for CFG 2.5 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 7.801293849945068 (512, 256, 256, 3) Calc FID for CFG 2.5 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 8.043386459350586 (512, 256, 256, 3) Calc FID for CFG 2.5 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 9.051283836364746 (512, 256, 256, 3) Calc FID for CFG 2.5 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 14.387025833129883 (512, 256, 256, 3) Calc FID for CFG 2.5 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 53.24984359741211 (512, 256, 256, 3) Calc FID for CFG 2.5 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 255.63644409179688 (512, 256, 256, 3) Calc FID for CFG 2.5 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 252.06878662109375 (512, 256, 256, 3) Calc FID for CFG 2.75 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 8.402754783630371 (512, 256, 256, 3) Calc FID for CFG 2.75 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 8.434297561645508 (512, 256, 256, 3) Calc FID for CFG 2.75 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 8.552092552185059 (512, 256, 256, 3) Calc FID for CFG 2.75 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 9.209967613220215 (512, 256, 256, 3) Calc FID for CFG 2.75 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 13.06956672668457 (512, 256, 256, 3) Calc FID for CFG 2.75 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 45.642677307128906 (512, 256, 256, 3) Calc FID for CFG 2.75 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 246.61810302734375 (512, 256, 256, 3) Calc FID for CFG 2.75 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 249.8270263671875 (512, 256, 256, 3) Calc FID for CFG 3.0 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 9.276342391967773 (512, 256, 256, 3) Calc FID for CFG 3.0 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 9.26561450958252 (512, 256, 256, 3) Calc FID for CFG 3.0 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 9.315662384033203 (512, 256, 256, 3) Calc FID for CFG 3.0 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 9.7174072265625 (512, 256, 256, 3) Calc FID for CFG 3.0 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 12.461471557617188 (512, 256, 256, 3) Calc FID for CFG 3.0 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 39.617332458496094 (512, 256, 256, 3) Calc FID for CFG 3.0 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 238.69842529296875 (512, 256, 256, 3) Calc FID for CFG 3.0 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 247.69577026367188 wandb: wandb: 🚀 View run shortcut_imagenet256 at: https://wandb.ai/daniel-z-kaplan/shortcut/runs/shortcut_imagenet256_20250816_140034_345353_10 wandb: Find logs at: ../../../tmp/tmptvnw60fi/wandb/run-20250816_140034-shortcut_imagenet256_20250816_140034_345353_10/logs