| # Dataset | |
| data_path: "embedded-language-flows/xsum_train_t5" | |
| eval_data_path: "embedded-language-flows/xsum_validation_t5" | |
| max_length: 1088 | |
| max_input_length: 1024 | |
| pad_token: eos | |
| # Encoder | |
| encoder_model_name: t5-small | |
| encoder_checkpoint: "embedded-language-flows/t5_small_encoder_jax/t5_small_encoder_jax.pkl" | |
| latent_mean: 0.0 | |
| latent_std: 0.2 | |
| # Model architecture | |
| model: ELF-B # ELF-B, ELF-M | |
| bottleneck_dim: 128 | |
| num_time_tokens: 4 | |
| num_self_cond_cfg_tokens: 4 | |
| num_model_mode_tokens: 4 | |
| # Denoiser objective | |
| denoiser_p_mean: -1.5 | |
| denoiser_p_std: 0.8 | |
| denoiser_noise_scale: 2.0 | |
| t_eps: 0.05 | |
| time_schedule: "logit_normal" | |
| # Decoder objective | |
| decoder_prob: 0.2 | |
| decoder_noise_scale: 5.0 | |
| decoder_p_mean: 0.8 | |
| decoder_p_std: 0.8 | |
| # Conditioning / CFG | |
| label_drop_prob: 0.1 | |
| self_cond_prob: 0.5 | |
| # Training (optimizer + schedule) | |
| epochs: 100 | |
| global_batch_size: 512 | |
| blr: 0.001 | |
| weight_decay: 0.0 | |
| warmup_steps: 5000 | |
| optimizer: muon | |
| # EMA | |
| ema_decay1: 0.9999 | |
| # Sampling | |
| sampling_configs_path: "configs/sampling_configs/cond_sampling_configs.yml" | |
| num_samples: 5000 | |
| # Logging & Checkpointing | |
| log_freq: 100 | |
| save_freq: 10 | |
| save_last_freq: 1000 | |
| eval_freq: 10 | |
| # Output | |
| output_dir: "outputs/elf_b-xsum" | |
| resume: null | |
| # Wandb | |
| use_wandb: true | |
| wandb_project: elf | |
| wandb_entity: null | |
| wandb_run_name: elf_b-xsum | |
| # Misc | |
| seed: 42 | |