div18 commited on
Commit
863bb8c
·
1 Parent(s): 1b9be85
Files changed (2) hide show
  1. training/config.yaml +1 -1
  2. training/launch_train.py +1 -0
training/config.yaml CHANGED
@@ -46,7 +46,7 @@ loss_type: "reinforce_baseline" # reinforce | reinforce_baseline | grpo
46
  num_episodes_per_iteration: 4 # Safe now: max_seq_len=512 + loss_batch_size=8 + CPU offload
47
  num_iterations: 500 # Total training iterations
48
  parallel_episodes: true # Batch generation across episodes (10x faster)
49
- loss_batch_size: 4 # Qwen3.5 logits = 4×512×151936×4 = 1.24 GB per batch
50
  learning_rate: 2.0e-4
51
  per_device_train_batch_size: 2 # A10G can handle 2 with seq_len=1024
52
  gradient_accumulation_steps: 4 # Effective batch = 2*4 = 8 transitions
 
46
  num_episodes_per_iteration: 4 # Safe now: max_seq_len=512 + loss_batch_size=8 + CPU offload
47
  num_iterations: 500 # Total training iterations
48
  parallel_episodes: true # Batch generation across episodes (10x faster)
49
+ loss_batch_size: 2 # Qwen3.5 logits = 2×512×151936×4 = 0.62 GB per batch
50
  learning_rate: 2.0e-4
51
  per_device_train_batch_size: 2 # A10G can handle 2 with seq_len=1024
52
  gradient_accumulation_steps: 4 # Effective batch = 2*4 = 8 transitions
training/launch_train.py CHANGED
@@ -121,6 +121,7 @@ def build_job_command() -> str:
121
  "done\n"
122
  "\n"
123
  "echo '[bootstrap] Launching training (local server, Hub persistence)...'\n"
 
124
  "ANTIATROPOS_HUB_MODEL_REPO=$HUB_MODEL_REPO "
125
  "ANTIATROPOS_HUB_METRICS_DATASET=$HUB_METRICS_DATASET "
126
  "ANTIATROPOS_ENV_URL=http://localhost:8000 "
 
121
  "done\n"
122
  "\n"
123
  "echo '[bootstrap] Launching training (local server, Hub persistence)...'\n"
124
+ "export PYTORCH_ALLOC_CONF='expandable_segments:True' # required by Qwen3.5 to avoid OOM fragmentation\n"
125
  "ANTIATROPOS_HUB_MODEL_REPO=$HUB_MODEL_REPO "
126
  "ANTIATROPOS_HUB_METRICS_DATASET=$HUB_METRICS_DATASET "
127
  "ANTIATROPOS_ENV_URL=http://localhost:8000 "