Using devices [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)] Device count 4 Global device count 4 Global Batch: 512 Node Batch: 512 Device Batch: 128 /tmp/tmpu25f0_2j Loading dataset Loading dataset creating model beta1: 0.9 beta2: 0.999 bootstrap_cfg: 1 bootstrap_dt_bias: 0 bootstrap_ema: 1 bootstrap_every: 8 cfg_scale: 1.5 class_dropout_prob: 0.1 denoise_timesteps: 128 depth: 12 dropout: 0.0 dt_sampling: uniform hidden_size: 768 lr: 0.0001 mlp_ratio: 4 num_classes: 1000 num_heads: 12 patch_size: 2 sharding: dp t_sampling: discrete-dt target_update_rate: 0.999 train_type: naive use_cosine: 0 use_ema: 0 use_stable_vae: 1 warmup: 0 weight_decay: 0.1 Total devices TPU_0(process=0,(0,0,0,0)) Initializing encoder. Incoming encoder shape (1, 256, 256, 3) Encoder layer (1, 256, 256, 128) doing downsample Encoder layer (1, 128, 128, 128) doing downsample Encoder layer (1, 64, 64, 256) doing downsample Encoder layer (1, 32, 32, 512) Encoder layer (1, 32, 32, 512) Encoder layer final (1, 32, 32, 512) Encoder layer final (1, 32, 32, 512) Final embeddings are size (1, 32, 32, 8) After quant (1, 32, 32, 4) encode finished Decoder incoming shape (1, 32, 32, 4) Decoder input (1, 32, 32, 512) Mid Block Decoder layer (1, 32, 32, 512) Mid Block Decoder layer (1, 32, 32, 512) Decoder layer (1, 64, 64, 512) Decoder layer (1, 128, 128, 512) Decoder layer (1, 256, 256, 256) Decoder layer (1, 256, 256, 128) Total num of VQVAE parameters: 67565323 Disc shape (1, 128, 128, 128) Disc shape (1, 64, 64, 256) Disc shape (1, 32, 32, 512) Disc shape (1, 16, 16, 512) Disc shape (1, 8, 8, 512) Disc shape (1, 4, 4, 512) Total num of Discriminator parameters: 23998017 Loaded checkpoint from 17487837 seconds ago. Loaded model with step 498001 ┌──────────────────────────────────────────────────────────────────────────────┐ │ TPU 0 │ ├──────────────────────────────────────────────────────────────────────────────┤ │ TPU 1 │ ├──────────────────────────────────────────────────────────────────────────────┤ │ TPU 2 │ ├──────────────────────────────────────────────────────────────────────────────┤ │ TPU 3 │ └──────────────────────────────────────────────────────────────────────────────┘ returning model model done Input to vae (4, 1, 256, 256, 3) encode image shape (1, 256, 256, 3) Initializing encoder. Incoming encoder shape (1, 256, 256, 3) Encoder layer (1, 256, 256, 128) doing downsample Encoder layer (1, 128, 128, 128) doing downsample Encoder layer (1, 64, 64, 256) doing downsample Encoder layer (1, 32, 32, 512) Encoder layer (1, 32, 32, 512) Encoder layer final (1, 32, 32, 512) Encoder layer final (1, 32, 32, 512) Final embeddings are size (1, 32, 32, 8) After quant (1, 32, 32, 4) output example shape (4, 1, 32, 32, 4) Test data shape (4, 256, 256, 3) x shape (4, 1, 256, 256, 3) encoded shape (4, 1, 32, 32, 4) z_vectors shape (1, 32, 32, 4) Decoder incoming shape (1, 32, 32, 4) Decoder input (1, 32, 32, 512) Mid Block Decoder layer (1, 32, 32, 512) Mid Block Decoder layer (1, 32, 32, 512) Decoder layer (1, 64, 64, 512) Decoder layer (1, 128, 128, 512) Decoder layer (1, 256, 256, 256) Decoder layer (1, 256, 256, 128) image shape (4, 1, 256, 256, 3) decoded img shape (256, 256, 3) obs shape (4, 32, 32, 4) DiT: Input of shape (4, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (4, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (4, 256, 768) dtype bfloat16 DiT: Conditioning of shape (1, 768) dtype float32  DiT Summary  ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ path  ┃ module  ┃ inputs  ┃ outputs  ┃ params  ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ │ DiT │ - float32[4,32,32,4] │ bfloat16[4,32,32,4] │ │ │ │ │ - float32[1] │ │ │ │ │ │ - float32[1] │ │ │ │ │ │ - int32[1] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ PatchEmbed_0 │ PatchEmbed │ float32[4,32,32,4] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ PatchEmbed_0/Conv_0 │ Conv │ float32[4,32,32,4] │ bfloat16[4,16,16,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[2,2,4,768] │ │ │ │ │ │ │ │ │ │ │ │ 13,056 (52.2 KB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ TimestepEmbedder_0 │ TimestepEmbedder │ float32[1] │ float32[1,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ TimestepEmbedder_0/Dense_0 │ Dense │ bfloat16[1,256] │ bfloat16[1,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[256,768] │ │ │ │ │ │ │ │ │ │ │ │ 197,376 (789.5 KB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ TimestepEmbedder_0/Dense_1 │ Dense │ bfloat16[1,768] │ float32[1,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ TimestepEmbedder_1 │ TimestepEmbedder │ float32[1] │ float32[1,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ TimestepEmbedder_1/Dense_0 │ Dense │ bfloat16[1,256] │ bfloat16[1,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[256,768] │ │ │ │ │ │ │ │ │ │ │ │ 197,376 (789.5 KB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ TimestepEmbedder_1/Dense_1 │ Dense │ bfloat16[1,768] │ float32[1,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ LabelEmbedder_0 │ LabelEmbedder │ int32[1] │ bfloat16[1,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ LabelEmbedder_0/Embed_0 │ Embed │ int32[1] │ bfloat16[1,768] │ embedding: float32[1001,768] │ │ │ │ │ │ │ │ │ │ │ │ 768,768 (3.1 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ FinalLayer_0 │ FinalLayer │ - bfloat16[4,256,768] │ bfloat16[4,256,16] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ FinalLayer_0/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,1536] │ bias: float32[1536] │ │ │ │ │ │ kernel: float32[768,1536] │ │ │ │ │ │ │ │ │ │ │ │ 1,181,184 (4.7 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ FinalLayer_0/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ FinalLayer_0/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,16] │ bias: float32[16] │ │ │ │ │ │ kernel: float32[768,16] │ │ │ │ │ │ │ │ │ │ │ │ 12,304 (49.2 KB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ Embed_0 │ Embed │ int32[1] │ float32[1,1] │ embedding: float32[256,1] │ │ │ │ │ │ │ │ │ │ │ │ 256 (1.0 KB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │   │   │   │  Total │ 131,091,728 (524.4 MB)  │ └──────────────────────────────────┴──────────────────┴───────────────────────┴───────────────────────┴──────────────────────────────┘    Total Parameters: 131,091,728 (524.4 MB)  DiT: Input of shape (4, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (4, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (4, 256, 768) dtype bfloat16 DiT: Conditioning of shape (1, 768) dtype float32 Loaded checkpoint from 2115 seconds ago. parameter shapes: ('PatchEmbed_0', 'Conv_0', 'kernel'): (2, 2, 4, 768) ('PatchEmbed_0', 'Conv_0', 'bias'): (768,) ('TimestepEmbedder_0', 'Dense_0', 'kernel'): (256, 768) ('TimestepEmbedder_0', 'Dense_0', 'bias'): (768,) ('TimestepEmbedder_0', 'Dense_1', 'kernel'): (768, 768) ('TimestepEmbedder_0', 'Dense_1', 'bias'): (768,) ('TimestepEmbedder_1', 'Dense_0', 'kernel'): (256, 768) ('TimestepEmbedder_1', 'Dense_0', 'bias'): (768,) ('TimestepEmbedder_1', 'Dense_1', 'kernel'): (768, 768) ('TimestepEmbedder_1', 'Dense_1', 'bias'): (768,) ('LabelEmbedder_0', 'Embed_0', 'embedding'): (1001, 768) ('DiTBlock_0', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_0', 'Dense_0', 'bias'): (4608,) ('DiTBlock_0', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_0', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_0', 'Dense_2', 'bias'): (768,) ('DiTBlock_0', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_0', 'Dense_3', 'bias'): (768,) ('DiTBlock_0', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_0', 'Dense_4', 'bias'): (768,) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_1', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_1', 'Dense_0', 'bias'): (4608,) ('DiTBlock_1', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_1', 'Dense_1', 'bias'): (768,) ('DiTBlock_1', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_1', 'Dense_2', 'bias'): (768,) ('DiTBlock_1', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_1', 'Dense_3', 'bias'): (768,) ('DiTBlock_1', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_1', 'Dense_4', 'bias'): (768,) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_2', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_2', 'Dense_0', 'bias'): (4608,) ('DiTBlock_2', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_2', 'Dense_1', 'bias'): (768,) ('DiTBlock_2', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_2', 'Dense_2', 'bias'): (768,) ('DiTBlock_2', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_2', 'Dense_3', 'bias'): (768,) ('DiTBlock_2', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_2', 'Dense_4', 'bias'): (768,) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_3', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_3', 'Dense_0', 'bias'): (4608,) ('DiTBlock_3', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_3', 'Dense_1', 'bias'): (768,) ('DiTBlock_3', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_3', 'Dense_2', 'bias'): (768,) ('DiTBlock_3', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_3', 'Dense_3', 'bias'): (768,) ('DiTBlock_3', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_3', 'Dense_4', 'bias'): (768,) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_4', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_4', 'Dense_0', 'bias'): (4608,) ('DiTBlock_4', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_4', 'Dense_1', 'bias'): (768,) ('DiTBlock_4', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_4', 'Dense_2', 'bias'): (768,) ('DiTBlock_4', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_4', 'Dense_3', 'bias'): (768,) ('DiTBlock_4', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_4', 'Dense_4', 'bias'): (768,) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_5', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_5', 'Dense_0', 'bias'): (4608,) ('DiTBlock_5', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_5', 'Dense_1', 'bias'): (768,) ('DiTBlock_5', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_5', 'Dense_2', 'bias'): (768,) ('DiTBlock_5', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_5', 'Dense_3', 'bias'): (768,) ('DiTBlock_5', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_5', 'Dense_4', 'bias'): (768,) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_6', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_6', 'Dense_0', 'bias'): (4608,) ('DiTBlock_6', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_6', 'Dense_1', 'bias'): (768,) ('DiTBlock_6', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_6', 'Dense_2', 'bias'): (768,) ('DiTBlock_6', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_6', 'Dense_3', 'bias'): (768,) ('DiTBlock_6', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_6', 'Dense_4', 'bias'): (768,) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_7', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_7', 'Dense_0', 'bias'): (4608,) ('DiTBlock_7', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_7', 'Dense_1', 'bias'): (768,) ('DiTBlock_7', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_7', 'Dense_2', 'bias'): (768,) ('DiTBlock_7', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_7', 'Dense_3', 'bias'): (768,) ('DiTBlock_7', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_7', 'Dense_4', 'bias'): (768,) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_8', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_8', 'Dense_0', 'bias'): (4608,) ('DiTBlock_8', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_8', 'Dense_1', 'bias'): (768,) ('DiTBlock_8', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_8', 'Dense_2', 'bias'): (768,) ('DiTBlock_8', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_8', 'Dense_3', 'bias'): (768,) ('DiTBlock_8', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_8', 'Dense_4', 'bias'): (768,) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_9', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_9', 'Dense_0', 'bias'): (4608,) ('DiTBlock_9', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_9', 'Dense_1', 'bias'): (768,) ('DiTBlock_9', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_9', 'Dense_2', 'bias'): (768,) ('DiTBlock_9', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_9', 'Dense_3', 'bias'): (768,) ('DiTBlock_9', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_9', 'Dense_4', 'bias'): (768,) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_10', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_10', 'Dense_0', 'bias'): (4608,) ('DiTBlock_10', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_10', 'Dense_1', 'bias'): (768,) ('DiTBlock_10', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_10', 'Dense_2', 'bias'): (768,) ('DiTBlock_10', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_10', 'Dense_3', 'bias'): (768,) ('DiTBlock_10', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_10', 'Dense_4', 'bias'): (768,) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_11', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_11', 'Dense_0', 'bias'): (4608,) ('DiTBlock_11', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_11', 'Dense_1', 'bias'): (768,) ('DiTBlock_11', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_11', 'Dense_2', 'bias'): (768,) ('DiTBlock_11', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_11', 'Dense_3', 'bias'): (768,) ('DiTBlock_11', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_11', 'Dense_4', 'bias'): (768,) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('FinalLayer_0', 'Dense_0', 'kernel'): (768, 1536) ('FinalLayer_0', 'Dense_0', 'bias'): (1536,) ('FinalLayer_0', 'Dense_1', 'kernel'): (768, 16) ('FinalLayer_0', 'Dense_1', 'bias'): (16,) ('Embed_0', 'embedding'): (256, 1) parameter shapes: ('DiTBlock_0', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_0', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_1', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_1', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_1', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_10', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_10', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_10', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_11', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_11', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_11', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_2', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_2', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_2', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_3', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_3', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_3', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_4', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_4', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_4', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_5', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_5', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_5', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_6', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_6', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_6', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_7', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_7', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_7', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_8', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_8', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_8', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_9', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_9', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_9', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('Embed_0', 'embedding'): (1, 256, 1) ('FinalLayer_0', 'Dense_0', 'bias'): (1, 1536) ('FinalLayer_0', 'Dense_0', 'kernel'): (1, 768, 1536) ('FinalLayer_0', 'Dense_1', 'bias'): (1, 16) ('FinalLayer_0', 'Dense_1', 'kernel'): (1, 768, 16) ('LabelEmbedder_0', 'Embed_0', 'embedding'): (1, 1001, 768) ('PatchEmbed_0', 'Conv_0', 'bias'): (1, 768) ('PatchEmbed_0', 'Conv_0', 'kernel'): (1, 2, 2, 4, 768) ('TimestepEmbedder_0', 'Dense_0', 'bias'): (1, 768) ('TimestepEmbedder_0', 'Dense_0', 'kernel'): (1, 256, 768) ('TimestepEmbedder_0', 'Dense_1', 'bias'): (1, 768) ('TimestepEmbedder_0', 'Dense_1', 'kernel'): (1, 768, 768) ('TimestepEmbedder_1', 'Dense_0', 'bias'): (1, 768) ('TimestepEmbedder_1', 'Dense_0', 'kernel'): (1, 256, 768) ('TimestepEmbedder_1', 'Dense_1', 'bias'): (1, 768) ('TimestepEmbedder_1', 'Dense_1', 'kernel'): (1, 768, 768) parameter shapes: ('DiTBlock_0', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_0', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_1', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_1', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_1', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_10', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_10', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_10', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_11', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_11', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_11', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_2', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_2', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_2', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_3', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_3', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_3', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_4', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_4', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_4', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_5', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_5', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_5', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_6', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_6', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_6', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_7', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_7', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_7', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_8', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_8', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_8', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_9', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_9', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_9', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('Embed_0', 'embedding'): (1, 256, 1) ('FinalLayer_0', 'Dense_0', 'bias'): (1, 1536) ('FinalLayer_0', 'Dense_0', 'kernel'): (1, 768, 1536) ('FinalLayer_0', 'Dense_1', 'bias'): (1, 16) ('FinalLayer_0', 'Dense_1', 'kernel'): (1, 768, 16) ('LabelEmbedder_0', 'Embed_0', 'embedding'): (1, 1001, 768) ('PatchEmbed_0', 'Conv_0', 'bias'): (1, 768) ('PatchEmbed_0', 'Conv_0', 'kernel'): (1, 2, 2, 4, 768) ('TimestepEmbedder_0', 'Dense_0', 'bias'): (1, 768) ('TimestepEmbedder_0', 'Dense_0', 'kernel'): (1, 256, 768) ('TimestepEmbedder_0', 'Dense_1', 'bias'): (1, 768) ('TimestepEmbedder_0', 'Dense_1', 'kernel'): (1, 768, 768) ('TimestepEmbedder_1', 'Dense_0', 'bias'): (1, 768) ('TimestepEmbedder_1', 'Dense_0', 'kernel'): (1, 256, 768) ('TimestepEmbedder_1', 'Dense_1', 'bias'): (1, 768) ('TimestepEmbedder_1', 'Dense_1', 'kernel'): (1, 768, 768) parameter shapes: ('DiTBlock_0', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_0', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_1', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_1', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_1', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_10', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_10', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_10', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_11', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_11', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_11', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_2', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_2', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_2', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_3', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_3', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_3', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_4', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_4', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_4', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_5', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_5', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_5', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_6', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_6', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_6', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_7', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_7', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_7', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_8', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_8', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_8', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_9', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_9', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_9', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('Embed_0', 'embedding'): (1, 256, 1) ('FinalLayer_0', 'Dense_0', 'bias'): (1, 1536) ('FinalLayer_0', 'Dense_0', 'kernel'): (1, 768, 1536) ('FinalLayer_0', 'Dense_1', 'bias'): (1, 16) ('FinalLayer_0', 'Dense_1', 'kernel'): (1, 768, 16) ('LabelEmbedder_0', 'Embed_0', 'embedding'): (1, 1001, 768) ('PatchEmbed_0', 'Conv_0', 'bias'): (1, 768) ('PatchEmbed_0', 'Conv_0', 'kernel'): (1, 2, 2, 4, 768) ('TimestepEmbedder_0', 'Dense_0', 'bias'): (1, 768) ('TimestepEmbedder_0', 'Dense_0', 'kernel'): (1, 256, 768) ('TimestepEmbedder_0', 'Dense_1', 'bias'): (1, 768) ('TimestepEmbedder_0', 'Dense_1', 'kernel'): (1, 768, 768) ('TimestepEmbedder_1', 'Dense_0', 'bias'): (1, 768) ('TimestepEmbedder_1', 'Dense_0', 'kernel'): (1, 256, 768) ('TimestepEmbedder_1', 'Dense_1', 'bias'): (1, 768) ('TimestepEmbedder_1', 'Dense_1', 'kernel'): (1, 768, 768) parameter shapes: ('DiTBlock_0', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_0', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_1', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_1', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_1', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_10', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_10', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_10', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_11', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_11', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_11', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_2', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_2', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_2', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_3', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_3', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_3', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_4', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_4', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_4', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_5', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_5', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_5', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_6', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_6', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_6', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_7', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_7', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_7', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_8', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_8', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_8', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_9', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_9', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_9', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('Embed_0', 'embedding'): (1, 256, 1) ('FinalLayer_0', 'Dense_0', 'bias'): (1, 1536) ('FinalLayer_0', 'Dense_0', 'kernel'): (1, 768, 1536) ('FinalLayer_0', 'Dense_1', 'bias'): (1, 16) ('FinalLayer_0', 'Dense_1', 'kernel'): (1, 768, 16) ('LabelEmbedder_0', 'Embed_0', 'embedding'): (1, 1001, 768) ('PatchEmbed_0', 'Conv_0', 'bias'): (1, 768) ('PatchEmbed_0', 'Conv_0', 'kernel'): (1, 2, 2, 4, 768) ('TimestepEmbedder_0', 'Dense_0', 'bias'): (1, 768) ('TimestepEmbedder_0', 'Dense_0', 'kernel'): (1, 256, 768) ('TimestepEmbedder_0', 'Dense_1', 'bias'): (1, 768) ('TimestepEmbedder_0', 'Dense_1', 'kernel'): (1, 768, 768) ('TimestepEmbedder_1', 'Dense_0', 'bias'): (1, 768) ('TimestepEmbedder_1', 'Dense_0', 'kernel'): (1, 256, 768) ('TimestepEmbedder_1', 'Dense_1', 'bias'): (1, 768) ('TimestepEmbedder_1', 'Dense_1', 'kernel'): (1, 768, 768) parameter shapes: ('DiTBlock_0', 'Dense_0', 'bias'): (4608,) ('DiTBlock_0', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_0', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_0', 'Dense_2', 'bias'): (768,) ('DiTBlock_0', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_0', 'Dense_3', 'bias'): (768,) ('DiTBlock_0', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_0', 'Dense_4', 'bias'): (768,) ('DiTBlock_0', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_1', 'Dense_0', 'bias'): (4608,) ('DiTBlock_1', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_1', 'Dense_1', 'bias'): (768,) ('DiTBlock_1', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_1', 'Dense_2', 'bias'): (768,) ('DiTBlock_1', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_1', 'Dense_3', 'bias'): (768,) ('DiTBlock_1', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_1', 'Dense_4', 'bias'): (768,) ('DiTBlock_1', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_10', 'Dense_0', 'bias'): (4608,) ('DiTBlock_10', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_10', 'Dense_1', 'bias'): (768,) ('DiTBlock_10', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_10', 'Dense_2', 'bias'): (768,) ('DiTBlock_10', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_10', 'Dense_3', 'bias'): (768,) ('DiTBlock_10', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_10', 'Dense_4', 'bias'): (768,) ('DiTBlock_10', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_11', 'Dense_0', 'bias'): (4608,) ('DiTBlock_11', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_11', 'Dense_1', 'bias'): (768,) ('DiTBlock_11', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_11', 'Dense_2', 'bias'): (768,) ('DiTBlock_11', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_11', 'Dense_3', 'bias'): (768,) ('DiTBlock_11', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_11', 'Dense_4', 'bias'): (768,) ('DiTBlock_11', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_2', 'Dense_0', 'bias'): (4608,) ('DiTBlock_2', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_2', 'Dense_1', 'bias'): (768,) ('DiTBlock_2', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_2', 'Dense_2', 'bias'): (768,) ('DiTBlock_2', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_2', 'Dense_3', 'bias'): (768,) ('DiTBlock_2', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_2', 'Dense_4', 'bias'): (768,) ('DiTBlock_2', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_3', 'Dense_0', 'bias'): (4608,) ('DiTBlock_3', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_3', 'Dense_1', 'bias'): (768,) ('DiTBlock_3', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_3', 'Dense_2', 'bias'): (768,) ('DiTBlock_3', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_3', 'Dense_3', 'bias'): (768,) ('DiTBlock_3', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_3', 'Dense_4', 'bias'): (768,) ('DiTBlock_3', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_4', 'Dense_0', 'bias'): (4608,) ('DiTBlock_4', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_4', 'Dense_1', 'bias'): (768,) ('DiTBlock_4', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_4', 'Dense_2', 'bias'): (768,) ('DiTBlock_4', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_4', 'Dense_3', 'bias'): (768,) ('DiTBlock_4', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_4', 'Dense_4', 'bias'): (768,) ('DiTBlock_4', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_5', 'Dense_0', 'bias'): (4608,) ('DiTBlock_5', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_5', 'Dense_1', 'bias'): (768,) ('DiTBlock_5', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_5', 'Dense_2', 'bias'): (768,) ('DiTBlock_5', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_5', 'Dense_3', 'bias'): (768,) ('DiTBlock_5', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_5', 'Dense_4', 'bias'): (768,) ('DiTBlock_5', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_6', 'Dense_0', 'bias'): (4608,) ('DiTBlock_6', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_6', 'Dense_1', 'bias'): (768,) ('DiTBlock_6', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_6', 'Dense_2', 'bias'): (768,) ('DiTBlock_6', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_6', 'Dense_3', 'bias'): (768,) ('DiTBlock_6', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_6', 'Dense_4', 'bias'): (768,) ('DiTBlock_6', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_7', 'Dense_0', 'bias'): (4608,) ('DiTBlock_7', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_7', 'Dense_1', 'bias'): (768,) ('DiTBlock_7', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_7', 'Dense_2', 'bias'): (768,) ('DiTBlock_7', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_7', 'Dense_3', 'bias'): (768,) ('DiTBlock_7', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_7', 'Dense_4', 'bias'): (768,) ('DiTBlock_7', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_8', 'Dense_0', 'bias'): (4608,) ('DiTBlock_8', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_8', 'Dense_1', 'bias'): (768,) ('DiTBlock_8', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_8', 'Dense_2', 'bias'): (768,) ('DiTBlock_8', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_8', 'Dense_3', 'bias'): (768,) ('DiTBlock_8', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_8', 'Dense_4', 'bias'): (768,) ('DiTBlock_8', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_9', 'Dense_0', 'bias'): (4608,) ('DiTBlock_9', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_9', 'Dense_1', 'bias'): (768,) ('DiTBlock_9', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_9', 'Dense_2', 'bias'): (768,) ('DiTBlock_9', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_9', 'Dense_3', 'bias'): (768,) ('DiTBlock_9', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_9', 'Dense_4', 'bias'): (768,) ('DiTBlock_9', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('Embed_0', 'embedding'): (256, 1) ('FinalLayer_0', 'Dense_0', 'bias'): (1536,) ('FinalLayer_0', 'Dense_0', 'kernel'): (768, 1536) ('FinalLayer_0', 'Dense_1', 'bias'): (16,) ('FinalLayer_0', 'Dense_1', 'kernel'): (768, 16) ('LabelEmbedder_0', 'Embed_0', 'embedding'): (1001, 768) ('PatchEmbed_0', 'Conv_0', 'bias'): (768,) ('PatchEmbed_0', 'Conv_0', 'kernel'): (2, 2, 4, 768) ('TimestepEmbedder_0', 'Dense_0', 'bias'): (768,) ('TimestepEmbedder_0', 'Dense_0', 'kernel'): (256, 768) ('TimestepEmbedder_0', 'Dense_1', 'bias'): (768,) ('TimestepEmbedder_0', 'Dense_1', 'kernel'): (768, 768) ('TimestepEmbedder_1', 'Dense_0', 'bias'): (768,) ('TimestepEmbedder_1', 'Dense_0', 'kernel'): (256, 768) ('TimestepEmbedder_1', 'Dense_1', 'bias'): (768,) ('TimestepEmbedder_1', 'Dense_1', 'kernel'): (768, 768) ┌────────────────────────────────────────────────┐ │ │ │ │ │ │ │ │ │ TPU 0,1,2,3 │ │ │ │ │ │ │ │ │ └────────────────────────────────────────────────┘ ┌─────────────────────────────────────────────────────────────────────────┐ │ │ │ │ │ │ │ │ │ TPU 0,1,2,3 │ │ │ │ │ │ │ │ │ └─────────────────────────────────────────────────────────────────────────┘ doing the else (512, 256, 256, 3) encode image shape (128, 256, 256, 3) Initializing encoder. Incoming encoder shape (128, 256, 256, 3) Encoder layer (128, 256, 256, 128) doing downsample Encoder layer (128, 128, 128, 128) doing downsample Encoder layer (128, 64, 64, 256) doing downsample Encoder layer (128, 32, 32, 512) Encoder layer (128, 32, 32, 512) Encoder layer final (128, 32, 32, 512) Encoder layer final (128, 32, 32, 512) Final embeddings are size (128, 32, 32, 8) After quant (128, 32, 32, 4) Calc FID for CFG 1.0 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 z_vectors shape (128, 32, 32, 4) Decoder incoming shape (128, 32, 32, 4) Decoder input (128, 32, 32, 512) Mid Block Decoder layer (128, 32, 32, 512) Mid Block Decoder layer (128, 32, 32, 512) Decoder layer (128, 64, 64, 512) Decoder layer (128, 128, 128, 512) Decoder layer (128, 256, 256, 256) Decoder layer (128, 256, 256, 128) FID is 29.85198211669922 (512, 256, 256, 3) Calc FID for CFG 1.0 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 30.429718017578125 (512, 256, 256, 3) Calc FID for CFG 1.0 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 32.383453369140625 (512, 256, 256, 3) Calc FID for CFG 1.0 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 39.473060607910156 (512, 256, 256, 3) Calc FID for CFG 1.0 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 66.19210815429688 (512, 256, 256, 3) Calc FID for CFG 1.0 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 167.40814208984375 (512, 256, 256, 3) Calc FID for CFG 1.0 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 320.34088134765625 (512, 256, 256, 3) Calc FID for CFG 1.0 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 310.8286437988281 (512, 256, 256, 3) Calc FID for CFG 1.25 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 18.188316345214844 (512, 256, 256, 3) Calc FID for CFG 1.25 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 18.66943359375 (512, 256, 256, 3) Calc FID for CFG 1.25 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 20.218494415283203 (512, 256, 256, 3) Calc FID for CFG 1.25 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 25.816804885864258 (512, 256, 256, 3) Calc FID for CFG 1.25 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 48.70721435546875 (512, 256, 256, 3) Calc FID for CFG 1.25 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 140.623291015625 (512, 256, 256, 3) Calc FID for CFG 1.25 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 303.20208740234375 (512, 256, 256, 3) Calc FID for CFG 1.25 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 297.94085693359375 (512, 256, 256, 3) Calc FID for CFG 1.5 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 11.75399398803711 (512, 256, 256, 3) Calc FID for CFG 1.5 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 12.077102661132812 (512, 256, 256, 3) Calc FID for CFG 1.5 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 13.16611385345459 (512, 256, 256, 3) Calc FID for CFG 1.5 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 17.34552764892578 (512, 256, 256, 3) Calc FID for CFG 1.5 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 35.340003967285156 (512, 256, 256, 3) Calc FID for CFG 1.5 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 116.47549438476562 (512, 256, 256, 3) Calc FID for CFG 1.5 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 286.4483337402344 (512, 256, 256, 3) Calc FID for CFG 1.5 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 288.33428955078125 (512, 256, 256, 3) Calc FID for CFG 1.75 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 8.648992538452148 (512, 256, 256, 3) Calc FID for CFG 1.75 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 8.880931854248047 (512, 256, 256, 3) Calc FID for CFG 1.75 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 9.625060081481934 (512, 256, 256, 3) Calc FID for CFG 1.75 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 12.514213562011719 (512, 256, 256, 3) Calc FID for CFG 1.75 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 26.037988662719727 (512, 256, 256, 3) Calc FID for CFG 1.75 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 95.8776626586914 (512, 256, 256, 3) Calc FID for CFG 1.75 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 270.989501953125 (512, 256, 256, 3) Calc FID for CFG 1.75 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 282.38409423828125 (512, 256, 256, 3) Calc FID for CFG 2.0 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 7.590020656585693 (512, 256, 256, 3) Calc FID for CFG 2.0 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 7.756609916687012 (512, 256, 256, 3) Calc FID for CFG 2.0 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 8.229378700256348 (512, 256, 256, 3) Calc FID for CFG 2.0 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 10.16218376159668 (512, 256, 256, 3) Calc FID for CFG 2.0 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 20.062997817993164 (512, 256, 256, 3) Calc FID for CFG 2.0 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 78.88784790039062 (512, 256, 256, 3) Calc FID for CFG 2.0 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 257.39532470703125 (512, 256, 256, 3) Calc FID for CFG 2.0 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 278.23712158203125 (512, 256, 256, 3) Calc FID for CFG 2.25 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 7.570075035095215 (512, 256, 256, 3) Calc FID for CFG 2.25 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 7.662998676300049 (512, 256, 256, 3) Calc FID for CFG 2.25 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 7.963239669799805 (512, 256, 256, 3) Calc FID for CFG 2.25 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 9.24071216583252 (512, 256, 256, 3) Calc FID for CFG 2.25 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 16.410505294799805 (512, 256, 256, 3) Calc FID for CFG 2.25 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 65.38135528564453 (512, 256, 256, 3) Calc FID for CFG 2.25 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 245.42811584472656 (512, 256, 256, 3) Calc FID for CFG 2.25 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 275.25335693359375 (512, 256, 256, 3) Calc FID for CFG 2.5 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 8.14083194732666 (512, 256, 256, 3) Calc FID for CFG 2.5 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 8.179957389831543 (512, 256, 256, 3) Calc FID for CFG 2.5 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 8.378838539123535 (512, 256, 256, 3) Calc FID for CFG 2.5 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 9.192809104919434 (512, 256, 256, 3) Calc FID for CFG 2.5 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 14.307737350463867 (512, 256, 256, 3) Calc FID for CFG 2.5 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 54.63789367675781 (512, 256, 256, 3) Calc FID for CFG 2.5 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 234.91204833984375 (512, 256, 256, 3) Calc FID for CFG 2.5 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 272.7349853515625 (512, 256, 256, 3) Calc FID for CFG 2.75 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 9.0011625289917 (512, 256, 256, 3) Calc FID for CFG 2.75 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 9.022237777709961 (512, 256, 256, 3) Calc FID for CFG 2.75 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 9.119712829589844 (512, 256, 256, 3) Calc FID for CFG 2.75 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 9.611940383911133 (512, 256, 256, 3) Calc FID for CFG 2.75 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 13.223968505859375 (512, 256, 256, 3) Calc FID for CFG 2.75 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 46.258277893066406 (512, 256, 256, 3) Calc FID for CFG 2.75 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 225.6319580078125 (512, 256, 256, 3) Calc FID for CFG 2.75 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 270.6058349609375 (512, 256, 256, 3) Calc FID for CFG 3.0 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 10.021092414855957 (512, 256, 256, 3) Calc FID for CFG 3.0 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 10.005136489868164 (512, 256, 256, 3) Calc FID for CFG 3.0 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 10.013227462768555 (512, 256, 256, 3) Calc FID for CFG 3.0 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 10.302431106567383 (512, 256, 256, 3) Calc FID for CFG 3.0 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 12.778925895690918 (512, 256, 256, 3) Calc FID for CFG 3.0 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 39.810997009277344 (512, 256, 256, 3) Calc FID for CFG 3.0 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 217.42495727539062 (512, 256, 256, 3) Calc FID for CFG 3.0 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 268.97784423828125 wandb: wandb: 🚀 View run shortcut_imagenet256 at: https://wandb.ai/daniel-z-kaplan/shortcut/runs/shortcut_imagenet256_20250809_222623_345353_10 wandb: Find logs at: ../../../tmp/tmpu25f0_2j/wandb/run-20250809_222623-shortcut_imagenet256_20250809_222623_345353_10/logs