Sssplendid commited on
Commit
3123eac
·
verified ·
1 Parent(s): d4b4cad

Add 350m/conda_lr1e_2_b1_0_9_b2_0_99_eps_1e_8_scale_0_25_rank256_T2000_H200_ppl16_4542_20260423_202415

Browse files
Files changed (15) hide show
  1. .gitattributes +1 -0
  2. 350m/conda_lr1e_2_b1_0_9_b2_0_99_eps_1e_8_scale_0_25_rank256_T2000_H200_ppl16_4542_20260423_202415/conda_lr1e_2_b1_0_9_b2_0_99_eps_1e_8_scale_0_25_rank256_T2000_H200_ppl16_4542_20260423_202415.txt +0 -0
  3. 350m/conda_lr1e_2_b1_0_9_b2_0_99_eps_1e_8_scale_0_25_rank256_T2000_H200_ppl16_4542_20260423_202415/model_60000/config.json +31 -0
  4. 350m/conda_lr1e_2_b1_0_9_b2_0_99_eps_1e_8_scale_0_25_rank256_T2000_H200_ppl16_4542_20260423_202415/model_60000/model.safetensors +3 -0
  5. 350m/conda_lr1e_2_b1_0_9_b2_0_99_eps_1e_8_scale_0_25_rank256_T2000_H200_ppl16_4542_20260423_202415/model_60000/optimizer.pt +3 -0
  6. 350m/conda_lr1e_2_b1_0_9_b2_0_99_eps_1e_8_scale_0_25_rank256_T2000_H200_ppl16_4542_20260423_202415/model_60000/pytorch_model.bin +3 -0
  7. 350m/conda_lr1e_2_b1_0_9_b2_0_99_eps_1e_8_scale_0_25_rank256_T2000_H200_ppl16_4542_20260423_202415/model_60000/training_state.json +7 -0
  8. 350m/conda_lr1e_2_b1_0_9_b2_0_99_eps_1e_8_scale_0_25_rank256_T2000_H200_ppl16_4542_20260423_202415/wandb.json +3 -0
  9. 350m/conda_lr1e_2_b1_0_9_b2_0_99_eps_1e_8_scale_0_25_rank256_T2000_H200_ppl16_4542_20260423_202415/wandb/debug-internal.log +15 -0
  10. 350m/conda_lr1e_2_b1_0_9_b2_0_99_eps_1e_8_scale_0_25_rank256_T2000_H200_ppl16_4542_20260423_202415/wandb/debug.log +22 -0
  11. 350m/conda_lr1e_2_b1_0_9_b2_0_99_eps_1e_8_scale_0_25_rank256_T2000_H200_ppl16_4542_20260423_202415/wandb/offline-run-20260423_202422-mawza3ul/files/SAC/torchrun_main.py +612 -0
  12. 350m/conda_lr1e_2_b1_0_9_b2_0_99_eps_1e_8_scale_0_25_rank256_T2000_H200_ppl16_4542_20260423_202415/wandb/offline-run-20260423_202422-mawza3ul/files/requirements.txt +142 -0
  13. 350m/conda_lr1e_2_b1_0_9_b2_0_99_eps_1e_8_scale_0_25_rank256_T2000_H200_ppl16_4542_20260423_202415/wandb/offline-run-20260423_202422-mawza3ul/logs/debug-internal.log +15 -0
  14. 350m/conda_lr1e_2_b1_0_9_b2_0_99_eps_1e_8_scale_0_25_rank256_T2000_H200_ppl16_4542_20260423_202415/wandb/offline-run-20260423_202422-mawza3ul/logs/debug.log +22 -0
  15. 350m/conda_lr1e_2_b1_0_9_b2_0_99_eps_1e_8_scale_0_25_rank256_T2000_H200_ppl16_4542_20260423_202415/wandb/offline-run-20260423_202422-mawza3ul/run-mawza3ul.wandb +3 -0
.gitattributes CHANGED
@@ -58,3 +58,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
58
  60m/soap_lr3e_3_b1_0_9_b2_0_95_eps_1e_8_A100_ppl_29_4706_20260416_193855/wandb/offline-run-20260416_193930-flol3ksy/run-flol3ksy.wandb filter=lfs diff=lfs merge=lfs -text
59
  60m/sophia_lr2e_4_b1_0_9_b2_0_99_eps_1e_8_A100_ppl_36_2695_20260416_193855/wandb/offline-run-20260416_193930-x2s18q4b/run-x2s18q4b.wandb filter=lfs diff=lfs merge=lfs -text
60
  130m/adabelief_lr1e_3_b1_0_9_b2_0_999_eps_1e_16_A100_ppl_23_4537_20260417_181954/wandb/offline-run-20260417_221600-a8vnu42p/run-a8vnu42p.wandb filter=lfs diff=lfs merge=lfs -text
 
 
58
  60m/soap_lr3e_3_b1_0_9_b2_0_95_eps_1e_8_A100_ppl_29_4706_20260416_193855/wandb/offline-run-20260416_193930-flol3ksy/run-flol3ksy.wandb filter=lfs diff=lfs merge=lfs -text
59
  60m/sophia_lr2e_4_b1_0_9_b2_0_99_eps_1e_8_A100_ppl_36_2695_20260416_193855/wandb/offline-run-20260416_193930-x2s18q4b/run-x2s18q4b.wandb filter=lfs diff=lfs merge=lfs -text
60
  130m/adabelief_lr1e_3_b1_0_9_b2_0_999_eps_1e_16_A100_ppl_23_4537_20260417_181954/wandb/offline-run-20260417_221600-a8vnu42p/run-a8vnu42p.wandb filter=lfs diff=lfs merge=lfs -text
61
+ 350m/conda_lr1e_2_b1_0_9_b2_0_99_eps_1e_8_scale_0_25_rank256_T2000_H200_ppl16_4542_20260423_202415/wandb/offline-run-20260423_202422-mawza3ul/run-mawza3ul.wandb filter=lfs diff=lfs merge=lfs -text
350m/conda_lr1e_2_b1_0_9_b2_0_99_eps_1e_8_scale_0_25_rank256_T2000_H200_ppl16_4542_20260423_202415/conda_lr1e_2_b1_0_9_b2_0_99_eps_1e_8_scale_0_25_rank256_T2000_H200_ppl16_4542_20260423_202415.txt ADDED
The diff for this file is too large to render. See raw diff
 
