Using devices [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)] Device count 4 Global device count 4 Global Batch: 512 Node Batch: 512 Device Batch: 128 /tmp/tmp4eouwj9f Loading dataset Loading dataset creating model beta1: 0.9 beta2: 0.999 bootstrap_cfg: 1 bootstrap_dt_bias: 0 bootstrap_ema: 1 bootstrap_every: 8 cfg_scale: 1.5 class_dropout_prob: 0.1 denoise_timesteps: 128 depth: 12 dropout: 0.0 dt_sampling: uniform hidden_size: 768 lr: 0.0001 mlp_ratio: 4 num_classes: 1000 num_heads: 12 patch_size: 2 sharding: dp t_sampling: discrete-dt target_update_rate: 0.999 train_type: naive use_cosine: 0 use_ema: 0 use_stable_vae: 1 warmup: 0 weight_decay: 0.1 Total devices TPU_0(process=0,(0,0,0,0)) Initializing encoder. Incoming encoder shape (1, 256, 256, 3) Encoder layer (1, 256, 256, 128) doing downsample Encoder layer (1, 128, 128, 128) doing downsample Encoder layer (1, 64, 64, 256) doing downsample Encoder layer (1, 32, 32, 512) Encoder layer (1, 32, 32, 512) Encoder layer final (1, 32, 32, 512) Encoder layer final (1, 32, 32, 512) Final embeddings are size (1, 32, 32, 8) After quant (1, 32, 32, 4) encode finished Decoder incoming shape (1, 32, 32, 4) Decoder input (1, 32, 32, 512) Mid Block Decoder layer (1, 32, 32, 512) Mid Block Decoder layer (1, 32, 32, 512) Decoder layer (1, 64, 64, 512) Decoder layer (1, 128, 128, 512) Decoder layer (1, 256, 256, 256) Decoder layer (1, 256, 256, 128) Total num of VQVAE parameters: 67565323 Disc shape (1, 128, 128, 128) Disc shape (1, 64, 64, 256) Disc shape (1, 32, 32, 512) Disc shape (1, 16, 16, 512) Disc shape (1, 8, 8, 512) Disc shape (1, 4, 4, 512) Total num of Discriminator parameters: 23998017 Loaded checkpoint from 17531801 seconds ago. Loaded model with step 474001 ┌──────────────────────────────────────────────────────────────────────────────┐ │ TPU 0 │ ├──────────────────────────────────────────────────────────────────────────────┤ │ TPU 1 │ ├──────────────────────────────────────────────────────────────────────────────┤ │ TPU 2 │ ├──────────────────────────────────────────────────────────────────────────────┤ │ TPU 3 │ └──────────────────────────────────────────────────────────────────────────────┘ returning model model done Input to vae (4, 1, 256, 256, 3) encode image shape (1, 256, 256, 3) Initializing encoder. Incoming encoder shape (1, 256, 256, 3) Encoder layer (1, 256, 256, 128) doing downsample Encoder layer (1, 128, 128, 128) doing downsample Encoder layer (1, 64, 64, 256) doing downsample Encoder layer (1, 32, 32, 512) Encoder layer (1, 32, 32, 512) Encoder layer final (1, 32, 32, 512) Encoder layer final (1, 32, 32, 512) Final embeddings are size (1, 32, 32, 8) After quant (1, 32, 32, 4) output example shape (4, 1, 32, 32, 4) Test data shape (4, 256, 256, 3) x shape (4, 1, 256, 256, 3) encoded shape (4, 1, 32, 32, 4) z_vectors shape (1, 32, 32, 4) Decoder incoming shape (1, 32, 32, 4) Decoder input (1, 32, 32, 512) Mid Block Decoder layer (1, 32, 32, 512) Mid Block Decoder layer (1, 32, 32, 512) Decoder layer (1, 64, 64, 512) Decoder layer (1, 128, 128, 512) Decoder layer (1, 256, 256, 256) Decoder layer (1, 256, 256, 128) image shape (4, 1, 256, 256, 3) decoded img shape (256, 256, 3) obs shape (4, 32, 32, 4) DiT: Input of shape (4, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (4, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (4, 256, 768) dtype bfloat16 DiT: Conditioning of shape (1, 768) dtype float32  DiT Summary  ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ path  ┃ module  ┃ inputs  ┃ outputs  ┃ params  ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ │ DiT │ - float32[4,32,32,4] │ bfloat16[4,32,32,4] │ │ │ │ │ - float32[1] │ │ │ │ │ │ - float32[1] │ │ │ │ │ │ - int32[1] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ PatchEmbed_0 │ PatchEmbed │ float32[4,32,32,4] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ PatchEmbed_0/Conv_0 │ Conv │ float32[4,32,32,4] │ bfloat16[4,16,16,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[2,2,4,768] │ │ │ │ │ │ │ │ │ │ │ │ 13,056 (52.2 KB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ TimestepEmbedder_0 │ TimestepEmbedder │ float32[1] │ float32[1,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ TimestepEmbedder_0/Dense_0 │ Dense │ bfloat16[1,256] │ bfloat16[1,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[256,768] │ │ │ │ │ │ │ │ │ │ │ │ 197,376 (789.5 KB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ TimestepEmbedder_0/Dense_1 │ Dense │ bfloat16[1,768] │ float32[1,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ TimestepEmbedder_1 │ TimestepEmbedder │ float32[1] │ float32[1,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ TimestepEmbedder_1/Dense_0 │ Dense │ bfloat16[1,256] │ bfloat16[1,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[256,768] │ │ │ │ │ │ │ │ │ │ │ │ 197,376 (789.5 KB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ TimestepEmbedder_1/Dense_1 │ Dense │ bfloat16[1,768] │ float32[1,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ LabelEmbedder_0 │ LabelEmbedder │ int32[1] │ bfloat16[1,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ LabelEmbedder_0/Embed_0 │ Embed │ int32[1] │ bfloat16[1,768] │ embedding: float32[1001,768] │ │ │ │ │ │ │ │ │ │ │ │ 768,768 (3.1 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_0/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_1/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_2/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_3/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_4/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_5/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_6/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_7/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_8/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_9/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_10/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ │ │ │ │ │ kernel: float32[768,4608] │ │ │ │ │ │ │ │ │ │ │ │ 3,543,552 (14.2 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[768,768] │ │ │ │ │ │ │ │ │ │ │ │ 590,592 (2.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ │ │ │ │ │ kernel: float32[768,3072] │ │ │ │ │ │ │ │ │ │ │ │ 2,362,368 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ │ │ │ │ │ kernel: float32[3072,768] │ │ │ │ │ │ │ │ │ │ │ │ 2,360,064 (9.4 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ DiTBlock_11/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ FinalLayer_0 │ FinalLayer │ - bfloat16[4,256,768] │ bfloat16[4,256,16] │ │ │ │ │ - float32[1,768] │ │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ FinalLayer_0/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,1536] │ bias: float32[1536] │ │ │ │ │ │ kernel: float32[768,1536] │ │ │ │ │ │ │ │ │ │ │ │ 1,181,184 (4.7 MB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ FinalLayer_0/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ FinalLayer_0/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,16] │ bias: float32[16] │ │ │ │ │ │ kernel: float32[768,16] │ │ │ │ │ │ │ │ │ │ │ │ 12,304 (49.2 KB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │ Embed_0 │ Embed │ int32[1] │ float32[1,1] │ embedding: float32[256,1] │ │ │ │ │ │ │ │ │ │ │ │ 256 (1.0 KB) │ ├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ │   │   │   │  Total │ 131,091,728 (524.4 MB)  │ └──────────────────────────────────┴──────────────────┴───────────────────────┴───────────────────────┴──────────────────────────────┘    Total Parameters: 131,091,728 (524.4 MB)  DiT: Input of shape (4, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (4, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (4, 256, 768) dtype bfloat16 DiT: Conditioning of shape (1, 768) dtype float32 Loaded checkpoint from 4024 seconds ago. parameter shapes: ('PatchEmbed_0', 'Conv_0', 'kernel'): (2, 2, 4, 768) ('PatchEmbed_0', 'Conv_0', 'bias'): (768,) ('TimestepEmbedder_0', 'Dense_0', 'kernel'): (256, 768) ('TimestepEmbedder_0', 'Dense_0', 'bias'): (768,) ('TimestepEmbedder_0', 'Dense_1', 'kernel'): (768, 768) ('TimestepEmbedder_0', 'Dense_1', 'bias'): (768,) ('TimestepEmbedder_1', 'Dense_0', 'kernel'): (256, 768) ('TimestepEmbedder_1', 'Dense_0', 'bias'): (768,) ('TimestepEmbedder_1', 'Dense_1', 'kernel'): (768, 768) ('TimestepEmbedder_1', 'Dense_1', 'bias'): (768,) ('LabelEmbedder_0', 'Embed_0', 'embedding'): (1001, 768) ('DiTBlock_0', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_0', 'Dense_0', 'bias'): (4608,) ('DiTBlock_0', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_0', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_0', 'Dense_2', 'bias'): (768,) ('DiTBlock_0', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_0', 'Dense_3', 'bias'): (768,) ('DiTBlock_0', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_0', 'Dense_4', 'bias'): (768,) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_1', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_1', 'Dense_0', 'bias'): (4608,) ('DiTBlock_1', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_1', 'Dense_1', 'bias'): (768,) ('DiTBlock_1', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_1', 'Dense_2', 'bias'): (768,) ('DiTBlock_1', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_1', 'Dense_3', 'bias'): (768,) ('DiTBlock_1', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_1', 'Dense_4', 'bias'): (768,) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_2', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_2', 'Dense_0', 'bias'): (4608,) ('DiTBlock_2', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_2', 'Dense_1', 'bias'): (768,) ('DiTBlock_2', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_2', 'Dense_2', 'bias'): (768,) ('DiTBlock_2', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_2', 'Dense_3', 'bias'): (768,) ('DiTBlock_2', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_2', 'Dense_4', 'bias'): (768,) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_3', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_3', 'Dense_0', 'bias'): (4608,) ('DiTBlock_3', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_3', 'Dense_1', 'bias'): (768,) ('DiTBlock_3', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_3', 'Dense_2', 'bias'): (768,) ('DiTBlock_3', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_3', 'Dense_3', 'bias'): (768,) ('DiTBlock_3', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_3', 'Dense_4', 'bias'): (768,) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_4', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_4', 'Dense_0', 'bias'): (4608,) ('DiTBlock_4', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_4', 'Dense_1', 'bias'): (768,) ('DiTBlock_4', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_4', 'Dense_2', 'bias'): (768,) ('DiTBlock_4', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_4', 'Dense_3', 'bias'): (768,) ('DiTBlock_4', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_4', 'Dense_4', 'bias'): (768,) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_5', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_5', 'Dense_0', 'bias'): (4608,) ('DiTBlock_5', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_5', 'Dense_1', 'bias'): (768,) ('DiTBlock_5', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_5', 'Dense_2', 'bias'): (768,) ('DiTBlock_5', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_5', 'Dense_3', 'bias'): (768,) ('DiTBlock_5', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_5', 'Dense_4', 'bias'): (768,) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_6', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_6', 'Dense_0', 'bias'): (4608,) ('DiTBlock_6', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_6', 'Dense_1', 'bias'): (768,) ('DiTBlock_6', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_6', 'Dense_2', 'bias'): (768,) ('DiTBlock_6', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_6', 'Dense_3', 'bias'): (768,) ('DiTBlock_6', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_6', 'Dense_4', 'bias'): (768,) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_7', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_7', 'Dense_0', 'bias'): (4608,) ('DiTBlock_7', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_7', 'Dense_1', 'bias'): (768,) ('DiTBlock_7', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_7', 'Dense_2', 'bias'): (768,) ('DiTBlock_7', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_7', 'Dense_3', 'bias'): (768,) ('DiTBlock_7', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_7', 'Dense_4', 'bias'): (768,) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_8', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_8', 'Dense_0', 'bias'): (4608,) ('DiTBlock_8', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_8', 'Dense_1', 'bias'): (768,) ('DiTBlock_8', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_8', 'Dense_2', 'bias'): (768,) ('DiTBlock_8', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_8', 'Dense_3', 'bias'): (768,) ('DiTBlock_8', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_8', 'Dense_4', 'bias'): (768,) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_9', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_9', 'Dense_0', 'bias'): (4608,) ('DiTBlock_9', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_9', 'Dense_1', 'bias'): (768,) ('DiTBlock_9', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_9', 'Dense_2', 'bias'): (768,) ('DiTBlock_9', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_9', 'Dense_3', 'bias'): (768,) ('DiTBlock_9', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_9', 'Dense_4', 'bias'): (768,) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_10', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_10', 'Dense_0', 'bias'): (4608,) ('DiTBlock_10', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_10', 'Dense_1', 'bias'): (768,) ('DiTBlock_10', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_10', 'Dense_2', 'bias'): (768,) ('DiTBlock_10', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_10', 'Dense_3', 'bias'): (768,) ('DiTBlock_10', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_10', 'Dense_4', 'bias'): (768,) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_11', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_11', 'Dense_0', 'bias'): (4608,) ('DiTBlock_11', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_11', 'Dense_1', 'bias'): (768,) ('DiTBlock_11', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_11', 'Dense_2', 'bias'): (768,) ('DiTBlock_11', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_11', 'Dense_3', 'bias'): (768,) ('DiTBlock_11', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_11', 'Dense_4', 'bias'): (768,) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('FinalLayer_0', 'Dense_0', 'kernel'): (768, 1536) ('FinalLayer_0', 'Dense_0', 'bias'): (1536,) ('FinalLayer_0', 'Dense_1', 'kernel'): (768, 16) ('FinalLayer_0', 'Dense_1', 'bias'): (16,) ('Embed_0', 'embedding'): (256, 1) parameter shapes: ('DiTBlock_0', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_0', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_1', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_1', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_1', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_10', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_10', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_10', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_11', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_11', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_11', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_2', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_2', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_2', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_3', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_3', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_3', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_4', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_4', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_4', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_5', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_5', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_5', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_6', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_6', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_6', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_7', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_7', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_7', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_8', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_8', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_8', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_9', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_9', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_9', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('Embed_0', 'embedding'): (1, 256, 1) ('FinalLayer_0', 'Dense_0', 'bias'): (1, 1536) ('FinalLayer_0', 'Dense_0', 'kernel'): (1, 768, 1536) ('FinalLayer_0', 'Dense_1', 'bias'): (1, 16) ('FinalLayer_0', 'Dense_1', 'kernel'): (1, 768, 16) ('LabelEmbedder_0', 'Embed_0', 'embedding'): (1, 1001, 768) ('PatchEmbed_0', 'Conv_0', 'bias'): (1, 768) ('PatchEmbed_0', 'Conv_0', 'kernel'): (1, 2, 2, 4, 768) ('TimestepEmbedder_0', 'Dense_0', 'bias'): (1, 768) ('TimestepEmbedder_0', 'Dense_0', 'kernel'): (1, 256, 768) ('TimestepEmbedder_0', 'Dense_1', 'bias'): (1, 768) ('TimestepEmbedder_0', 'Dense_1', 'kernel'): (1, 768, 768) ('TimestepEmbedder_1', 'Dense_0', 'bias'): (1, 768) ('TimestepEmbedder_1', 'Dense_0', 'kernel'): (1, 256, 768) ('TimestepEmbedder_1', 'Dense_1', 'bias'): (1, 768) ('TimestepEmbedder_1', 'Dense_1', 'kernel'): (1, 768, 768) parameter shapes: ('DiTBlock_0', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_0', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_1', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_1', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_1', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_10', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_10', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_10', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_11', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_11', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_11', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_2', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_2', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_2', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_3', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_3', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_3', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_4', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_4', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_4', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_5', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_5', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_5', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_6', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_6', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_6', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_7', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_7', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_7', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_8', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_8', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_8', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_9', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_9', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_9', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('Embed_0', 'embedding'): (1, 256, 1) ('FinalLayer_0', 'Dense_0', 'bias'): (1, 1536) ('FinalLayer_0', 'Dense_0', 'kernel'): (1, 768, 1536) ('FinalLayer_0', 'Dense_1', 'bias'): (1, 16) ('FinalLayer_0', 'Dense_1', 'kernel'): (1, 768, 16) ('LabelEmbedder_0', 'Embed_0', 'embedding'): (1, 1001, 768) ('PatchEmbed_0', 'Conv_0', 'bias'): (1, 768) ('PatchEmbed_0', 'Conv_0', 'kernel'): (1, 2, 2, 4, 768) ('TimestepEmbedder_0', 'Dense_0', 'bias'): (1, 768) ('TimestepEmbedder_0', 'Dense_0', 'kernel'): (1, 256, 768) ('TimestepEmbedder_0', 'Dense_1', 'bias'): (1, 768) ('TimestepEmbedder_0', 'Dense_1', 'kernel'): (1, 768, 768) ('TimestepEmbedder_1', 'Dense_0', 'bias'): (1, 768) ('TimestepEmbedder_1', 'Dense_0', 'kernel'): (1, 256, 768) ('TimestepEmbedder_1', 'Dense_1', 'bias'): (1, 768) ('TimestepEmbedder_1', 'Dense_1', 'kernel'): (1, 768, 768) parameter shapes: ('DiTBlock_0', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_0', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_1', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_1', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_1', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_10', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_10', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_10', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_11', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_11', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_11', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_2', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_2', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_2', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_3', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_3', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_3', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_4', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_4', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_4', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_5', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_5', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_5', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_6', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_6', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_6', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_7', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_7', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_7', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_8', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_8', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_8', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_9', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_9', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_9', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('Embed_0', 'embedding'): (1, 256, 1) ('FinalLayer_0', 'Dense_0', 'bias'): (1, 1536) ('FinalLayer_0', 'Dense_0', 'kernel'): (1, 768, 1536) ('FinalLayer_0', 'Dense_1', 'bias'): (1, 16) ('FinalLayer_0', 'Dense_1', 'kernel'): (1, 768, 16) ('LabelEmbedder_0', 'Embed_0', 'embedding'): (1, 1001, 768) ('PatchEmbed_0', 'Conv_0', 'bias'): (1, 768) ('PatchEmbed_0', 'Conv_0', 'kernel'): (1, 2, 2, 4, 768) ('TimestepEmbedder_0', 'Dense_0', 'bias'): (1, 768) ('TimestepEmbedder_0', 'Dense_0', 'kernel'): (1, 256, 768) ('TimestepEmbedder_0', 'Dense_1', 'bias'): (1, 768) ('TimestepEmbedder_0', 'Dense_1', 'kernel'): (1, 768, 768) ('TimestepEmbedder_1', 'Dense_0', 'bias'): (1, 768) ('TimestepEmbedder_1', 'Dense_0', 'kernel'): (1, 256, 768) ('TimestepEmbedder_1', 'Dense_1', 'bias'): (1, 768) ('TimestepEmbedder_1', 'Dense_1', 'kernel'): (1, 768, 768) parameter shapes: ('DiTBlock_0', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_0', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_0', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_1', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_1', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_1', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_1', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_10', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_10', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_10', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_10', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_11', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_11', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_11', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_11', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_2', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_2', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_2', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_2', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_3', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_3', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_3', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_3', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_4', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_4', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_4', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_4', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_5', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_5', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_5', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_5', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_6', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_6', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_6', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_6', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_7', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_7', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_7', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_7', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_8', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_8', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_8', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_8', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('DiTBlock_9', 'Dense_0', 'bias'): (1, 4608) ('DiTBlock_9', 'Dense_0', 'kernel'): (1, 768, 4608) ('DiTBlock_9', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_1', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_2', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_2', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_3', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_3', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'Dense_4', 'bias'): (1, 768) ('DiTBlock_9', 'Dense_4', 'kernel'): (1, 768, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) ('Embed_0', 'embedding'): (1, 256, 1) ('FinalLayer_0', 'Dense_0', 'bias'): (1, 1536) ('FinalLayer_0', 'Dense_0', 'kernel'): (1, 768, 1536) ('FinalLayer_0', 'Dense_1', 'bias'): (1, 16) ('FinalLayer_0', 'Dense_1', 'kernel'): (1, 768, 16) ('LabelEmbedder_0', 'Embed_0', 'embedding'): (1, 1001, 768) ('PatchEmbed_0', 'Conv_0', 'bias'): (1, 768) ('PatchEmbed_0', 'Conv_0', 'kernel'): (1, 2, 2, 4, 768) ('TimestepEmbedder_0', 'Dense_0', 'bias'): (1, 768) ('TimestepEmbedder_0', 'Dense_0', 'kernel'): (1, 256, 768) ('TimestepEmbedder_0', 'Dense_1', 'bias'): (1, 768) ('TimestepEmbedder_0', 'Dense_1', 'kernel'): (1, 768, 768) ('TimestepEmbedder_1', 'Dense_0', 'bias'): (1, 768) ('TimestepEmbedder_1', 'Dense_0', 'kernel'): (1, 256, 768) ('TimestepEmbedder_1', 'Dense_1', 'bias'): (1, 768) ('TimestepEmbedder_1', 'Dense_1', 'kernel'): (1, 768, 768) parameter shapes: ('DiTBlock_0', 'Dense_0', 'bias'): (4608,) ('DiTBlock_0', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_0', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_0', 'Dense_2', 'bias'): (768,) ('DiTBlock_0', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_0', 'Dense_3', 'bias'): (768,) ('DiTBlock_0', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_0', 'Dense_4', 'bias'): (768,) ('DiTBlock_0', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_1', 'Dense_0', 'bias'): (4608,) ('DiTBlock_1', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_1', 'Dense_1', 'bias'): (768,) ('DiTBlock_1', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_1', 'Dense_2', 'bias'): (768,) ('DiTBlock_1', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_1', 'Dense_3', 'bias'): (768,) ('DiTBlock_1', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_1', 'Dense_4', 'bias'): (768,) ('DiTBlock_1', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_10', 'Dense_0', 'bias'): (4608,) ('DiTBlock_10', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_10', 'Dense_1', 'bias'): (768,) ('DiTBlock_10', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_10', 'Dense_2', 'bias'): (768,) ('DiTBlock_10', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_10', 'Dense_3', 'bias'): (768,) ('DiTBlock_10', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_10', 'Dense_4', 'bias'): (768,) ('DiTBlock_10', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_11', 'Dense_0', 'bias'): (4608,) ('DiTBlock_11', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_11', 'Dense_1', 'bias'): (768,) ('DiTBlock_11', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_11', 'Dense_2', 'bias'): (768,) ('DiTBlock_11', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_11', 'Dense_3', 'bias'): (768,) ('DiTBlock_11', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_11', 'Dense_4', 'bias'): (768,) ('DiTBlock_11', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_2', 'Dense_0', 'bias'): (4608,) ('DiTBlock_2', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_2', 'Dense_1', 'bias'): (768,) ('DiTBlock_2', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_2', 'Dense_2', 'bias'): (768,) ('DiTBlock_2', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_2', 'Dense_3', 'bias'): (768,) ('DiTBlock_2', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_2', 'Dense_4', 'bias'): (768,) ('DiTBlock_2', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_3', 'Dense_0', 'bias'): (4608,) ('DiTBlock_3', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_3', 'Dense_1', 'bias'): (768,) ('DiTBlock_3', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_3', 'Dense_2', 'bias'): (768,) ('DiTBlock_3', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_3', 'Dense_3', 'bias'): (768,) ('DiTBlock_3', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_3', 'Dense_4', 'bias'): (768,) ('DiTBlock_3', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_4', 'Dense_0', 'bias'): (4608,) ('DiTBlock_4', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_4', 'Dense_1', 'bias'): (768,) ('DiTBlock_4', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_4', 'Dense_2', 'bias'): (768,) ('DiTBlock_4', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_4', 'Dense_3', 'bias'): (768,) ('DiTBlock_4', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_4', 'Dense_4', 'bias'): (768,) ('DiTBlock_4', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_5', 'Dense_0', 'bias'): (4608,) ('DiTBlock_5', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_5', 'Dense_1', 'bias'): (768,) ('DiTBlock_5', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_5', 'Dense_2', 'bias'): (768,) ('DiTBlock_5', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_5', 'Dense_3', 'bias'): (768,) ('DiTBlock_5', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_5', 'Dense_4', 'bias'): (768,) ('DiTBlock_5', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_6', 'Dense_0', 'bias'): (4608,) ('DiTBlock_6', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_6', 'Dense_1', 'bias'): (768,) ('DiTBlock_6', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_6', 'Dense_2', 'bias'): (768,) ('DiTBlock_6', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_6', 'Dense_3', 'bias'): (768,) ('DiTBlock_6', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_6', 'Dense_4', 'bias'): (768,) ('DiTBlock_6', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_7', 'Dense_0', 'bias'): (4608,) ('DiTBlock_7', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_7', 'Dense_1', 'bias'): (768,) ('DiTBlock_7', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_7', 'Dense_2', 'bias'): (768,) ('DiTBlock_7', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_7', 'Dense_3', 'bias'): (768,) ('DiTBlock_7', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_7', 'Dense_4', 'bias'): (768,) ('DiTBlock_7', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_8', 'Dense_0', 'bias'): (4608,) ('DiTBlock_8', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_8', 'Dense_1', 'bias'): (768,) ('DiTBlock_8', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_8', 'Dense_2', 'bias'): (768,) ('DiTBlock_8', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_8', 'Dense_3', 'bias'): (768,) ('DiTBlock_8', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_8', 'Dense_4', 'bias'): (768,) ('DiTBlock_8', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('DiTBlock_9', 'Dense_0', 'bias'): (4608,) ('DiTBlock_9', 'Dense_0', 'kernel'): (768, 4608) ('DiTBlock_9', 'Dense_1', 'bias'): (768,) ('DiTBlock_9', 'Dense_1', 'kernel'): (768, 768) ('DiTBlock_9', 'Dense_2', 'bias'): (768,) ('DiTBlock_9', 'Dense_2', 'kernel'): (768, 768) ('DiTBlock_9', 'Dense_3', 'bias'): (768,) ('DiTBlock_9', 'Dense_3', 'kernel'): (768, 768) ('DiTBlock_9', 'Dense_4', 'bias'): (768,) ('DiTBlock_9', 'Dense_4', 'kernel'): (768, 768) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) ('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) ('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) ('Embed_0', 'embedding'): (256, 1) ('FinalLayer_0', 'Dense_0', 'bias'): (1536,) ('FinalLayer_0', 'Dense_0', 'kernel'): (768, 1536) ('FinalLayer_0', 'Dense_1', 'bias'): (16,) ('FinalLayer_0', 'Dense_1', 'kernel'): (768, 16) ('LabelEmbedder_0', 'Embed_0', 'embedding'): (1001, 768) ('PatchEmbed_0', 'Conv_0', 'bias'): (768,) ('PatchEmbed_0', 'Conv_0', 'kernel'): (2, 2, 4, 768) ('TimestepEmbedder_0', 'Dense_0', 'bias'): (768,) ('TimestepEmbedder_0', 'Dense_0', 'kernel'): (256, 768) ('TimestepEmbedder_0', 'Dense_1', 'bias'): (768,) ('TimestepEmbedder_0', 'Dense_1', 'kernel'): (768, 768) ('TimestepEmbedder_1', 'Dense_0', 'bias'): (768,) ('TimestepEmbedder_1', 'Dense_0', 'kernel'): (256, 768) ('TimestepEmbedder_1', 'Dense_1', 'bias'): (768,) ('TimestepEmbedder_1', 'Dense_1', 'kernel'): (768, 768) ┌────────────────────────────────────────────────┐ │ │ │ │ │ │ │ │ │ TPU 0,1,2,3 │ │ │ │ │ │ │ │ │ └────────────────────────────────────────────────┘ ┌─────────────────────────────────────────────────────────────────────────┐ │ │ │ │ │ │ │ │ │ TPU 0,1,2,3 │ │ │ │ │ │ │ │ │ └─────────────────────────────────────────────────────────────────────────┘ doing the else (512, 256, 256, 3) encode image shape (128, 256, 256, 3) Initializing encoder. Incoming encoder shape (128, 256, 256, 3) Encoder layer (128, 256, 256, 128) doing downsample Encoder layer (128, 128, 128, 128) doing downsample Encoder layer (128, 64, 64, 256) doing downsample Encoder layer (128, 32, 32, 512) Encoder layer (128, 32, 32, 512) Encoder layer final (128, 32, 32, 512) Encoder layer final (128, 32, 32, 512) Final embeddings are size (128, 32, 32, 8) After quant (128, 32, 32, 4) Calc FID for CFG 1.0 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 z_vectors shape (128, 32, 32, 4) Decoder incoming shape (128, 32, 32, 4) Decoder input (128, 32, 32, 512) Mid Block Decoder layer (128, 32, 32, 512) Mid Block Decoder layer (128, 32, 32, 512) Decoder layer (128, 64, 64, 512) Decoder layer (128, 128, 128, 512) Decoder layer (128, 256, 256, 256) Decoder layer (128, 256, 256, 128) FID is 37.710113525390625 (512, 256, 256, 3) Calc FID for CFG 1.0 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 38.55653381347656 (512, 256, 256, 3) Calc FID for CFG 1.0 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 40.906532287597656 (512, 256, 256, 3) Calc FID for CFG 1.0 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 48.274147033691406 (512, 256, 256, 3) Calc FID for CFG 1.0 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 72.10687255859375 (512, 256, 256, 3) Calc FID for CFG 1.0 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 157.84823608398438 (512, 256, 256, 3) Calc FID for CFG 1.0 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 302.2505798339844 (512, 256, 256, 3) Calc FID for CFG 1.0 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 286.669921875 (512, 256, 256, 3) Calc FID for CFG 1.25 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 23.236591339111328 (512, 256, 256, 3) Calc FID for CFG 1.25 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 23.893098831176758 (512, 256, 256, 3) Calc FID for CFG 1.25 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 25.84885597229004 (512, 256, 256, 3) Calc FID for CFG 1.25 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 32.232967376708984 (512, 256, 256, 3) Calc FID for CFG 1.25 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 53.88352966308594 (512, 256, 256, 3) Calc FID for CFG 1.25 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 132.78050231933594 (512, 256, 256, 3) Calc FID for CFG 1.25 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 291.7098083496094 (512, 256, 256, 3) Calc FID for CFG 1.25 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 275.8406677246094 (512, 256, 256, 3) Calc FID for CFG 1.5 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 14.714788436889648 (512, 256, 256, 3) Calc FID for CFG 1.5 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 15.207860946655273 (512, 256, 256, 3) Calc FID for CFG 1.5 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 16.68800163269043 (512, 256, 256, 3) Calc FID for CFG 1.5 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 21.549667358398438 (512, 256, 256, 3) Calc FID for CFG 1.5 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 39.529205322265625 (512, 256, 256, 3) Calc FID for CFG 1.5 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 111.15298461914062 (512, 256, 256, 3) Calc FID for CFG 1.5 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 282.57464599609375 (512, 256, 256, 3) Calc FID for CFG 1.5 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 268.36480712890625 (512, 256, 256, 3) Calc FID for CFG 1.75 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 10.178242683410645 (512, 256, 256, 3) Calc FID for CFG 1.75 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 10.496488571166992 (512, 256, 256, 3) Calc FID for CFG 1.75 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 11.531961441040039 (512, 256, 256, 3) Calc FID for CFG 1.75 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 15.042491912841797 (512, 256, 256, 3) Calc FID for CFG 1.75 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 29.286884307861328 (512, 256, 256, 3) Calc FID for CFG 1.75 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 92.95084381103516 (512, 256, 256, 3) Calc FID for CFG 1.75 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 274.617431640625 (512, 256, 256, 3) Calc FID for CFG 1.75 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 263.70489501953125 (512, 256, 256, 3) Calc FID for CFG 2.0 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 8.128108024597168 (512, 256, 256, 3) Calc FID for CFG 2.0 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 8.341901779174805 (512, 256, 256, 3) Calc FID for CFG 2.0 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 8.977622985839844 (512, 256, 256, 3) Calc FID for CFG 2.0 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 11.461089134216309 (512, 256, 256, 3) Calc FID for CFG 2.0 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 22.358848571777344 (512, 256, 256, 3) Calc FID for CFG 2.0 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 78.08787536621094 (512, 256, 256, 3) Calc FID for CFG 2.0 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 267.7125244140625 (512, 256, 256, 3) Calc FID for CFG 2.0 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 260.5279846191406 (512, 256, 256, 3) Calc FID for CFG 2.25 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 7.579560279846191 (512, 256, 256, 3) Calc FID for CFG 2.25 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 7.693774223327637 (512, 256, 256, 3) Calc FID for CFG 2.25 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 8.097635269165039 (512, 256, 256, 3) Calc FID for CFG 2.25 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 9.761984825134277 (512, 256, 256, 3) Calc FID for CFG 2.25 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 17.961179733276367 (512, 256, 256, 3) Calc FID for CFG 2.25 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 66.0443115234375 (512, 256, 256, 3) Calc FID for CFG 2.25 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 261.72833251953125 (512, 256, 256, 3) Calc FID for CFG 2.25 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 257.94683837890625 (512, 256, 256, 3) Calc FID for CFG 2.5 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 7.839601039886475 (512, 256, 256, 3) Calc FID for CFG 2.5 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 7.884913444519043 (512, 256, 256, 3) Calc FID for CFG 2.5 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 8.095534324645996 (512, 256, 256, 3) Calc FID for CFG 2.5 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 9.196731567382812 (512, 256, 256, 3) Calc FID for CFG 2.5 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 15.235790252685547 (512, 256, 256, 3) Calc FID for CFG 2.5 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 56.361209869384766 (512, 256, 256, 3) Calc FID for CFG 2.5 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 256.31207275390625 (512, 256, 256, 3) Calc FID for CFG 2.5 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 255.73171997070312 (512, 256, 256, 3) Calc FID for CFG 2.75 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 8.575051307678223 (512, 256, 256, 3) Calc FID for CFG 2.75 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 8.555830001831055 (512, 256, 256, 3) Calc FID for CFG 2.75 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 8.639049530029297 (512, 256, 256, 3) Calc FID for CFG 2.75 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 9.298751831054688 (512, 256, 256, 3) Calc FID for CFG 2.75 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 13.668627738952637 (512, 256, 256, 3) Calc FID for CFG 2.75 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 48.6970329284668 (512, 256, 256, 3) Calc FID for CFG 2.75 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 251.49713134765625 (512, 256, 256, 3) Calc FID for CFG 2.75 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 253.87701416015625 (512, 256, 256, 3) Calc FID for CFG 3.0 and denoise_timesteps 128 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 9.52678108215332 (512, 256, 256, 3) Calc FID for CFG 3.0 and denoise_timesteps 64 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 9.481215476989746 (512, 256, 256, 3) Calc FID for CFG 3.0 and denoise_timesteps 32 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 9.447056770324707 (512, 256, 256, 3) Calc FID for CFG 3.0 and denoise_timesteps 16 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 9.796664237976074 (512, 256, 256, 3) Calc FID for CFG 3.0 and denoise_timesteps 8 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 12.878477096557617 (512, 256, 256, 3) Calc FID for CFG 3.0 and denoise_timesteps 4 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 42.54828643798828 (512, 256, 256, 3) Calc FID for CFG 3.0 and denoise_timesteps 2 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 247.09524536132812 (512, 256, 256, 3) Calc FID for CFG 3.0 and denoise_timesteps 1 DiT: Input of shape (512, 32, 32, 4) dtype float32 DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 DiT: Conditioning of shape (512, 768) dtype float32 FID is 252.1986083984375 wandb: wandb: 🚀 View run shortcut_imagenet256 at: https://wandb.ai/daniel-z-kaplan/shortcut/runs/shortcut_imagenet256_20250807_234640_345353_10 wandb: Find logs at: ../../../tmp/tmp4eouwj9f/wandb/run-20250807_234640-shortcut_imagenet256_20250807_234640_345353_10/logs