| # Model architecture | |
| model: | |
| input_dim: 3 # RGB images | |
| hidden_dim: 512 | |
| num_blocks: 8 | |
| num_heads: 8 | |
| patch_size: 8 | |
| patch_stride: 4 | |
| time_freq_dim: 256 | |
| time_max_period: 1024 | |
| mlp_ratio: 4 | |
| use_bias: false | |
| padding: "SAME" | |
| pos_embed_cls_token: false | |
| pos_embed_extra_tokens: 0 | |
| # Training parameters | |
| training: | |
| learning_rate: 1.0e-4 | |
| batch_size: 128 | |
| num_steps: 1_000_000 | |
| warmup_pct: 0.01 | |
| weight_decay: 0.0 | |
| grad_clip_norm: 100.0 | |
| # Checkpointing and logging | |
| checkpointing: | |
| log_every: 1_000 | |
| plot_every: 10_000 | |
| save_every: 10_000 | |
| resume_from_checkpoint: null | |
| # Data | |
| data: | |
| train_split: 0.9 # 90% for training, 10% for testing | |
| random_seed: 42 | |