350m/conda_lr1e_2_b1_0_9_b2_0_99_eps_1e_8_scale_0_25_rank256_T2000_H200_ppl16_4542_20260423_202415/model_60000/config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "LlamaForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 0,
8
+ "dtype": "bfloat16",
9
+ "eos_token_id": 1,
10
+ "head_dim": 64,
11
+ "hidden_act": "silu",
12
+ "hidden_size": 1024,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 2736,
15
+ "max_position_embeddings": 2048,
16
+ "max_sequence_length": 1024,
17
+ "mlp_bias": false,
18
+ "model_type": "llama",
19
+ "num_attention_heads": 16,
20
+ "num_hidden_layers": 24,
21
+ "num_key_value_heads": 16,
22
+ "pad_token_id": -1,
23
+ "pretraining_tp": 1,
24
+ "rms_norm_eps": 1e-06,
25
+ "rope_scaling": null,
26
+ "rope_theta": 10000.0,
27
+ "tie_word_embeddings": false,
28
+ "transformers_version": "4.57.3",
29
+ "use_cache": true,
30
+ "vocab_size": 32000
31
+ }
350m/conda_lr1e_2_b1_0_9_b2_0_99_eps_1e_8_scale_0_25_rank256_T2000_H200_ppl16_4542_20260423_202415/model_60000/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ae4737309e6d86fd1a6704986f79d99c5701b932b40a55b1defe52265b26314c
3
+ size 735967792
350m/conda_lr1e_2_b1_0_9_b2_0_99_eps_1e_8_scale_0_25_rank256_T2000_H200_ppl16_4542_20260423_202415/model_60000/optimizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fae0081c25762de738fa1ee41275b545fbcb33e26dbbc712ac740dbffe6185bf
3
+ size 1824395851
350m/conda_lr1e_2_b1_0_9_b2_0_99_eps_1e_8_scale_0_25_rank256_T2000_H200_ppl16_4542_20260423_202415/model_60000/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc0e795939cf0c432c6675fb69805b3f4beaadf0f7e4d45715ecb5cc46766541
3
+ size 736040495
350m/conda_lr1e_2_b1_0_9_b2_0_99_eps_1e_8_scale_0_25_rank256_T2000_H200_ppl16_4542_20260423_202415/model_60000/training_state.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "global_step": 60000,
3
+ "update_step": 60000,
4
+ "tokens_seen": 5997831592,
5
+ "tokens_seen_before": 5997732940,
6
+ "update_time": 1.0997016429901123
7
+ }
350m/conda_lr1e_2_b1_0_9_b2_0_99_eps_1e_8_scale_0_25_rank256_T2000_H200_ppl16_4542_20260423_202415/wandb.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "wandb_id": "mawza3ul"
3
+ }
350m/conda_lr1e_2_b1_0_9_b2_0_99_eps_1e_8_scale_0_25_rank256_T2000_H200_ppl16_4542_20260423_202415/wandb/debug-internal.log ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"time":"2026-04-23T20:24:22.4996429+08:00","level":"INFO","msg":"wandb-core"}
2
+ {"time":"2026-04-23T20:24:22.499901972+08:00","level":"INFO","msg":"stream: starting","core version":"0.26.0"}
3
+ {"time":"2026-04-23T20:24:22.63528813+08:00","level":"WARN","msg":"featurechecker: GraphQL client is nil, skipping feature loading"}
4
+ {"time":"2026-04-23T20:24:22.635312622+08:00","level":"WARN","msg":"featurechecker: GraphQL client is nil, skipping feature loading"}
5
+ {"time":"2026-04-23T20:24:22.635332006+08:00","level":"INFO","msg":"stream: created new stream","id":"mawza3ul"}
6
+ {"time":"2026-04-23T20:24:22.635539387+08:00","level":"INFO","msg":"handler: started"}
7
+ {"time":"2026-04-23T20:24:22.636204984+08:00","level":"INFO","msg":"stream: started"}
8
+ {"time":"2026-04-23T20:24:22.636292599+08:00","level":"INFO","msg":"writer: started","stream_id":"mawza3ul"}
9
+ {"time":"2026-04-23T20:24:22.636304487+08:00","level":"INFO","msg":"sender: started"}
10
+ {"time":"2026-04-23T20:24:22.637327692+08:00","level":"WARN","msg":"featurechecker: GraphQL client is nil, skipping feature loading"}
11
+ {"time":"2026-04-23T20:24:22.637343161+08:00","level":"WARN","msg":"runupserter: server does not expand metric globs but the x_server_side_expand_glob_metrics setting is set; ignoring"}
12
+ {"time":"2026-04-24T01:14:24.730723575+08:00","level":"INFO","msg":"stream: closing"}
13
+ {"time":"2026-04-24T01:14:24.753028533+08:00","level":"INFO","msg":"handler: closed"}
14
+ {"time":"2026-04-24T01:14:24.753387495+08:00","level":"INFO","msg":"sender: closed"}
15
+ {"time":"2026-04-24T01:14:24.753400977+08:00","level":"INFO","msg":"stream: closed"}
350m/conda_lr1e_2_b1_0_9_b2_0_99_eps_1e_8_scale_0_25_rank256_T2000_H200_ppl16_4542_20260423_202415/wandb/debug.log ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2026-04-23 20:24:22,097 INFO MainThread:342 [wandb_setup.py:_flush():81] Current SDK version is 0.26.0
2
+ 2026-04-23 20:24:22,098 INFO MainThread:342 [wandb_setup.py:_flush():81] Configure stats pid to 342
3
+ 2026-04-23 20:24:22,098 INFO MainThread:342 [wandb_setup.py:_flush():81] Loading settings from environment variables
4
+ 2026-04-23 20:24:22,098 INFO MainThread:342 [wandb_init.py:setup_run_log_directory():721] Logging user logs to exp_remain_h200/work_dirs/350m/train_350m_conda_lr1e_2_scale0_25_rank256_gap2000_20260423_202415/wandb/offline-run-20260423_202422-mawza3ul/logs/debug.log
5
+ 2026-04-23 20:24:22,098 INFO MainThread:342 [wandb_init.py:setup_run_log_directory():722] Logging internal logs to exp_remain_h200/work_dirs/350m/train_350m_conda_lr1e_2_scale0_25_rank256_gap2000_20260423_202415/wandb/offline-run-20260423_202422-mawza3ul/logs/debug-internal.log
6
+ 2026-04-23 20:24:22,098 INFO MainThread:342 [wandb_init.py:init():848] calling init triggers
7
+ 2026-04-23 20:24:22,098 INFO MainThread:342 [wandb_init.py:init():853] wandb.init called with sweep_config: {}
8
+ config: {'_wandb': {}}
9
+ 2026-04-23 20:24:22,098 INFO MainThread:342 [wandb_init.py:init():896] starting backend
10
+ 2026-04-23 20:24:22,494 INFO MainThread:342 [wandb_init.py:init():911] sending inform_init request
11
+ 2026-04-23 20:24:22,498 INFO MainThread:342 [wandb_init.py:init():919] backend started and connected
12
+ 2026-04-23 20:24:22,498 INFO MainThread:342 [wandb_init.py:init():989] updated telemetry
13
+ 2026-04-23 20:24:22,526 INFO MainThread:342 [wandb_init.py:init():1013] communicating run to backend with 90.0 second timeout
14
+ 2026-04-23 20:24:22,638 INFO MainThread:342 [wandb_init.py:init():1058] starting run threads in backend
15
+ 2026-04-23 20:24:22,712 INFO MainThread:342 [wandb_run.py:_console_start():2542] atexit reg
16
+ 2026-04-23 20:24:22,712 INFO MainThread:342 [wandb_run.py:_redirect():2391] redirect: wrap_raw
17
+ 2026-04-23 20:24:22,712 INFO MainThread:342 [wandb_run.py:_redirect():2460] Wrapping output streams.
18
+ 2026-04-23 20:24:22,712 INFO MainThread:342 [wandb_run.py:_redirect():2483] Redirects installed.
19
+ 2026-04-23 20:24:22,714 INFO MainThread:342 [wandb_init.py:init():1098] run started, returning control to user process
20
+ 2026-04-23 20:24:49,867 INFO MainThread:342 [wandb_run.py:_config_callback():1403] config_cb None None {'model_config': 'configs/llama_350m.json', 'exp_config': 'exp_v2/configs/llama_350m.json', 'eval_every': 1000, 'save_every': 60000, 'dtype': 'bfloat16', 'seed': 0, 'compile': True, 'dynamo_suppress_errors': True, 'dynamo_cache_limit': 10000, 'memory_cleanup_frequency': 10000, 'resume_step': None, 'restore_optimizer': False, 'continue_from': None, 'single_gpu': False, 'save_dir': 'exp_remain_h200/work_dirs/350m/train_350m_conda_lr1e_2_scale0_25_rank256_gap2000_20260423_202415', 'use_hf_model': False, 'workers': 12, 'batch_size': 128, 'gradient_accumulation': 1, 'total_batch_size': 512, 'warmup_steps': 6000, 'num_training_steps': 60000, 'max_train_tokens': None, 'optimizer': 'conda', 'max_length': 256, 'scheduler': 'cosine', 'min_lr_ratio': 0.1, 'weight_decay': 0.0, 'grad_clipping': 0.0, 'activation_checkpointing': False, 'data_path': '/mnt/shared-storage-gpfs2/finebio-shared/optimizer/dataset/C4/en', 'data_name': 'en', 'tags': None, 'name': 'test', 'project': 'test', 'unset_wandb': False, 'entity': None, 'wandb_dir': 'exp_remain_h200/work_dirs/350m/train_350m_conda_lr1e_2_scale0_25_rank256_gap2000_20260423_202415', 'beta1': 0.9, 'beta2': 0.99, 'beta3': 0.99, 'eps': 1e-08, 'rank': 256, 'update_proj_gap': 2000, 'galore_scale': 1.0, 'proj_type': 'std', 'proj_quant': False, 'proj_bits': 8, 'proj_group_size': 256, 'weight_quant': False, 'weight_bits': 8, 'weight_group_size': 256, 'stochastic_round': False, 'simulation': False, 'cos_threshold': 1, 'gamma_proj': 2, 'queue_size': 5, 'proj': 'random', 'scale_type': 'channel', 'apollo_scale': 0.25, 'scale_front': False, 'n_clusters': 3, 'scale_update_freq': 500, 'scale_level': '1,0,1,1', 'scale_bound': None, 'metric': 'mean', 'align_grad': False, 'dim': 4096, 'n_heads': 32, 'muon_ns_steps': 5, 'muon_momentum': 0.95, 'nproc_per_node': 4, 'max_lr': 0.01, 'total_params_M': 367.96928, 'dataset': 'c4', 'model': {'vocab_size': 32000, 'max_position_embeddings': 2048, 'hidden_size': 1024, 'intermediate_size': 2736, 'num_hidden_layers': 24, 'num_attention_heads': 16, 'num_key_value_heads': 16, 'hidden_act': 'silu', 'initializer_range': 0.02, 'rms_norm_eps': 1e-06, 'pretraining_tp': 1, 'use_cache': True, 'rope_theta': 10000.0, 'rope_scaling': None, 'attention_bias': False, 'attention_dropout': 0.0, 'mlp_bias': False, 'head_dim': 64, 'return_dict': True, 'output_hidden_states': False, 'torchscript': False, 'dtype': None, 'pruned_heads': {}, 'tie_word_embeddings': False, 'chunk_size_feed_forward': 0, 'is_encoder_decoder': False, 'is_decoder': False, 'cross_attention_hidden_size': None, 'add_cross_attention': False, 'tie_encoder_decoder': False, 'architectures': ['LLaMAForCausalLM'], 'finetuning_task': None, 'id2label': {0: 'LABEL_0', 1: 'LABEL_1'}, 'label2id': {'LABEL_0': 0, 'LABEL_1': 1}, 'task_specific_params': None, 'problem_type': None, 'tokenizer_class': None, 'prefix': None, 'bos_token_id': 0, 'pad_token_id': -1, 'eos_token_id': 1, 'sep_token_id': None, 'decoder_start_token_id': None, 'max_length': 20, 'min_length': 0, 'do_sample': False, 'early_stopping': False, 'num_beams': 1, 'temperature': 1.0, 'top_k': 50, 'top_p': 1.0, 'typical_p': 1.0, 'repetition_penalty': 1.0, 'length_penalty': 1.0, 'no_repeat_ngram_size': 0, 'encoder_no_repeat_ngram_size': 0, 'bad_words_ids': None, 'num_return_sequences': 1, 'output_scores': False, 'return_dict_in_generate': False, 'forced_bos_token_id': None, 'forced_eos_token_id': None, 'remove_invalid_values': False, 'exponential_decay_length_penalty': None, 'suppress_tokens': None, 'begin_suppress_tokens': None, 'num_beam_groups': 1, 'diversity_penalty': 0.0, '_name_or_path': 'configs/llama_350m.json', 'transformers_version': '4.57.3', 'max_sequence_length': 1024, 'model_type': 'llama', 'tf_legacy_loss': False, 'use_bfloat16': False, 'output_attentions': False}, 'world_size': 4, 'device': 'cuda:0'}
21
+ 2026-04-24 01:14:24,727 INFO wandb-AsyncioManager-main:342 [service_client.py:_forward_responses():134] Reached EOF.
22
+ 2026-04-24 01:14:24,730 INFO wandb-AsyncioManager-main:342 [mailbox.py:close():155] Closing mailbox, abandoning 0 handles.
350m/conda_lr1e_2_b1_0_9_b2_0_99_eps_1e_8_scale_0_25_rank256_T2000_H200_ppl16_4542_20260423_202415/wandb/offline-run-20260423_202422-mawza3ul/files/SAC/torchrun_main.py ADDED
@@ -0,0 +1,612 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ import os
7
+ import time
8
+ import json
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.distributed as dist
13
+
14
+ from tqdm import tqdm
15
+ from loguru import logger
16
+
17
+ import transformers
18
+
19
+ transformers.logging.set_verbosity_error()
20
+
21
+ import wandb
22
+
23
+ from utils.argparse import parse_args
24
+ from utils.setup import getting_svd_cnt, set_seed, setup_model, saving_model_weight, load_model_weight
25
+ from utils.optimizer_factory import setup_optimization
26
+ from utils.eval import evaluate_model
27
+ from utils.dataloader import setup_dataset
28
+ from utils.modeling_llama import LlamaForCausalLM
29
+ from utils.fake_quantization import QLinear
30
+ from utils.quantization import QScaleLinear
31
+ from opt.hybrid import HybridOptimizer, HybridScheduler
32
+
33
+
34
+ def main(args):
35
+ import torch
36
+ ############ Setup random seed ############
37
+ set_seed(args)
38
+
39
+ ############ Setup DDP environment ############
40
+ assert "LOCAL_RANK" in os.environ, "torchrun should set LOCAL_RANK"
41
+ global_rank = int(os.environ["RANK"])
42
+ local_rank = int(os.environ["LOCAL_RANK"])
43
+ world_size = int(os.environ["WORLD_SIZE"])
44
+ torch.cuda.set_device(local_rank)
45
+
46
+ logger.info(f"Global rank {global_rank}, local rank {local_rank}, device: {torch.cuda.current_device()}")
47
+ dist.init_process_group(backend="nccl", rank=global_rank, world_size=world_size)
48
+
49
+ logger.info("Process group initialized")
50
+ device = f"cuda:{local_rank}"
51
+
52
+ if global_rank != 0:
53
+ logger.remove() # turn off logger
54
+
55
+ logger.info(f"Using dist with rank {global_rank} (only rank 0 will log)")
56
+ logger.info("*" * 40)
57
+ logger.info(f"Starting training with the arguments")
58
+ for k, v in vars(args).items():
59
+ logger.info(f"{k:30} {v}")
60
+ logger.info("*" * 40)
61
+
62
+ ############ Initialize wandb without config (it is passed later) ############
63
+ if (not args.unset_wandb) and global_rank == 0:
64
+ if args.entity is None:
65
+ os.environ['WANDB_MODE'] = 'offline'
66
+ # Set wandb directory for offline mode
67
+ wandb_dir = getattr(args, 'wandb_dir', None) if getattr(args, 'wandb_dir', None) is not None else args.save_dir
68
+ if getattr(args, 'wandb_dir', None) is not None:
69
+ logger.info(f"Wandb directory set to: {wandb_dir}")
70
+ wandb.init(project=args.project, name=args.name, entity=args.entity, dir=wandb_dir)
71
+
72
+ ############ Setup training data ############
73
+ if args.total_batch_size is not None:
74
+ if args.gradient_accumulation is None:
75
+ assert args.total_batch_size % world_size == 0, "total_batch_size must be divisible by world_size"
76
+ args.gradient_accumulation = args.total_batch_size // (args.batch_size * world_size)
77
+ assert args.gradient_accumulation > 0, "gradient_accumulation must be greater than 0"
78
+
79
+ assert (
80
+ args.gradient_accumulation * args.batch_size * world_size == args.total_batch_size
81
+ ), "gradient_accumulation * batch_size * world_size must be equal to total_batch_size"
82
+
83
+ dataloader, tokenizer = setup_dataset(args, global_rank, world_size)
84
+
85
+ ############ Initialize model ############
86
+ model_config, model = setup_model(args)
87
+ # Ensure model has generation_config (fix for transformers version compatibility)
88
+ if not hasattr(model, 'generation_config') or model.generation_config is None:
89
+ from transformers import GenerationConfig
90
+ model.generation_config = GenerationConfig()
91
+ model.generation_config.pad_token_id = tokenizer.pad_token_id
92
+
93
+ ############ Resuming from checkpoints ############
94
+ global_step = 0
95
+ update_step = 0
96
+ beginning_step = 0
97
+ tokens_seen = 0
98
+ tokens_seen_before = 0
99
+
100
+ # identifying checkpointing
101
+ if args.continue_from is not None and os.path.exists(args.continue_from):
102
+ # searching the latest checkpoints
103
+ checkpoint_path_list = os.listdir(args.continue_from)
104
+ checkpoint_path_list = [int(x.split("_")[-1]) for x in checkpoint_path_list if x.startswith("model_")]
105
+ if len(checkpoint_path_list) > 0:
106
+ logger.info("Find Checkpoints", checkpoint_path_list)
107
+ beginning_step = max(checkpoint_path_list)
108
+ if args.resume_step is not None:
109
+ beginning_step = args.resume_step
110
+ args.continue_from = os.path.join(args.continue_from, f"model_{beginning_step}")
111
+ logger.info("Continue from", args.continue_from)
112
+ else:
113
+ logger.warning(f"Did not find any checkpoints in {args.continue_from}")
114
+ args.continue_from = None
115
+
116
+ # resuming from checkpointing
117
+ if args.continue_from is not None:
118
+ logger.info("*" * 40)
119
+ logger.info(f"Loading model from {args.continue_from}")
120
+ checkpoint_path = os.path.join(args.continue_from, "pytorch_model.bin")
121
+ if os.path.exists(checkpoint_path):
122
+ load_model_weight(model, checkpoint_path, args)
123
+ logger.info(f"Model successfully loaded (strict=False policy)")
124
+ else:
125
+ # Try safetensors format
126
+ checkpoint_path = os.path.join(args.continue_from, "model.safetensors")
127
+ if os.path.exists(checkpoint_path):
128
+ from safetensors import safe_open
129
+ tensors = {}
130
+ with safe_open(checkpoint_path, framework="pt", device=0) as f:
131
+ for k in f.keys():
132
+ tensors[k] = f.get_tensor(k)
133
+ print(k, tensors[k].shape)
134
+ ret = model.load_state_dict(tensors, strict=False)
135
+ logger.info(f"Model successfully loaded from safetensors (strict=False policy)", ret)
136
+ else:
137
+ logger.warning(f"No model checkpoint found in {args.continue_from}")
138
+
139
+ if os.path.exists(os.path.join(args.continue_from, "training_state.json")):
140
+ logger.info(
141
+ f"Loading training state like global_step, update_step, and tokens_seen from {args.continue_from}"
142
+ )
143
+ with open(os.path.join(args.continue_from, "training_state.json")) as f:
144
+ _old_state = json.load(f)
145
+ global_step = _old_state["global_step"]
146
+ update_step = _old_state["update_step"]
147
+ tokens_seen = _old_state["tokens_seen"]
148
+ tokens_seen_before = _old_state["tokens_seen_before"]
149
+ logger.info(f"global_step : {global_step}")
150
+ logger.info(f"update_step : {update_step}")
151
+ logger.info(f"tokens_seen : {tokens_seen}")
152
+ logger.info(f"tokens_seen_before: {tokens_seen_before}")
153
+ logger.info(f"Will train for {args.num_training_steps - update_step} update steps")
154
+ else:
155
+ logger.warning(f"Did not find training state in {args.continue_from}, global step will start from zero")
156
+ logger.info("*" * 40)
157
+
158
+ ############ Setup model ############
159
+ if args.dtype in ["bf16", "bfloat16"]:
160
+ model = model.to(dtype=torch.bfloat16)
161
+ model = model.to(device=device)
162
+
163
+ for _, module in model.named_modules():
164
+ if isinstance(module, QScaleLinear):
165
+ weight_device = module.weight.device
166
+ module.weight.scales = module.weight.scales.to(device=weight_device)
167
+ module.weight.zeros = module.weight.zeros.to(device=weight_device)
168
+
169
+ n_total_params = sum(p.numel() for p in model.parameters())
170
+ trainable_params = [p for p in model.parameters() if p.requires_grad]
171
+ trainable_params_int8 = [p for p in model.parameters() if hasattr(p, "group_size")]
172
+
173
+ ############ Initialize wandb ############
174
+ run_config = dict(vars(args))
175
+ run_config.update(
176
+ {
177
+ "max_lr": run_config.pop("lr"), # rename lr to max_lr to avoid conflicts with scheduler
178
+ "total_params_M": n_total_params / 1_000_000,
179
+ "dataset": "c4",
180
+ "model": model_config.to_dict(),
181
+ "world_size": world_size,
182
+ "device": str(device),
183
+ }
184
+ )
185
+
186
+ if global_rank == 0:
187
+ if not args.unset_wandb:
188
+ wandb.config.update(run_config, allow_val_change=True)
189
+ wandb.save(os.path.abspath(__file__), policy="now") # save current script
190
+ # fix tqdm visual length to 80 so that the progress bar
191
+ # doesn't jump around when changing from external display to laptop
192
+ pbar = tqdm(total=args.num_training_steps - update_step, desc="Update steps", ncols=80)
193
+
194
+ ############ Initialize optimization ############
195
+ if "galore" in args.optimizer.lower():
196
+ # make parameters with "rank" to a single group, if param_name has "mlp" or "attn"
197
+ lowrank_params = []
198
+ target_modules_list = ["attn", "mlp"]
199
+ for module_name, module in model.named_modules():
200
+ if not (isinstance(module, nn.Linear) or isinstance(module, QScaleLinear) or isinstance(module, QLinear)):
201
+ continue
202
+ if not any(target_key in module_name for target_key in target_modules_list):
203
+ continue
204
+ logger.info(f"Adding {module_name} to GaLore parameters")
205
+ lowrank_params.append(module.weight)
206
+
207
+ id_lowrank_params = [id(p) for p in lowrank_params]
208
+ # make parameters without "rank" to another group
209
+ regular_params = [p for p in model.parameters() if id(p) not in id_lowrank_params]
210
+ # then call low rank optimizer
211
+ param_groups = [
212
+ {"params": regular_params},
213
+ {
214
+ "params": lowrank_params,
215
+ "rank": args.rank,
216
+ "update_proj_gap": args.update_proj_gap,
217
+ "scale": args.galore_scale,
218
+ "proj_type": args.proj_type,
219
+ "quant": args.proj_quant,
220
+ "quant_n_bit": args.proj_bits,
221
+ "quant_group_size": args.proj_group_size,
222
+ "cos_threshold": args.cos_threshold,
223
+ "gamma_proj": args.gamma_proj,
224
+ "queue_size": args.queue_size,
225
+ },
226
+ ]
227
+ elif "apollo" in args.optimizer.lower():
228
+ # make parameters with "rank" to a single group, if param_name has "mlp" or "attn"
229
+ lowrank_params = []
230
+ target_modules_list = ["attn", "mlp"]
231
+ for module_name, module in model.named_modules():
232
+ if not (isinstance(module, nn.Linear) or isinstance(module, QScaleLinear) or isinstance(module, QLinear)):
233
+ continue
234
+ if not any(target_key in module_name for target_key in target_modules_list):
235
+ continue
236
+ logger.info(f"Adding {module_name} to APOLLO parameters")
237
+ lowrank_params.append(module.weight)
238
+
239
+ id_lowrank_params = [id(p) for p in lowrank_params]
240
+ # make parameters without "rank" to another group
241
+ regular_params = [p for p in model.parameters() if id(p) not in id_lowrank_params]
242
+ # then call low rank optimizer
243
+ param_groups = [
244
+ {"params": regular_params},
245
+ {
246
+ "params": lowrank_params,
247
+ "rank": args.rank,
248
+ "update_proj_gap": args.update_proj_gap,
249
+ "scale": args.apollo_scale,
250
+ "proj_type": args.proj_type,
251
+ "proj": args.proj,
252
+ "scale_type": args.scale_type,
253
+ },
254
+ ]
255
+ elif "conda" in args.optimizer.lower():
256
+ # make parameters with "rank" to a single group, if param_name has "mlp" or "attn"
257
+ lowrank_params = []
258
+ target_modules_list = ["attn", "mlp"]
259
+ for module_name, module in model.named_modules():
260
+ if not (isinstance(module, nn.Linear) or isinstance(module, QScaleLinear) or isinstance(module, QLinear)):
261
+ continue
262
+ if not any(target_key in module_name for target_key in target_modules_list):
263
+ continue
264
+ logger.info(f"Adding {module_name} to conda parameters")
265
+ lowrank_params.append(module.weight)
266
+
267
+ id_lowrank_params = [id(p) for p in lowrank_params]
268
+ # make parameters without "rank" to another group
269
+ regular_params = [p for p in model.parameters() if id(p) not in id_lowrank_params]
270
+ # then call low rank optimizer
271
+ param_groups = [
272
+ {"params": regular_params},
273
+ {
274
+ "params": lowrank_params,
275
+ "rank": args.rank,
276
+ "update_proj_gap": args.update_proj_gap,
277
+ "scale": args.apollo_scale,
278
+ "proj_type": args.proj_type,
279
+ "proj": args.proj,
280
+ "scale_type": args.scale_type,
281
+ },
282
+ ]
283
+ else:
284
+ param_groups = None
285
+ id_lowrank_params = None
286
+
287
+ # print params and trainable params
288
+ logger.info(f"\n{model}\n")
289
+ logger.info(f"Total params: {sum(p.numel() for p in model.parameters()) / 1_000_000:.2f}M")
290
+
291
+ if args.simulation:
292
+ num_train_params = sum(p.numel() for p in trainable_params)
293
+ else:
294
+ num_train_params = sum(p.numel() for p in trainable_params) + sum(p.numel() for p in trainable_params_int8)
295
+
296
+ logger.info(f"Trainable params: {num_train_params / 1_000_000:.2f}M")
297
+ if "q_galore" in args.optimizer.lower():
298
+ logger.info(
299
+ f"Trainable params with Q-GaLore enabled: {sum(p.numel() for p in trainable_params_int8) / 1_000_000:.2f}M"
300
+ )
301
+ elif "galore" in args.optimizer.lower():
302
+ logger.info(f"Total params with GaLore enabled: {sum(p.numel() for p in lowrank_params) / 1_000_000:.2f}M")
303
+ elif "q_apollo" in args.optimizer.lower():
304
+ logger.info(
305
+ f"Trainable params with Q-APOLLO enabled: {sum(p.numel() for p in trainable_params_int8) / 1_000_000:.2f}M"
306
+ )
307
+ elif "apollo" in args.optimizer.lower():
308
+ logger.info(f"Total params with APOLLO enabled: {sum(p.numel() for p in lowrank_params) / 1_000_000:.2f}M")
309
+
310
+ logger.info(f"Saving model to {args.save_dir} every {args.save_every} update steps")
311
+
312
+ model, optimizer, scheduler, layer_wise_flag = setup_optimization(
313
+ args, model, trainable_params, param_groups, id_lowrank_params, model_config
314
+ )
315
+
316
+ if layer_wise_flag:
317
+ # will pass optimizer_dict and scheduler_dict out instead of optimizer and scheduler
318
+ optimizer_dict = optimizer
319
+ scheduler_dict = scheduler
320
+
321
+ # Bug-3 fix: wrap with DDP *before* torch.compile per PyTorch recommendation.
322
+ # This ensures gradient reduction hooks are correctly installed on the DDP module,
323
+ # and the compiled graph captures the full DDP+model forward pass.
324
+ # (Issue-5: optimizer.load_state_dict is called after both DDP and compile below.)
325
+ if not args.single_gpu:
326
+ model: LlamaForCausalLM = torch.nn.parallel.DistributedDataParallel(
327
+ model,
328
+ device_ids=[local_rank],
329
+ output_device=local_rank,
330
+ broadcast_buffers=False,
331
+ )
332
+
333
+ # compile the model (after DDP so the compiled graph includes DDP reduction)
334
+ if args.compile:
335
+ print("Compiling the model... (takes a ~minute)")
336
+ unoptimized_model = model
337
+
338
+ # Configure TorchDynamo to suppress errors and fall back to eager mode
339
+ import torch._dynamo
340
+ torch._dynamo.config.suppress_errors = args.dynamo_suppress_errors
341
+ torch._dynamo.config.verbose = False
342
+ # Set cache size limit to prevent memory issues during long training
343
+ torch._dynamo.config.cache_size_limit = args.dynamo_cache_limit
344
+
345
+ model = torch.compile(model) # requires PyTorch 2.0
346
+
347
+ # resume optimizer
348
+ if args.restore_optimizer and args.continue_from is not None:
349
+ logger.info("Restoring optimizer and scheduler from the checkpoint")
350
+ _optimizer_dir = args.continue_from
351
+ optimizer_checkpoint = torch.load(os.path.join(_optimizer_dir, "optimizer.pt"), map_location="cpu")
352
+ optimizer.load_state_dict(optimizer_checkpoint["optimizer"])
353
+ scheduler.load_state_dict(optimizer_checkpoint["scheduler"])
354
+ update_step = optimizer_checkpoint["update_step"]
355
+ beginning_step = update_step
356
+ global_step = optimizer_checkpoint["global_step"]
357
+ logger.info(f"Optimizer and scheduler restored from {_optimizer_dir}")
358
+
359
+ # ##############################
360
+ # TRAINING LOOP
361
+ # we use iterable dataset, so we may never go through all the data
362
+ # ##############################
363
+ # global steps and others are defined above
364
+ pad_idx = tokenizer.pad_token_id
365
+ update_time = time.time()
366
+ local_step = 0 # when continue_from is used, local_step != global_step
367
+ total_svd_count = 0
368
+
369
+ dataloader_iter = iter(dataloader)
370
+
371
+ # Issue-4 fix: accumulate loss across micro-batches so logged loss is the true
372
+ # gradient-accumulation average, not just the last micro-batch.
373
+ accumulated_loss = 0.0
374
+
375
+ # Skip data if resuming from checkpoint
376
+ if update_step != 0:
377
+ skip_batches = args.gradient_accumulation * update_step
378
+ logger.info(f"Skipping {skip_batches} batches to resume from update step {update_step}")
379
+ skipped = 0
380
+ for _ in range(skip_batches):
381
+ # Issue-6 fix: handle StopIteration during skip so all ranks stay aligned
382
+ try:
383
+ next(dataloader_iter)
384
+ except StopIteration:
385
+ logger.warning(
386
+ f"Dataset exhausted during skip at batch {skipped}/{skip_batches}; "
387
+ f"restarting iterator to keep ranks aligned."
388
+ )
389
+ dataloader_iter = iter(dataloader)
390
+ next(dataloader_iter)
391
+ skipped += 1
392
+ logger.info(f"Skipped {skipped} batches successfully")
393
+
394
+ while update_step <= args.num_training_steps:
395
+ try:
396
+ batch = next(dataloader_iter)
397
+ except StopIteration:
398
+ logger.info(f"Dataset completed one epoch. Starting new epoch with reshuffled data.")
399
+ dataloader_iter = iter(dataloader)
400
+ batch = next(dataloader_iter)
401
+
402
+ global_step += 1
403
+ local_step += 1
404
+
405
+ if update_step >= args.num_training_steps:
406
+ logger.info(f"Reached max number of update steps ({args.num_training_steps}). Stopping training.")
407
+ logger.info(f"Rank {global_rank} stopping training.")
408
+ break
409
+
410
+ # forward & backward
411
+ batch = {k: v.to(device) for k, v in batch.items()}
412
+ labels = batch["input_ids"].clone()
413
+ labels[labels == pad_idx] = -100
414
+ tokens_seen += (batch["input_ids"] != pad_idx).sum().item() * world_size
415
+
416
+ loss = model(**batch, labels=labels).loss
417
+
418
+ scaled_loss = loss / args.gradient_accumulation
419
+ scaled_loss.backward()
420
+ accumulated_loss += loss.item() # Issue-4: accumulate before the continue
421
+
422
+ if global_step % args.gradient_accumulation != 0:
423
+ continue
424
+
425
+ # The below code is only executed during the update step
426
+ # Issue-4: compute average loss over all micro-batches in this accumulation window
427
+ avg_loss = accumulated_loss / args.gradient_accumulation
428
+ accumulated_loss = 0.0 # reset for next accumulation window
429
+ # add grad clipping: TODO: add gradient clipping of int8 weight
430
+ if args.grad_clipping != 0.0:
431
+ torch.nn.utils.clip_grad_norm_(trainable_params, args.grad_clipping)
432
+ # Periodic memory cleanup to prevent symbolic tensor issues during long training
433
+ if global_step % args.memory_cleanup_frequency == 0:
434
+ torch.cuda.empty_cache()
435
+ # Clear TorchDynamo cache to prevent memory accumulation
436
+ if args.compile:
437
+ import torch._dynamo
438
+ torch._dynamo.reset()
439
+
440
+ if global_rank == 0:
441
+ pbar.update(1)
442
+ if not layer_wise_flag: # layer-wise updation is done during backward; requires gradient_accumulation equals 1
443
+ optimizer.step()
444
+ scheduler.step()
445
+ optimizer.zero_grad()
446
+
447
+ update_step += 1
448
+ update_time = time.time() - update_time
449
+
450
+ # save checkpoint by save_every
451
+ if local_step > args.gradient_accumulation and update_step % args.save_every == 0 and global_rank == 0:
452
+ current_model_directory = f"{args.save_dir}/model_{update_step}"
453
+ logger.info(f"Saving model and optimizer to {current_model_directory}, update step {update_step}")
454
+ os.makedirs(args.save_dir, exist_ok=True)
455
+ # Bug-1 fix: unwrap DDP/compiled model for saving; works in both single-GPU and multi-GPU modes
456
+ unwrapped_model = model.module if hasattr(model, 'module') else model
457
+ unwrapped_model.save_pretrained(current_model_directory, max_shard_size="500GB", from_pt=True)
458
+ saving_model_weight(unwrapped_model, f"{current_model_directory}/pytorch_model.bin", args)
459
+
460
+ optimizer_checkpoint = {
461
+ "optimizer": optimizer.state_dict(),
462
+ "scheduler": scheduler.state_dict(),
463
+ "update_step": update_step,
464
+ "global_step": global_step,
465
+ "config": run_config,
466
+ "wandb": wandb.run.dir if not args.unset_wandb else None,
467
+ "dtype": args.dtype,
468
+ }
469
+ torch.save(optimizer_checkpoint, f"{current_model_directory}/optimizer.pt")
470
+
471
+ training_state_checkpoint = {
472
+ "global_step": global_step,
473
+ "update_step": update_step,
474
+ "tokens_seen": tokens_seen,
475
+ "tokens_seen_before": tokens_seen_before,
476
+ "update_time": update_time,
477
+ }
478
+ with open(f"{current_model_directory}/training_state.json", "w") as f:
479
+ json.dump(training_state_checkpoint, f, indent=4)
480
+
481
+ # save wandb related info
482
+ if not args.unset_wandb:
483
+ wandb_info = {
484
+ "wandb_id": wandb.run.id,
485
+ }
486
+ with open(f"{args.save_dir}/wandb.json", "w") as f:
487
+ json.dump(wandb_info, f, indent=4)
488
+
489
+ # evaluation
490
+ if update_step % args.eval_every == 0:
491
+ logger.info(f"Performing evaluation at step {update_step}")
492
+ total_loss, evaluated_on_tokens, perplexity = evaluate_model(
493
+ model, tokenizer, pad_idx, global_rank, world_size, device, args
494
+ )
495
+
496
+ if global_rank == 0:
497
+ if not args.unset_wandb:
498
+ wandb.log(
499
+ {
500
+ "eval_loss": total_loss,
501
+ "eval_perplexity": perplexity,
502
+ "eval_tokens": evaluated_on_tokens,
503
+ },
504
+ step=update_step,
505
+ )
506
+ logger.info(f"Eval loss at step {update_step}: {total_loss}, Eval perplexity: {perplexity}")
507
+
508
+ if not layer_wise_flag:
509
+ lr = optimizer.param_groups[0]["lr"]
510
+ else:
511
+ lr = list(optimizer_dict.values())[0].param_groups[0]["lr"]
512
+ tokens_in_update = tokens_seen - tokens_seen_before
513
+ tokens_seen_before = tokens_seen
514
+ batches_in_update = args.gradient_accumulation * world_size
515
+ if not layer_wise_flag:
516
+ total_svd_count = getting_svd_cnt(optimizer)
517
+ else:
518
+ total_svd_count = 0
519
+
520
+ # Build extra lr metrics for hybrid optimizer
521
+ hybrid_lr_log = {}
522
+ if isinstance(optimizer, HybridOptimizer):
523
+ if optimizer.muon is not None:
524
+ hybrid_lr_log["lr_muon"] = optimizer.muon.param_groups[0]["lr"]
525
+ if optimizer.apollo is not None:
526
+ hybrid_lr_log["lr_apollo"] = optimizer.apollo.param_groups[0]["lr"]
527
+ hybrid_lr_log["lr_adamw"] = optimizer.adamw.param_groups[0]["lr"]
528
+
529
+ if global_rank == 0:
530
+ if not args.unset_wandb:
531
+ log_dict = {
532
+ "loss": avg_loss,
533
+ "lr": lr,
534
+ "update_step": update_step,
535
+ "tokens_seen": tokens_seen,
536
+ "total_svd_count": total_svd_count,
537
+ "throughput_tokens": tokens_in_update / update_time,
538
+ "throughput_examples": args.total_batch_size / update_time,
539
+ "throughput_batches": batches_in_update / update_time,
540
+ }
541
+ log_dict.update(hybrid_lr_log)
542
+ wandb.log(log_dict, step=update_step)
543
+ update_time = time.time()
544
+
545
+ # ##############################
546
+ # END of training loop
547
+ # ##############################
548
+ logger.info("Training finished")
549
+ if global_rank == 0:
550
+ pbar.close()
551
+
552
+ current_model_directory = f"{args.save_dir}/model_{update_step}"
553
+ if global_rank == 0 and not os.path.exists(current_model_directory):
554
+ logger.info(f"Saving model and optimizer to {current_model_directory}, update step {update_step}")
555
+ os.makedirs(args.save_dir, exist_ok=True)
556
+ # Bug-1 fix: unwrap DDP/compiled model for saving; works in both single-GPU and multi-GPU modes
557
+ unwrapped_model = model.module if hasattr(model, 'module') else model
558
+ unwrapped_model.save_pretrained(current_model_directory, max_shard_size="500GB", from_pt=True)
559
+ saving_model_weight(unwrapped_model, f"{current_model_directory}/pytorch_model.bin", args)
560
+
561
+ optimizer_checkpoint = {
562
+ "optimizer": optimizer.state_dict(),
563
+ "scheduler": scheduler.state_dict(),
564
+ "update_step": update_step,
565
+ "global_step": global_step,
566
+ "config": run_config,
567
+ "wandb": wandb.run.dir if not args.unset_wandb else None,
568
+ "dtype": args.dtype,
569
+ }
570
+ torch.save(optimizer_checkpoint, f"{current_model_directory}/optimizer.pt")
571
+
572
+ training_state_checkpoint = {
573
+ "global_step": global_step,
574
+ "update_step": update_step,
575
+ "tokens_seen": tokens_seen,
576
+ "tokens_seen_before": tokens_seen_before,
577
+ "update_time": update_time,
578
+ }
579
+ with open(f"{current_model_directory}/training_state.json", "w") as f:
580
+ json.dump(training_state_checkpoint, f, indent=4)
581
+
582
+ # Final evaluation
583
+ logger.info("Running final evaluation")
584
+ model.eval()
585
+ del loss, optimizer, scheduler
586
+ import gc
587
+
588
+ gc.collect()
589
+ torch.cuda.empty_cache()
590
+
591
+ total_loss, evaluated_on_tokens, perplexity = evaluate_model(model, tokenizer, pad_idx, global_rank, world_size, device, args)
592
+
593
+ if global_rank == 0:
594
+ if not args.unset_wandb:
595
+ wandb.log(
596
+ {
597
+ "final_eval_loss": total_loss,
598
+ "final_eval_perplexity": perplexity,
599
+ "final_eval_tokens": evaluated_on_tokens,
600
+ },
601
+ step=update_step,
602
+ )
603
+ logger.info(f"Final eval loss: {total_loss}, Final eval perplexity: {perplexity}")
604
+
605
+ logger.info("Script finished successfully")
606
+ print(f"Rank {global_rank} finished successfully")
607
+
608
+
609
+ if __name__ == "__main__":
610
+ print("Starting script")
611
+ args = parse_args(None)
612
+ main(args)
350m/conda_lr1e_2_b1_0_9_b2_0_99_eps_1e_8_scale_0_25_rank256_T2000_H200_ppl16_4542_20260423_202415/wandb/offline-run-20260423_202422-mawza3ul/files/requirements.txt ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ sac==0.1.0
2
+ packaging==26.0
3
+ setuptools==82.0.1
4
+ wheel==0.46.3
5
+ pip==26.0.1
6
+ torchaudio==2.11.0
7
+ nvidia-cusparselt-cu12==0.7.1
8
+ mpmath==1.3.0
9
+ typing_extensions==4.15.0
10
+ triton==3.4.0
11
+ sympy==1.14.0
12
+ pillow==12.2.0
13
+ nvidia-nvtx-cu12==12.8.90
14
+ nvidia-nvjitlink-cu12==12.8.93
15
+ nvidia-nccl-cu12==2.27.3
16
+ nvidia-curand-cu12==10.3.9.90
17
+ nvidia-cufile-cu12==1.13.1.3
18
+ nvidia-cuda-runtime-cu12==12.8.90
19
+ nvidia-cuda-nvrtc-cu12==12.8.93
20
+ nvidia-cuda-cupti-cu12==12.8.90
21
+ nvidia-cublas-cu12==12.8.4.1
22
+ numpy==2.2.6
23
+ networkx==3.4.2
24
+ MarkupSafe==3.0.3
25
+ aiohappyeyeballs==2.6.1
26
+ filelock==3.28.0
27
+ nvidia-cusparse-cu12==12.5.8.93
28
+ nvidia-cufft-cu12==11.3.3.83
29
+ nvidia-cudnn-cu12==9.10.2.21
30
+ Jinja2==3.1.6
31
+ nvidia-cusolver-cu12==11.7.3.90
32
+ torch==2.8.0+cu128
33
+ torchvision==0.23.0+cu128
34
+ pytz==2026.1.post1
35
+ xxhash==3.6.0
36
+ urllib3==2.6.3
37
+ tzdata==2026.1
38
+ tqdm==4.67.3
39
+ six==1.17.0
40
+ safetensors==0.7.0
41
+ regex==2026.4.4
42
+ PyYAML==6.0.3
43
+ pyarrow==23.0.1
44
+ psutil==7.2.2
45
+ propcache==0.4.1
46
+ multidict==6.7.1
47
+ idna==3.11
48
+ hf-xet==1.4.3
49
+ h11==0.16.0
50
+ fsspec==2026.2.0
51
+ frozenlist==1.8.0
52
+ exceptiongroup==1.3.1
53
+ dill==0.4.1
54
+ charset-normalizer==3.4.7
55
+ certifi==2026.2.25
56
+ attrs==26.1.0
57
+ async-timeout==5.0.1
58
+ yarl==1.23.0
59
+ requests==2.33.1
60
+ python-dateutil==2.9.0.post0
61
+ multiprocess==0.70.19
62
+ httpcore==1.0.9
63
+ anyio==4.13.0
64
+ aiosignal==1.4.0
65
+ pandas==2.3.3
66
+ huggingface_hub==0.36.2
67
+ httpx==0.28.1
68
+ aiohttp==3.13.5
69
+ tokenizers==0.22.2
70
+ accelerate==1.13.0
71
+ transformers==4.57.3
72
+ datasets==4.8.4
73
+ peft==0.19.1
74
+ pytorch-ranger==0.1.1
75
+ lion-pytorch==0.2.4
76
+ bitsandbytes==0.49.2
77
+ torch-optimizer==0.3.0
78
+ apollo-torch==1.0.3
79
+ nvidia-ml-py==13.590.48
80
+ typing-inspection==0.4.2
81
+ threadpoolctl==3.6.0
82
+ smmap==5.0.3
83
+ sentry-sdk==2.58.0
84
+ scipy==1.15.3
85
+ pyparsing==3.3.2
86
+ pydantic_core==2.46.3
87
+ protobuf==7.34.1
88
+ platformdirs==4.9.6
89
+ nvitop==1.6.2
90
+ loguru==0.7.3
91
+ kiwisolver==1.5.0
92
+ joblib==1.5.3
93
+ fonttools==4.62.1
94
+ cycler==0.12.1
95
+ contourpy==1.3.2
96
+ click==8.3.2
97
+ annotated-types==0.7.0
98
+ scikit-learn==1.7.2
99
+ pydantic==2.13.3
100
+ modelscope==1.35.4
101
+ matplotlib==3.10.8
102
+ gitdb==4.0.12
103
+ seaborn==0.13.2
104
+ GitPython==3.1.47
105
+ wandb==0.26.0
106
+ sac==0.1.0
107
+ nvidia-ml-py3==7.352.0
108
+ gitignore_parser==0.1.13
109
+ durationpy==0.10
110
+ dotmap==1.3.30
111
+ wrapt==2.1.2
112
+ websocket-client==1.9.0
113
+ typeguard==4.5.1
114
+ tabulate==0.9.0
115
+ pycparser==3.0
116
+ pyasn1==0.6.3
117
+ py==1.11.0
118
+ oauthlib==3.3.1
119
+ jmespath==1.1.0
120
+ invoke==3.0.3
121
+ elasticsearch==7.17.13
122
+ docutils==0.19
123
+ decorator==5.2.1
124
+ confluent-kafka==2.14.0
125
+ colorama==0.4.6
126
+ bcrypt==5.0.0
127
+ rsa==4.7.2
128
+ retry==0.9.2
129
+ requests-oauthlib==2.0.0
130
+ Deprecated==1.3.1
131
+ cffi==2.0.0
132
+ botocore==1.42.92
133
+ s3transfer==0.16.0
134
+ PyNaCl==1.6.2
135
+ kubernetes==35.0.0
136
+ cryptography==46.0.7
137
+ paramiko==4.0.0
138
+ boto3==1.42.92
139
+ awscli==1.44.82
140
+ megfile==2.2.10.post1
141
+ refile==7.2.7.post3
142
+ brainpp==2.7.12.16
350m/conda_lr1e_2_b1_0_9_b2_0_99_eps_1e_8_scale_0_25_rank256_T2000_H200_ppl16_4542_20260423_202415/wandb/offline-run-20260423_202422-mawza3ul/logs/debug-internal.log ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"time":"2026-04-23T20:24:22.4996429+08:00","level":"INFO","msg":"wandb-core"}
2
+ {"time":"2026-04-23T20:24:22.499901972+08:00","level":"INFO","msg":"stream: starting","core version":"0.26.0"}
3
+ {"time":"2026-04-23T20:24:22.63528813+08:00","level":"WARN","msg":"featurechecker: GraphQL client is nil, skipping feature loading"}
4
+ {"time":"2026-04-23T20:24:22.635312622+08:00","level":"WARN","msg":"featurechecker: GraphQL client is nil, skipping feature loading"}
5
+ {"time":"2026-04-23T20:24:22.635332006+08:00","level":"INFO","msg":"stream: created new stream","id":"mawza3ul"}
6
+ {"time":"2026-04-23T20:24:22.635539387+08:00","level":"INFO","msg":"handler: started"}
7
+ {"time":"2026-04-23T20:24:22.636204984+08:00","level":"INFO","msg":"stream: started"}
8
+ {"time":"2026-04-23T20:24:22.636292599+08:00","level":"INFO","msg":"writer: started","stream_id":"mawza3ul"}
9
+ {"time":"2026-04-23T20:24:22.636304487+08:00","level":"INFO","msg":"sender: started"}
10
+ {"time":"2026-04-23T20:24:22.637327692+08:00","level":"WARN","msg":"featurechecker: GraphQL client is nil, skipping feature loading"}
11
+ {"time":"2026-04-23T20:24:22.637343161+08:00","level":"WARN","msg":"runupserter: server does not expand metric globs but the x_server_side_expand_glob_metrics setting is set; ignoring"}
12
+ {"time":"2026-04-24T01:14:24.730723575+08:00","level":"INFO","msg":"stream: closing"}
13
+ {"time":"2026-04-24T01:14:24.753028533+08:00","level":"INFO","msg":"handler: closed"}
14
+ {"time":"2026-04-24T01:14:24.753387495+08:00","level":"INFO","msg":"sender: closed"}
15
+ {"time":"2026-04-24T01:14:24.753400977+08:00","level":"INFO","msg":"stream: closed"}
350m/conda_lr1e_2_b1_0_9_b2_0_99_eps_1e_8_scale_0_25_rank256_T2000_H200_ppl16_4542_20260423_202415/wandb/offline-run-20260423_202422-mawza3ul/logs/debug.log ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2026-04-23 20:24:22,097 INFO MainThread:342 [wandb_setup.py:_flush():81] Current SDK version is 0.26.0
2
+ 2026-04-23 20:24:22,098 INFO MainThread:342 [wandb_setup.py:_flush():81] Configure stats pid to 342
3
+ 2026-04-23 20:24:22,098 INFO MainThread:342 [wandb_setup.py:_flush():81] Loading settings from environment variables
4
+ 2026-04-23 20:24:22,098 INFO MainThread:342 [wandb_init.py:setup_run_log_directory():721] Logging user logs to exp_remain_h200/work_dirs/350m/train_350m_conda_lr1e_2_scale0_25_rank256_gap2000_20260423_202415/wandb/offline-run-20260423_202422-mawza3ul/logs/debug.log
5
+ 2026-04-23 20:24:22,098 INFO MainThread:342 [wandb_init.py:setup_run_log_directory():722] Logging internal logs to exp_remain_h200/work_dirs/350m/train_350m_conda_lr1e_2_scale0_25_rank256_gap2000_20260423_202415/wandb/offline-run-20260423_202422-mawza3ul/logs/debug-internal.log
6
+ 2026-04-23 20:24:22,098 INFO MainThread:342 [wandb_init.py:init():848] calling init triggers
7
+ 2026-04-23 20:24:22,098 INFO MainThread:342 [wandb_init.py:init():853] wandb.init called with sweep_config: {}
8
+ config: {'_wandb': {}}
9
+ 2026-04-23 20:24:22,098 INFO MainThread:342 [wandb_init.py:init():896] starting backend
10
+ 2026-04-23 20:24:22,494 INFO MainThread:342 [wandb_init.py:init():911] sending inform_init request
11
+ 2026-04-23 20:24:22,498 INFO MainThread:342 [wandb_init.py:init():919] backend started and connected
12
+ 2026-04-23 20:24:22,498 INFO MainThread:342 [wandb_init.py:init():989] updated telemetry
13
+ 2026-04-23 20:24:22,526 INFO MainThread:342 [wandb_init.py:init():1013] communicating run to backend with 90.0 second timeout
14
+ 2026-04-23 20:24:22,638 INFO MainThread:342 [wandb_init.py:init():1058] starting run threads in backend
15
+ 2026-04-23 20:24:22,712 INFO MainThread:342 [wandb_run.py:_console_start():2542] atexit reg
16
+ 2026-04-23 20:24:22,712 INFO MainThread:342 [wandb_run.py:_redirect():2391] redirect: wrap_raw
17
+ 2026-04-23 20:24:22,712 INFO MainThread:342 [wandb_run.py:_redirect():2460] Wrapping output streams.
18
+ 2026-04-23 20:24:22,712 INFO MainThread:342 [wandb_run.py:_redirect():2483] Redirects installed.
19
+ 2026-04-23 20:24:22,714 INFO MainThread:342 [wandb_init.py:init():1098] run started, returning control to user process
20
+ 2026-04-23 20:24:49,867 INFO MainThread:342 [wandb_run.py:_config_callback():1403] config_cb None None {'model_config': 'configs/llama_350m.json', 'exp_config': 'exp_v2/configs/llama_350m.json', 'eval_every': 1000, 'save_every': 60000, 'dtype': 'bfloat16', 'seed': 0, 'compile': True, 'dynamo_suppress_errors': True, 'dynamo_cache_limit': 10000, 'memory_cleanup_frequency': 10000, 'resume_step': None, 'restore_optimizer': False, 'continue_from': None, 'single_gpu': False, 'save_dir': 'exp_remain_h200/work_dirs/350m/train_350m_conda_lr1e_2_scale0_25_rank256_gap2000_20260423_202415', 'use_hf_model': False, 'workers': 12, 'batch_size': 128, 'gradient_accumulation': 1, 'total_batch_size': 512, 'warmup_steps': 6000, 'num_training_steps': 60000, 'max_train_tokens': None, 'optimizer': 'conda', 'max_length': 256, 'scheduler': 'cosine', 'min_lr_ratio': 0.1, 'weight_decay': 0.0, 'grad_clipping': 0.0, 'activation_checkpointing': False, 'data_path': '/mnt/shared-storage-gpfs2/finebio-shared/optimizer/dataset/C4/en', 'data_name': 'en', 'tags': None, 'name': 'test', 'project': 'test', 'unset_wandb': False, 'entity': None, 'wandb_dir': 'exp_remain_h200/work_dirs/350m/train_350m_conda_lr1e_2_scale0_25_rank256_gap2000_20260423_202415', 'beta1': 0.9, 'beta2': 0.99, 'beta3': 0.99, 'eps': 1e-08, 'rank': 256, 'update_proj_gap': 2000, 'galore_scale': 1.0, 'proj_type': 'std', 'proj_quant': False, 'proj_bits': 8, 'proj_group_size': 256, 'weight_quant': False, 'weight_bits': 8, 'weight_group_size': 256, 'stochastic_round': False, 'simulation': False, 'cos_threshold': 1, 'gamma_proj': 2, 'queue_size': 5, 'proj': 'random', 'scale_type': 'channel', 'apollo_scale': 0.25, 'scale_front': False, 'n_clusters': 3, 'scale_update_freq': 500, 'scale_level': '1,0,1,1', 'scale_bound': None, 'metric': 'mean', 'align_grad': False, 'dim': 4096, 'n_heads': 32, 'muon_ns_steps': 5, 'muon_momentum': 0.95, 'nproc_per_node': 4, 'max_lr': 0.01, 'total_params_M': 367.96928, 'dataset': 'c4', 'model': {'vocab_size': 32000, 'max_position_embeddings': 2048, 'hidden_size': 1024, 'intermediate_size': 2736, 'num_hidden_layers': 24, 'num_attention_heads': 16, 'num_key_value_heads': 16, 'hidden_act': 'silu', 'initializer_range': 0.02, 'rms_norm_eps': 1e-06, 'pretraining_tp': 1, 'use_cache': True, 'rope_theta': 10000.0, 'rope_scaling': None, 'attention_bias': False, 'attention_dropout': 0.0, 'mlp_bias': False, 'head_dim': 64, 'return_dict': True, 'output_hidden_states': False, 'torchscript': False, 'dtype': None, 'pruned_heads': {}, 'tie_word_embeddings': False, 'chunk_size_feed_forward': 0, 'is_encoder_decoder': False, 'is_decoder': False, 'cross_attention_hidden_size': None, 'add_cross_attention': False, 'tie_encoder_decoder': False, 'architectures': ['LLaMAForCausalLM'], 'finetuning_task': None, 'id2label': {0: 'LABEL_0', 1: 'LABEL_1'}, 'label2id': {'LABEL_0': 0, 'LABEL_1': 1}, 'task_specific_params': None, 'problem_type': None, 'tokenizer_class': None, 'prefix': None, 'bos_token_id': 0, 'pad_token_id': -1, 'eos_token_id': 1, 'sep_token_id': None, 'decoder_start_token_id': None, 'max_length': 20, 'min_length': 0, 'do_sample': False, 'early_stopping': False, 'num_beams': 1, 'temperature': 1.0, 'top_k': 50, 'top_p': 1.0, 'typical_p': 1.0, 'repetition_penalty': 1.0, 'length_penalty': 1.0, 'no_repeat_ngram_size': 0, 'encoder_no_repeat_ngram_size': 0, 'bad_words_ids': None, 'num_return_sequences': 1, 'output_scores': False, 'return_dict_in_generate': False, 'forced_bos_token_id': None, 'forced_eos_token_id': None, 'remove_invalid_values': False, 'exponential_decay_length_penalty': None, 'suppress_tokens': None, 'begin_suppress_tokens': None, 'num_beam_groups': 1, 'diversity_penalty': 0.0, '_name_or_path': 'configs/llama_350m.json', 'transformers_version': '4.57.3', 'max_sequence_length': 1024, 'model_type': 'llama', 'tf_legacy_loss': False, 'use_bfloat16': False, 'output_attentions': False}, 'world_size': 4, 'device': 'cuda:0'}
21
+ 2026-04-24 01:14:24,727 INFO wandb-AsyncioManager-main:342 [service_client.py:_forward_responses():134] Reached EOF.
22
+ 2026-04-24 01:14:24,730 INFO wandb-AsyncioManager-main:342 [mailbox.py:close():155] Closing mailbox, abandoning 0 handles.
350m/conda_lr1e_2_b1_0_9_b2_0_99_eps_1e_8_scale_0_25_rank256_T2000_H200_ppl16_4542_20260423_202415/wandb/offline-run-20260423_202422-mawza3ul/run-mawza3ul.wandb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:962d8f0f208560ca403ae8a08e4f5b16a4672f36756493a2b4a1f93e72e37218
3
+ size 59876049