Using devices [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)] Device count 4 Global device count 4 Global Batch: 512 Node Batch: 512 Device Batch: 128 /tmp/tmp57xxkh9z Loading dataset Loading dataset creating model beta1: 0.9 beta2: 0.999 bootstrap_cfg: 1 bootstrap_dt_bias: 0 bootstrap_ema: 1 bootstrap_every: 8 cfg_scale: 1.5 class_dropout_prob: 0.1 denoise_timesteps: 128 depth: 12 dropout: 0.0 dt_sampling: uniform hidden_size: 768 lr: 0.0001 mlp_ratio: 4 num_classes: 1000 num_heads: 12 patch_size: 2 sharding: dp t_sampling: discrete-dt target_update_rate: 0.999 train_type: naive use_cosine: 0 use_ema: 0 use_stable_vae: 1 warmup: 0 weight_decay: 0.1 Total devices TPU_0(process=0,(0,0,0,0)) Initializing encoder. Incoming encoder shape (1, 256, 256, 3) Encoder layer (1, 256, 256, 128) doing downsample Encoder layer (1, 128, 128, 128) doing downsample Encoder layer (1, 64, 64, 256) doing downsample Encoder layer (1, 32, 32, 512) Encoder layer (1, 32, 32, 512) Encoder layer final (1, 32, 32, 512) Encoder layer final (1, 32, 32, 512) Final embeddings are size (1, 32, 32, 8) After quant (1, 32, 32, 4) encode finished Decoder incoming shape (1, 32, 32, 4) Decoder input (1, 32, 32, 512) Mid Block Decoder layer (1, 32, 32, 512) Mid Block Decoder layer (1, 32, 32, 512) Decoder layer (1, 64, 64, 512) Decoder layer (1, 128, 128, 512) Decoder layer (1, 256, 256, 256) Decoder layer (1, 256, 256, 128) Total num of VQVAE parameters: 67565323 Disc shape (1, 128, 128, 128) Disc shape (1, 64, 64, 256) Disc shape (1, 32, 32, 512) Disc shape (1, 16, 16, 512) Disc shape (1, 8, 8, 512) Disc shape (1, 4, 4, 512) Total num of Discriminator parameters: 23998017 Loaded checkpoint from 17943743 seconds ago. Loaded model with step 323001 ┌──────────────────────────────────────────────────────────────────────────────┐ │ TPU 0 │ ├──────────────────────────────────────────────────────────────────────────────┤ │ TPU 1 │ ├──────────────────────────────────────────────────────────────────────────────┤ │ TPU 2 │ ├──────────────────────────────────────────────────────────────────────────────┤ │ TPU 3 │ └──────────────────────────────────────────────────────────────────────────────┘ returning model model done Input to vae (4, 1, 256, 256, 3) encode image shape (1, 256, 256, 3) Initializing encoder. Incoming encoder shape (1, 256, 256, 3) Encoder layer (1, 256, 256, 128) doing downsample Encoder layer (1, 128, 128, 128) doing downsample Encoder layer (1, 64, 64, 256) doing downsample Encoder layer (1, 32, 32, 512) Encoder layer (1, 32, 32, 512) Encoder layer final (1, 32, 32, 512) Encoder layer final (1, 32, 32, 512) Final embeddings are size (1, 32, 32, 8) After quant (1, 32, 32, 4) output example shape (4, 1, 32, 32, 4) Test data shape (4, 256, 256, 3) x shape (4, 1, 256, 256, 3) encoded shape (4, 1, 32, 32, 4) z_vectors shape (1, 32, 32, 4) Decoder incoming shape (1, 32, 32, 4) Decoder input (1, 32, 32, 512) Mid Block Decoder layer (1, 32, 32, 512) Mid Block Decoder layer (1, 32, 32, 512) Decoder layer (1, 64, 64, 512) Decoder layer (1, 128, 128, 512) Decoder layer (1, 256, 256, 256) Decoder layer (1, 256, 256, 128) image shape (4, 1, 256, 256, 3) decoded img shape (256, 256, 3) obs shape (4, 32, 32, 4) DiT: Input of shape (4, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (4, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (4, 256, 768) dtype bfloat16 DiT: Conditioning of shape (1, 768) dtype float32  DiT Summary  ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ path  ┃ module  ┃ inputs  ┃ outputs  ┃ params  ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ │ DiT │ - float32[4,32,32,4] │ bfloat16[4,32,32,4] │ │ │ │ │ - float32[1] │ │ │ │ │ │ - float32[1] │ │ │ │ │ │ - int32[1] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ PatchEmbed_0 │ PatchEmbed │ float32[4,32,32,4] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ PatchEmbed_0/Conv_0 │ Conv │ float32[4,32,32,4] │ bfloat16[4,16,16,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[2,2,4,768] │ │ │ │ │ │ │ │ │ │ │ │ 13,056 (52.2 KB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ TimestepEmbedder_0 │ TimestepEmbedder │ float32[1] │ float32[1,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ TimestepEmbedder_0/Dense_0 │ Dense │ bfloat16[1,256] │ bfloat16[1,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[256,768] │ │ │ │ │ │ │ │ │ │ │ │ 197,376 (789.5 KB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ TimestepEmbedder_0/Dense_1 │ Dense │ bfloat16[1,768] │ float32[1,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ TimestepEmbedder_1 │ TimestepEmbedder │ float32[1] │ float32[1,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ TimestepEmbedder_1/Dense_0 │ Dense │ bfloat16[1,256] │ bfloat16[1,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[256,768] │ │ │ │ │ │ │ │ │ │ │ │ 197,376 (789.5 KB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ TimestepEmbedder_1/Dense_1 │ Dense │ bfloat16[1,768] │ float32[1,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ LabelEmbedder_0 │ LabelEmbedder │ int32[1] │ bfloat16[1,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ LabelEmbedder_0/Embed_0 │ Embed │ int32[1] │ bfloat16[1,768] │ embedding: float32[1001,768] │ │ │ │ │ │ │ │ │ │ │ │ 768,768 (3.1 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ FinalLayer_0 │ FinalLayer │ - bfloat16[4,256,768] │ bfloat16[4,256,16] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ FinalLayer_0/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,1536] │ bias: float32[1536] │ │ │ │ │ │ kernel: float32[768,1536] │ │ │ │ │ │ │ │ │ │ │ │ 1,181,184 (4.7 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ FinalLayer_0/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ FinalLayer_0/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,16] │ bias: float32[16] │ │ │ │ │ │ kernel: float32[768,16] │ │ │ │ │ │ │ │ │ │ │ │ 12,304 (49.2 KB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ Embed_0 │ Embed │ int32[1] │ float32[1,1] │ embedding: float32[256,1] │ │ │ │ │ │ │ │ │ │ │ │ 256 (1.0 KB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │   │   │   │  Total │ 131,091,728 (524.4 MB)  │ └──────────────────────────────────┴──────────────────┴───────────────────────┴───────────────────────┴──────────────────────────────┘    Total Parameters: 131,091,728 (524.4 MB)  DiT: Input of shape (4, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (4, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (4, 256, 768) dtype bfloat16 DiT: Conditioning of shape (1, 768) dtype float32 Loaded checkpoint from 3912 seconds ago. parameter shapes: ('PatchEmbed_0', 'Conv_0', 'kernel'): (2, 2, 4, 768) ('PatchEmbed_0', 'Conv_0', 'bias'): (768,) ('TimestepEmbedder_0', 'Dense_0', 'kernel'): (256, 768) ('TimestepEmbedder_0', 'Dense_0', 'bias'): (768,) ('TimestepEmbedder_0', 'Dense_1', 'kernel'): (768, 768) ('TimestepEmbedder_0', 'Dense_1', 'bias'): (768,) ('TimestepEmbedder_1', 'Dense_0', 'kernel'): (256, 768) ('TimestepEmbedder_1', 'Dense_0', 'bias'): (768,) ('TimestepEmbedder_1', 'Dense_1', 'kernel'): (768, 768) ('TimestepEmbedder_1', 'Dense_1', 'bias'): (768,) ('LabelEmbedder_0', 'Embed_0', 'embedding'): (1001, 768) ('DiTBlock_0', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_0', 'Dense_0', 'bias'): (4608,) ('DiTBlock_0', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_0', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_0', 'Dense_2', 'bias'): (768,) ('DiTBlock_0', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_0', 'Dense_3', 'bias'): (768,) ('DiTBlock_0', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_0', 'Dense_4', 'bias'): (768,) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_1', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_1', 'Dense_0', 'bias'): (4608,) ('DiTBlock_1', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_1', 'Dense_1', 'bias'): (768,) ('DiTBlock_1', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_1', 'Dense_2', 'bias'): (768,) ('DiTBlock_1', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_1', 'Dense_3', 'bias'): (768,) ('DiTBlock_1', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_1', 'Dense_4', 'bias'): (768,) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_2', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_2', 'Dense_0', 'bias'): (4608,) ('DiTBlock_2', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_2', 'Dense_1', 'bias'): (768,) ('DiTBlock_2', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_2', 'Dense_2', 'bias'): (768,) ('DiTBlock_2', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_2', 'Dense_3', 'bias'): (768,) ('DiTBlock_2', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_2', 'Dense_4', 'bias'): (768,) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_3', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_3', 'Dense_0', 'bias'): (4608,) ('DiTBlock_3', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_3', 'Dense_1', 'bias'): (768,) ('DiTBlock_3', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_3', 'Dense_2', 'bias'): (768,) ('DiTBlock_3', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_3', 'Dense_3', 'bias'): (768,) ('DiTBlock_3', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_3', 'Dense_4', 'bias'): (768,) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_4', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_4', 'Dense_0', 'bias'): (4608,) ('DiTBlock_4', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_4', 'Dense_1', 'bias'): (768,) ('DiTBlock_4', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_4', 'Dense_2', 'bias'): (768,) ('DiTBlock_4', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_4', 'Dense_3', 'bias'): (768,) ('DiTBlock_4', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_4', 'Dense_4', 'bias'): (768,) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_5', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_5', 'Dense_0', 'bias'): (4608,) ('DiTBlock_5', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_5', 'Dense_1', 'bias'): (768,) ('DiTBlock_5', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_5', 'Dense_2', 'bias'): (768,) ('DiTBlock_5', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_5', 'Dense_3', 'bias'): (768,) ('DiTBlock_5', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_5', 'Dense_4', 'bias'): (768,) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_6', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_6', 'Dense_0', 'bias'): (4608,) ('DiTBlock_6', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_6', 'Dense_1', 'bias'): (768,) ('DiTBlock_6', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_6', 'Dense_2', 'bias'): (768,) ('DiTBlock_6', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_6', 'Dense_3', 'bias'): (768,) ('DiTBlock_6', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_6', 'Dense_4', 'bias'): (768,) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_7', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_7', 'Dense_0', 'bias'): (4608,) ('DiTBlock_7', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_7', 'Dense_1', 'bias'): (768,) ('DiTBlock_7', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_7', 'Dense_2', 'bias'): (768,) ('DiTBlock_7', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_7', 'Dense_3', 'bias'): (768,) ('DiTBlock_7', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_7', 'Dense_4', 'bias'): (768,) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_8', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_8', 'Dense_0', 'bias'): (4608,) ('DiTBlock_8', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_8', 'Dense_1', 'bias'): (768,) ('DiTBlock_8', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_8', 'Dense_2', 'bias'): (768,) ('DiTBlock_8', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_8', 'Dense_3', 'bias'): (768,) ('DiTBlock_8', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_8', 'Dense_4', 'bias'): (768,) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_9', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_9', 'Dense_0', 'bias'): (4608,) ('DiTBlock_9', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_9', 'Dense_1', 'bias'): (768,) ('DiTBlock_9', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_9', 'Dense_2', 'bias'): (768,) ('DiTBlock_9', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_9', 'Dense_3', 'bias'): (768,) ('DiTBlock_9', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_9', 'Dense_4', 'bias'): (768,) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_10', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_10', 'Dense_0', 'bias'): (4608,) ('DiTBlock_10', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_10', 'Dense_1', 'bias'): (768,) ('DiTBlock_10', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_10', 'Dense_2', 'bias'): (768,) ('DiTBlock_10', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_10', 'Dense_3', 'bias'): (768,) ('DiTBlock_10', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_10', 'Dense_4', 'bias'): (768,) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_11', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_11', 'Dense_0', 'bias'): (4608,) ('DiTBlock_11', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_11', 'Dense_1', 'bias'): (768,) ('DiTBlock_11', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_11', 'Dense_2', 'bias'): (768,) ('DiTBlock_11', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_11', 'Dense_3', 'bias'): (768,) ('DiTBlock_11', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_11', 'Dense_4', 'bias'): (768,) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('FinalLayer_0', 'Dense_0', 'kernel'): (768, 1536) ('FinalLayer_0', 'Dense_0', 'bias'): (1536,) ('FinalLayer_0', 'Dense_1', 'kernel'): (768, 16) ('FinalLayer_0', 'Dense_1', 'bias'): (16,) ('Embed_0', 'embedding'): (256, 1) parameter shapes: ('DiTBlock_0', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_0', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_1', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_1', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_1', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_10', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_10', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_10', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_11', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_11', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_11', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_2', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_2', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_2', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_3', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_3', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_3', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_4', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_4', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_4', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_5', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_5', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_5', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_6', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_6', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_6', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_7', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_7', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_7', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_8', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_8', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_8', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_9', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_9', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_9', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('Embed_0', 'embedding'): (1, 256, 1) ('FinalLayer_0', 'Dense_0', 'bias'): (1, 1536) ('FinalLayer_0', 'Dense_0', 'kernel'): (1, 768, 1536) ('FinalLayer_0', 'Dense_1', 'bias'): (1, 16) ('FinalLayer_0', 'Dense_1', 'kernel'): (1, 768, 16) ('LabelEmbedder_0', 'Embed_0', 'embedding'): (1, 1001, 768) ('PatchEmbed_0', 'Conv_0', 'bias'): (1, 768) ('PatchEmbed_0', 'Conv_0', 'kernel'): (1, 2, 2, 4, 768) ('TimestepEmbedder_0', 'Dense_0', 'bias'): (1, 768) ('TimestepEmbedder_0', 'Dense_0', 'kernel'): (1, 256, 768) ('TimestepEmbedder_0', 'Dense_1', 'bias'): (1, 768) ('TimestepEmbedder_0', 'Dense_1', 'kernel'): (1, 768, 768) ('TimestepEmbedder_1', 'Dense_0', 'bias'): (1, 768) ('TimestepEmbedder_1', 'Dense_0', 'kernel'): (1, 256, 768) ('TimestepEmbedder_1', 'Dense_1', 'bias'): (1, 768) ('TimestepEmbedder_1', 'Dense_1', 'kernel'): (1, 768, 768) parameter shapes: ('DiTBlock_0', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_0', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_1', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_1', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_1', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_10', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_10', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_10', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_11', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_11', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_11', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_2', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_2', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_2', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_3', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_3', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_3', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_4', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_4', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_4', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_5', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_5', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_5', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_6', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_6', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_6', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_7', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_7', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_7', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_8', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_8', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_8', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_9', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_9', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_9', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('Embed_0', 'embedding'): (1, 256, 1) ('FinalLayer_0', 'Dense_0', 'bias'): (1, 1536) ('FinalLayer_0', 'Dense_0', 'kernel'): (1, 768, 1536) ('FinalLayer_0', 'Dense_1', 'bias'): (1, 16) ('FinalLayer_0', 'Dense_1', 'kernel'): (1, 768, 16) ('LabelEmbedder_0', 'Embed_0', 'embedding'): (1, 1001, 768) ('PatchEmbed_0', 'Conv_0', 'bias'): (1, 768) ('PatchEmbed_0', 'Conv_0', 'kernel'): (1, 2, 2, 4, 768) ('TimestepEmbedder_0', 'Dense_0', 'bias'): (1, 768) ('TimestepEmbedder_0', 'Dense_0', 'kernel'): (1, 256, 768) ('TimestepEmbedder_0', 'Dense_1', 'bias'): (1, 768) ('TimestepEmbedder_0', 'Dense_1', 'kernel'): (1, 768, 768) ('TimestepEmbedder_1', 'Dense_0', 'bias'): (1, 768) ('TimestepEmbedder_1', 'Dense_0', 'kernel'): (1, 256, 768) ('TimestepEmbedder_1', 'Dense_1', 'bias'): (1, 768) ('TimestepEmbedder_1', 'Dense_1', 'kernel'): (1, 768, 768) parameter shapes: ('DiTBlock_0', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_0', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_1', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_1', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_1', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_10', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_10', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_10', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_11', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_11', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_11', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_2', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_2', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_2', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_3', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_3', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_3', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_4', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_4', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_4', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_5', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_5', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_5', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_6', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_6', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_6', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_7', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_7', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_7', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_8', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_8', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_8', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_9', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_9', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_9', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('Embed_0', 'embedding'): (1, 256, 1) ('FinalLayer_0', 'Dense_0', 'bias'): (1, 1536) ('FinalLayer_0', 'Dense_0', 'kernel'): (1, 768, 1536) ('FinalLayer_0', 'Dense_1', 'bias'): (1, 16) ('FinalLayer_0', 'Dense_1', 'kernel'): (1, 768, 16) ('LabelEmbedder_0', 'Embed_0', 'embedding'): (1, 1001, 768) ('PatchEmbed_0', 'Conv_0', 'bias'): (1, 768) ('PatchEmbed_0', 'Conv_0', 'kernel'): (1, 2, 2, 4, 768) ('TimestepEmbedder_0', 'Dense_0', 'bias'): (1, 768) ('TimestepEmbedder_0', 'Dense_0', 'kernel'): (1, 256, 768) ('TimestepEmbedder_0', 'Dense_1', 'bias'): (1, 768) ('TimestepEmbedder_0', 'Dense_1', 'kernel'): (1, 768, 768) ('TimestepEmbedder_1', 'Dense_0', 'bias'): (1, 768) ('TimestepEmbedder_1', 'Dense_0', 'kernel'): (1, 256, 768) ('TimestepEmbedder_1', 'Dense_1', 'bias'): (1, 768) ('TimestepEmbedder_1', 'Dense_1', 'kernel'): (1, 768, 768) parameter shapes: ('DiTBlock_0', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_0', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_1', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_1', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_1', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_10', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_10', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_10', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_11', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_11', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_11', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_2', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_2', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_2', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_3', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_3', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_3', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_4', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_4', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_4', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_5', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_5', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_5', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_6', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_6', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_6', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_7', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_7', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_7', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_8', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_8', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_8', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_9', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_9', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_9', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('Embed_0', 'embedding'): (1, 256, 1) ('FinalLayer_0', 'Dense_0', 'bias'): (1, 1536) ('FinalLayer_0', 'Dense_0', 'kernel'): (1, 768, 1536) ('FinalLayer_0', 'Dense_1', 'bias'): (1, 16) ('FinalLayer_0', 'Dense_1', 'kernel'): (1, 768, 16) ('LabelEmbedder_0', 'Embed_0', 'embedding'): (1, 1001, 768) ('PatchEmbed_0', 'Conv_0', 'bias'): (1, 768) ('PatchEmbed_0', 'Conv_0', 'kernel'): (1, 2, 2, 4, 768) ('TimestepEmbedder_0', 'Dense_0', 'bias'): (1, 768) ('TimestepEmbedder_0', 'Dense_0', 'kernel'): (1, 256, 768) ('TimestepEmbedder_0', 'Dense_1', 'bias'): (1, 768) ('TimestepEmbedder_0', 'Dense_1', 'kernel'): (1, 768, 768) ('TimestepEmbedder_1', 'Dense_0', 'bias'): (1, 768) ('TimestepEmbedder_1', 'Dense_0', 'kernel'): (1, 256, 768) ('TimestepEmbedder_1', 'Dense_1', 'bias'): (1, 768) ('TimestepEmbedder_1', 'Dense_1', 'kernel'): (1, 768, 768) parameter shapes: ('DiTBlock_0', 'Dense_0', 'bias'): (4608,) ('DiTBlock_0', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_0', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_0', 'Dense_2', 'bias'): (768,) ('DiTBlock_0', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_0', 'Dense_3', 'bias'): (768,) ('DiTBlock_0', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_0', 'Dense_4', 'bias'): (768,) ('DiTBlock_0', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_1', 'Dense_0', 'bias'): (4608,) ('DiTBlock_1', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_1', 'Dense_1', 'bias'): (768,) ('DiTBlock_1', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_1', 'Dense_2', 'bias'): (768,) ('DiTBlock_1', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_1', 'Dense_3', 'bias'): (768,) ('DiTBlock_1', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_1', 'Dense_4', 'bias'): (768,) ('DiTBlock_1', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_10', 'Dense_0', 'bias'): (4608,) ('DiTBlock_10', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_10', 'Dense_1', 'bias'): (768,) ('DiTBlock_10', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_10', 'Dense_2', 'bias'): (768,) ('DiTBlock_10', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_10', 'Dense_3', 'bias'): (768,) ('DiTBlock_10', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_10', 'Dense_4', 'bias'): (768,) ('DiTBlock_10', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_11', 'Dense_0', 'bias'): (4608,) ('DiTBlock_11', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_11', 'Dense_1', 'bias'): (768,) ('DiTBlock_11', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_11', 'Dense_2', 'bias'): (768,) ('DiTBlock_11', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_11', 'Dense_3', 'bias'): (768,) ('DiTBlock_11', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_11', 'Dense_4', 'bias'): (768,) ('DiTBlock_11', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_2', 'Dense_0', 'bias'): (4608,) ('DiTBlock_2', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_2', 'Dense_1', 'bias'): (768,) ('DiTBlock_2', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_2', 'Dense_2', 'bias'): (768,) ('DiTBlock_2', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_2', 'Dense_3', 'bias'): (768,) ('DiTBlock_2', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_2', 'Dense_4', 'bias'): (768,) ('DiTBlock_2', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_3', 'Dense_0', 'bias'): (4608,) ('DiTBlock_3', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_3', 'Dense_1', 'bias'): (768,) ('DiTBlock_3', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_3', 'Dense_2', 'bias'): (768,) ('DiTBlock_3', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_3', 'Dense_3', 'bias'): (768,) ('DiTBlock_3', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_3', 'Dense_4', 'bias'): (768,) ('DiTBlock_3', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_4', 'Dense_0', 'bias'): (4608,) ('DiTBlock_4', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_4', 'Dense_1', 'bias'): (768,) ('DiTBlock_4', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_4', 'Dense_2', 'bias'): (768,) ('DiTBlock_4', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_4', 'Dense_3', 'bias'): (768,) ('DiTBlock_4', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_4', 'Dense_4', 'bias'): (768,) ('DiTBlock_4', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_5', 'Dense_0', 'bias'): (4608,) ('DiTBlock_5', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_5', 'Dense_1', 'bias'): (768,) ('DiTBlock_5', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_5', 'Dense_2', 'bias'): (768,) ('DiTBlock_5', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_5', 'Dense_3', 'bias'): (768,) ('DiTBlock_5', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_5', 'Dense_4', 'bias'): (768,) ('DiTBlock_5', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_6', 'Dense_0', 'bias'): (4608,) ('DiTBlock_6', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_6', 'Dense_1', 'bias'): (768,) ('DiTBlock_6', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_6', 'Dense_2', 'bias'): (768,) ('DiTBlock_6', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_6', 'Dense_3', 'bias'): (768,) ('DiTBlock_6', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_6', 'Dense_4', 'bias'): (768,) ('DiTBlock_6', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_7', 'Dense_0', 'bias'): (4608,) ('DiTBlock_7', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_7', 'Dense_1', 'bias'): (768,) ('DiTBlock_7', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_7', 'Dense_2', 'bias'): (768,) ('DiTBlock_7', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_7', 'Dense_3', 'bias'): (768,) ('DiTBlock_7', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_7', 'Dense_4', 'bias'): (768,) ('DiTBlock_7', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_8', 'Dense_0', 'bias'): (4608,) ('DiTBlock_8', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_8', 'Dense_1', 'bias'): (768,) ('DiTBlock_8', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_8', 'Dense_2', 'bias'): (768,) ('DiTBlock_8', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_8', 'Dense_3', 'bias'): (768,) ('DiTBlock_8', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_8', 'Dense_4', 'bias'): (768,) ('DiTBlock_8', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_9', 'Dense_0', 'bias'): (4608,) ('DiTBlock_9', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_9', 'Dense_1', 'bias'): (768,) ('DiTBlock_9', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_9', 'Dense_2', 'bias'): (768,) ('DiTBlock_9', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_9', 'Dense_3', 'bias'): (768,) ('DiTBlock_9', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_9', 'Dense_4', 'bias'): (768,) ('DiTBlock_9', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('Embed_0', 'embedding'): (256, 1) ('FinalLayer_0', 'Dense_0', 'bias'): (1536,) ('FinalLayer_0', 'Dense_0', 'kernel'): (768, 1536) ('FinalLayer_0', 'Dense_1', 'bias'): (16,) ('FinalLayer_0', 'Dense_1', 'kernel'): (768, 16) ('LabelEmbedder_0', 'Embed_0', 'embedding'): (1001, 768) ('PatchEmbed_0', 'Conv_0', 'bias'): (768,) ('PatchEmbed_0', 'Conv_0', 'kernel'): (2, 2, 4, 768) ('TimestepEmbedder_0', 'Dense_0', 'bias'): (768,) ('TimestepEmbedder_0', 'Dense_0', 'kernel'): (256, 768) ('TimestepEmbedder_0', 'Dense_1', 'bias'): (768,) ('TimestepEmbedder_0', 'Dense_1', 'kernel'): (768, 768) ('TimestepEmbedder_1', 'Dense_0', 'bias'): (768,) ('TimestepEmbedder_1', 'Dense_0', 'kernel'): (256, 768) ('TimestepEmbedder_1', 'Dense_1', 'bias'): (768,) ('TimestepEmbedder_1', 'Dense_1', 'kernel'): (768, 768) ┌────────────────────────────────────────────────┐ │ │ │ │ │ │ │ │ │ TPU 0,1,2,3 │ │ │ │ │ │ │ │ │ └────────────────────────────────────────────────┘ ┌─────────────────────────────────────────────────────────────────────────┐ │ │ │ │ │ │ │ │ │ TPU 0,1,2,3 │ │ │ │ │ │ │ │ │ └─────────────────────────────────────────────────────────────────────────┘ doing the else (512, 256, 256, 3) encode image shape (128, 256, 256, 3) Initializing encoder. Incoming encoder shape (128, 256, 256, 3) Encoder layer (128, 256, 256, 128) doing downsample Encoder layer (128, 128, 128, 128) doing downsample Encoder layer (128, 64, 64, 256) doing downsample Encoder layer (128, 32, 32, 512) Encoder layer (128, 32, 32, 512) Encoder layer final (128, 32, 32, 512) Encoder layer final (128, 32, 32, 512) Final embeddings are size (128, 32, 32, 8) After quant (128, 32, 32, 4) Calc FID for CFG 1.0 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 z_vectors shape (128, 32, 32, 4) Decoder incoming shape (128, 32, 32, 4) Decoder input (128, 32, 32, 512) Mid Block Decoder layer (128, 32, 32, 512) Mid Block Decoder layer (128, 32, 32, 512) Decoder layer (128, 64, 64, 512) Decoder layer (128, 128, 128, 512) Decoder layer (128, 256, 256, 256) Decoder layer (128, 256, 256, 128) FID is 35.31022644042969 (512, 256, 256, 3) Calc FID for CFG 1.0 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 36.25322723388672 (512, 256, 256, 3) Calc FID for CFG 1.0 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 39.04065704345703 (512, 256, 256, 3) Calc FID for CFG 1.0 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 47.8662109375 (512, 256, 256, 3) Calc FID for CFG 1.0 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 79.06202697753906 (512, 256, 256, 3) Calc FID for CFG 1.0 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 205.282470703125 (512, 256, 256, 3) Calc FID for CFG 1.0 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 366.40966796875 (512, 256, 256, 3) Calc FID for CFG 1.0 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 260.6976013183594 (512, 256, 256, 3) Calc FID for CFG 1.25 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 21.369884490966797 (512, 256, 256, 3) Calc FID for CFG 1.25 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 22.148216247558594 (512, 256, 256, 3) Calc FID for CFG 1.25 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 24.354700088500977 (512, 256, 256, 3) Calc FID for CFG 1.25 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 31.724349975585938 (512, 256, 256, 3) Calc FID for CFG 1.25 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 58.99041748046875 (512, 256, 256, 3) Calc FID for CFG 1.25 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 170.36500549316406 (512, 256, 256, 3) Calc FID for CFG 1.25 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 345.7372131347656 (512, 256, 256, 3) Calc FID for CFG 1.25 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 250.92877197265625 (512, 256, 256, 3) Calc FID for CFG 1.5 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 13.394742965698242 (512, 256, 256, 3) Calc FID for CFG 1.5 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 13.958024978637695 (512, 256, 256, 3) Calc FID for CFG 1.5 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 15.560650825500488 (512, 256, 256, 3) Calc FID for CFG 1.5 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 21.029720306396484 (512, 256, 256, 3) Calc FID for CFG 1.5 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 42.93667221069336 (512, 256, 256, 3) Calc FID for CFG 1.5 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 140.39520263671875 (512, 256, 256, 3) Calc FID for CFG 1.5 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 326.5361633300781 (512, 256, 256, 3) Calc FID for CFG 1.5 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 244.94207763671875 (512, 256, 256, 3) Calc FID for CFG 1.75 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 9.423892974853516 (512, 256, 256, 3) Calc FID for CFG 1.75 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 9.777412414550781 (512, 256, 256, 3) Calc FID for CFG 1.75 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 10.83847713470459 (512, 256, 256, 3) Calc FID for CFG 1.75 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 14.639281272888184 (512, 256, 256, 3) Calc FID for CFG 1.75 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 31.432449340820312 (512, 256, 256, 3) Calc FID for CFG 1.75 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 115.02278900146484 (512, 256, 256, 3) Calc FID for CFG 1.75 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 309.41192626953125 (512, 256, 256, 3) Calc FID for CFG 1.75 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 241.01193237304688 (512, 256, 256, 3) Calc FID for CFG 2.0 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 7.828547954559326 (512, 256, 256, 3) Calc FID for CFG 2.0 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 8.003972053527832 (512, 256, 256, 3) Calc FID for CFG 2.0 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 8.634432792663574 (512, 256, 256, 3) Calc FID for CFG 2.0 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 11.16789722442627 (512, 256, 256, 3) Calc FID for CFG 2.0 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 23.635202407836914 (512, 256, 256, 3) Calc FID for CFG 2.0 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 94.17955017089844 (512, 256, 256, 3) Calc FID for CFG 2.0 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 294.5157775878906 (512, 256, 256, 3) Calc FID for CFG 2.0 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 238.18911743164062 (512, 256, 256, 3) Calc FID for CFG 2.25 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 7.5738983154296875 (512, 256, 256, 3) Calc FID for CFG 2.25 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 7.647280216217041 (512, 256, 256, 3) Calc FID for CFG 2.25 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 7.974250316619873 (512, 256, 256, 3) Calc FID for CFG 2.25 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 9.564567565917969 (512, 256, 256, 3) Calc FID for CFG 2.25 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 18.56076431274414 (512, 256, 256, 3) Calc FID for CFG 2.25 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 77.56942749023438 (512, 256, 256, 3) Calc FID for CFG 2.25 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 281.5552978515625 (512, 256, 256, 3) Calc FID for CFG 2.25 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 235.73397827148438 (512, 256, 256, 3) Calc FID for CFG 2.5 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 8.044367790222168 (512, 256, 256, 3) Calc FID for CFG 2.5 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 8.0531644821167 (512, 256, 256, 3) Calc FID for CFG 2.5 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 8.165743827819824 (512, 256, 256, 3) Calc FID for CFG 2.5 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 9.1002197265625 (512, 256, 256, 3) Calc FID for CFG 2.5 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 15.491735458374023 (512, 256, 256, 3) Calc FID for CFG 2.5 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 64.42039489746094 (512, 256, 256, 3) Calc FID for CFG 2.5 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 270.5625915527344 (512, 256, 256, 3) Calc FID for CFG 2.5 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 233.64608764648438 (512, 256, 256, 3) Calc FID for CFG 2.75 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 8.906791687011719 (512, 256, 256, 3) Calc FID for CFG 2.75 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 8.859354972839355 (512, 256, 256, 3) Calc FID for CFG 2.75 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 8.847463607788086 (512, 256, 256, 3) Calc FID for CFG 2.75 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 9.30807876586914 (512, 256, 256, 3) Calc FID for CFG 2.75 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 13.719435691833496 (512, 256, 256, 3) Calc FID for CFG 2.75 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 54.147682189941406 (512, 256, 256, 3) Calc FID for CFG 2.75 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 261.23638916015625 (512, 256, 256, 3) Calc FID for CFG 2.75 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 232.02645874023438 (512, 256, 256, 3) Calc FID for CFG 3.0 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 9.92387580871582 (512, 256, 256, 3) Calc FID for CFG 3.0 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 9.84957504272461 (512, 256, 256, 3) Calc FID for CFG 3.0 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 9.755224227905273 (512, 256, 256, 3) Calc FID for CFG 3.0 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 9.908406257629395 (512, 256, 256, 3) Calc FID for CFG 3.0 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 12.787176132202148 (512, 256, 256, 3) Calc FID for CFG 3.0 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 46.10810470581055 (512, 256, 256, 3) Calc FID for CFG 3.0 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 253.41326904296875 (512, 256, 256, 3) Calc FID for CFG 3.0 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 230.6085205078125 wandb: wandb: 🚀 View run shortcut_imagenet256 at: https://wandb.ai/daniel-z-kaplan/shortcut/runs/shortcut_imagenet256_20250814_122245_345353_10 wandb: Find logs at: ../../../tmp/tmp57xxkh9z/wandb/run-20250814_122245-shortcut_imagenet256_20250814_122245_345353_10/logs