# Dataset data_path: "embedded-language-flows/xsum_train_t5" eval_data_path: "embedded-language-flows/xsum_validation_t5" max_length: 1088 max_input_length: 1024 pad_token: eos # Encoder encoder_model_name: t5-small encoder_checkpoint: "embedded-language-flows/t5_small_encoder_jax/t5_small_encoder_jax.pkl" latent_mean: 0.0 latent_std: 0.2 # Model architecture model: ELF-B # ELF-B, ELF-M bottleneck_dim: 128 num_time_tokens: 4 num_self_cond_cfg_tokens: 4 num_model_mode_tokens: 4 # Denoiser objective denoiser_p_mean: -1.5 denoiser_p_std: 0.8 denoiser_noise_scale: 2.0 t_eps: 0.05 time_schedule: "logit_normal" # Decoder objective decoder_prob: 0.2 decoder_noise_scale: 5.0 decoder_p_mean: 0.8 decoder_p_std: 0.8 # Conditioning / CFG label_drop_prob: 0.1 self_cond_prob: 0.5 # Training (optimizer + schedule) epochs: 100 global_batch_size: 512 blr: 0.001 weight_decay: 0.0 warmup_steps: 5000 optimizer: muon # EMA ema_decay1: 0.9999 # Sampling sampling_configs_path: "configs/sampling_configs/cond_sampling_configs.yml" num_samples: 5000 # Logging & Checkpointing log_freq: 100 save_freq: 10 save_last_freq: 1000 eval_freq: 10 # Output output_dir: "outputs/elf_b-xsum" resume: null # Wandb use_wandb: true wandb_project: elf wandb_entity: null wandb_run_name: elf_b-xsum # Misc seed: 42