| datasets: | |
| - name: chatalpaca_multiturn_enriched | |
| repo_id: BRlkl/chatalpaca-multiturn-enriched | |
| source_split: train | |
| format: messages_all_turns | |
| validation_ratio: 0.02 | |
| split_seed: 17 | |
| min_turns: 2 | |
| max_turns: 6 | |
| max_message_chars: 6000 | |
| use_base_chat_template: true | |
| model: | |
| base_model_name: google/t5gemma-l-l-prefixlm-it | |
| initial_model_path: BRlkl/test_1024 | |
| dtype: bfloat16 | |
| attn_implementation: sdpa | |
| disable_cudnn_sdp: true | |
| disable_mha_fastpath: true | |
| magicnorm_eps: 1.0e-6 | |
| z_slots: 1024 | |
| num_time_tokens: 0 | |
| use_explicit_time_features: false | |
| gate_attention_heads: 4 | |
| max_observation_tokens: 1024 | |
| max_decoder_tokens: 1024 | |
| thought_loop_proposal_mode: observation_hidden_compression | |
| preserve_observation_encoder_manifold: true | |
| observation_encoder_use_state_context: true | |
| latent_attention_mask_mode: full | |
| hard_state_replace: true | |
| training: | |
| seed: 17 | |
| num_workers: 2 | |
| gradient_checkpointing: true | |
| mixed_precision: bf16 | |
| max_grad_norm: 1.0 | |
| weight_decay: 0.01 | |
| backbone_learning_rate: 5.0e-6 | |
| new_module_learning_rate: 1.0e-4 | |
| adam_beta1: 0.9 | |
| adam_beta2: 0.95 | |
| adam_epsilon: 1.0e-8 | |
| fused_adamw: true | |
| freeze_gate_head: true | |
| assistant_feedback_mode: teacher_forced | |
| log_every_steps: 1 | |
| eval_every_steps: 100 | |
| checkpoint_every_steps: 500 | |
| eval_max_batches: 16 | |
| validation_behavior_max_batches: 4 | |
| max_train_examples: | |
| max_validation_examples: | |
| response_loss_weight: 0.33 | |
| current_user_reconstruction_loss_weight: 0.33 | |
| probe_loss_weight: 0.33 | |
| probe_question_text: "What is everything we have talked about so far? Give exact conversation transcript verbatim in following format: [User 1]: X [Assistant 1]: Y [User 2]: A etc" | |
| feedback_generation_max_new_tokens: 1024 | |
| feedback_generation_extra_new_tokens: 16 | |
| validation_response_max_new_tokens: 1024 | |
| validation_response_extra_new_tokens: 16 | |
| validation_probe_max_new_tokens: 1024 | |
| validation_probe_extra_new_tokens: 16 | |
| wandb_train_metric_keys: | |
| - train/loss_total | |
| - train/loss_response | |
| - train/loss_current_user_reconstruction | |
| - train/loss_probe | |
| - train/response_first_token_exact_match | |
| - train/current-user_reconstruction_first_token_exact_match | |
| - train/probe_first_token_exact_match | |
| wandb_validation_metric_keys: | |
| - validation/loss_total | |
| - validation/loss_response | |
| - validation/loss_current_user_reconstruction | |
| - validation/loss_probe | |
| - validation/goal_loss | |
| - validation/response_similarity | |
| - validation/response_reconstruction_similarity | |
| - validation/probe_transcript_similarity | |
| checkpoint_selection_metric: validation/goal_loss | |
| checkpoint_selection_mode: min | |
| validation_response_exact_miss_penalty: 1.0 | |
| validation_reconstruction_similarity_miss_penalty: 1.0 | |
| validation_probe_exact_miss_penalty: 1.0 | |
| validation_probe_similarity_miss_penalty: 2.0 | |
| phase: | |
| micro_batch_size: 10 | |
| eval_batch_size: 10 | |
| gradient_accumulation_steps: 4 | |
| num_train_epochs: 5 | |
| warmup_ratio: 0.03 | |
| shuffle_train: true | |
| cache: | |
| preprocessed_root: cache/preprocessed_pre_sft_multiturn_simple_transcript | |
| paths: | |
| run_root: runs_pre_sft_multiturn_simple_transcript | |
| export_root: exports_multiturn_simple_transcript | |
| inference: | |
| format: predictive_state_multiturn | |
| use_base_chat_template: true | |
| wandb: | |
| enabled: true | |
| project: samantha-pre-sft | |
| run_name: t5gemma2-thoughtloop-pre-sft-simple-transcript | |
| hub: | |
| model_repo_id: BRlkl/test_multiturn_simple_transcript1024_2 | |
| private: false | |