Sssplendid commited on
Commit
d8ba42b
·
verified ·
1 Parent(s): a5c7c91

Add 130m/came_lr5e_4_b1_0_9_b2_0_999_eps_1e_6_A100_ppl_23_7861_20260421_151403

Browse files
Files changed (16) hide show
  1. .gitattributes +1 -0
  2. 130m/came_lr5e_4_b1_0_9_b2_0_999_eps_1e_6_A100_ppl_23_7861_20260421_151403/came_lr5e_4_b1_0_9_b2_0_999_eps_1e_6_A100_ppl_23_7861_20260421_151403.txt +0 -0
  3. 130m/came_lr5e_4_b1_0_9_b2_0_999_eps_1e_6_A100_ppl_23_7861_20260421_151403/model_20000/config.json +31 -0
  4. 130m/came_lr5e_4_b1_0_9_b2_0_999_eps_1e_6_A100_ppl_23_7861_20260421_151403/model_20000/model.safetensors +3 -0
  5. 130m/came_lr5e_4_b1_0_9_b2_0_999_eps_1e_6_A100_ppl_23_7861_20260421_151403/model_20000/optimizer.pt +3 -0
  6. 130m/came_lr5e_4_b1_0_9_b2_0_999_eps_1e_6_A100_ppl_23_7861_20260421_151403/model_20000/pytorch_model.bin +3 -0
  7. 130m/came_lr5e_4_b1_0_9_b2_0_999_eps_1e_6_A100_ppl_23_7861_20260421_151403/model_20000/training_state.json +7 -0
  8. 130m/came_lr5e_4_b1_0_9_b2_0_999_eps_1e_6_A100_ppl_23_7861_20260421_151403/wandb.json +3 -0
  9. 130m/came_lr5e_4_b1_0_9_b2_0_999_eps_1e_6_A100_ppl_23_7861_20260421_151403/wandb/debug-internal.log +12 -0
  10. 130m/came_lr5e_4_b1_0_9_b2_0_999_eps_1e_6_A100_ppl_23_7861_20260421_151403/wandb/debug.log +24 -0
  11. 130m/came_lr5e_4_b1_0_9_b2_0_999_eps_1e_6_A100_ppl_23_7861_20260421_151403/wandb/offline-run-20260421_151426-c42l43zw/files/SAC/torchrun_main.py +603 -0
  12. 130m/came_lr5e_4_b1_0_9_b2_0_999_eps_1e_6_A100_ppl_23_7861_20260421_151403/wandb/offline-run-20260421_151426-c42l43zw/files/requirements.txt +134 -0
  13. 130m/came_lr5e_4_b1_0_9_b2_0_999_eps_1e_6_A100_ppl_23_7861_20260421_151403/wandb/offline-run-20260421_151426-c42l43zw/logs/debug-core.log +14 -0
  14. 130m/came_lr5e_4_b1_0_9_b2_0_999_eps_1e_6_A100_ppl_23_7861_20260421_151403/wandb/offline-run-20260421_151426-c42l43zw/logs/debug-internal.log +12 -0
  15. 130m/came_lr5e_4_b1_0_9_b2_0_999_eps_1e_6_A100_ppl_23_7861_20260421_151403/wandb/offline-run-20260421_151426-c42l43zw/logs/debug.log +24 -0
  16. 130m/came_lr5e_4_b1_0_9_b2_0_999_eps_1e_6_A100_ppl_23_7861_20260421_151403/wandb/offline-run-20260421_151426-c42l43zw/run-c42l43zw.wandb +3 -0
.gitattributes CHANGED
@@ -67,3 +67,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
67
  130m/adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_A100_ppl_23_1813_20260416_193855/wandb/offline-run-20260416_193926-lg6xmhwz/run-lg6xmhwz.wandb filter=lfs diff=lfs merge=lfs -text
68
  130m/adan_lr3e_3_b1_0_9_b2_0_92_b3_0_99_eps_1e_8_A100_ppl_22_8442_20260416_193855/wandb/offline-run-20260416_193926-n4jow674/run-n4jow674.wandb filter=lfs diff=lfs merge=lfs -text
69
  130m/apollo_lr1e_2_b1_0_9_b2_0_99_eps_1e_6_scale_1_rank_192_T_200_A100_ppl_22_7386_20260419_234620/wandb/offline-run-20260419_234717-pafttcq9/run-pafttcq9.wandb filter=lfs diff=lfs merge=lfs -text
 
 
67
  130m/adamw_lr1e_3_b1_0_9_b2_0_99_eps_1e_8_A100_ppl_23_1813_20260416_193855/wandb/offline-run-20260416_193926-lg6xmhwz/run-lg6xmhwz.wandb filter=lfs diff=lfs merge=lfs -text
68
  130m/adan_lr3e_3_b1_0_9_b2_0_92_b3_0_99_eps_1e_8_A100_ppl_22_8442_20260416_193855/wandb/offline-run-20260416_193926-n4jow674/run-n4jow674.wandb filter=lfs diff=lfs merge=lfs -text
69
  130m/apollo_lr1e_2_b1_0_9_b2_0_99_eps_1e_6_scale_1_rank_192_T_200_A100_ppl_22_7386_20260419_234620/wandb/offline-run-20260419_234717-pafttcq9/run-pafttcq9.wandb filter=lfs diff=lfs merge=lfs -text
70
+ 130m/came_lr5e_4_b1_0_9_b2_0_999_eps_1e_6_A100_ppl_23_7861_20260421_151403/wandb/offline-run-20260421_151426-c42l43zw/run-c42l43zw.wandb filter=lfs diff=lfs merge=lfs -text
130m/came_lr5e_4_b1_0_9_b2_0_999_eps_1e_6_A100_ppl_23_7861_20260421_151403/came_lr5e_4_b1_0_9_b2_0_999_eps_1e_6_A100_ppl_23_7861_20260421_151403.txt ADDED
The diff for this file is too large to render. See raw diff
 
130m/came_lr5e_4_b1_0_9_b2_0_999_eps_1e_6_A100_ppl_23_7861_20260421_151403/model_20000/config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "LlamaForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 0,
8
+ "dtype": "bfloat16",
9
+ "eos_token_id": 1,
10
+ "head_dim": 64,
11
+ "hidden_act": "silu",
12
+ "hidden_size": 768,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 2048,
15
+ "max_position_embeddings": 2048,
16
+ "max_sequence_length": 1024,
17
+ "mlp_bias": false,
18
+ "model_type": "llama",
19
+ "num_attention_heads": 12,
20
+ "num_hidden_layers": 12,
21
+ "num_key_value_heads": 12,
22
+ "pad_token_id": -1,
23
+ "pretraining_tp": 1,
24
+ "rms_norm_eps": 1e-06,
25
+ "rope_scaling": null,
26
+ "rope_theta": 10000.0,
27
+ "tie_word_embeddings": false,
28
+ "transformers_version": "4.57.3",
29
+ "use_cache": true,
30
+ "vocab_size": 32000
31
+ }
130m/came_lr5e_4_b1_0_9_b2_0_999_eps_1e_6_A100_ppl_23_7861_20260421_151403/model_20000/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:61e1cea2e1097d161efff6aa28a7204d321b1b35254fa4865a12f5979b809fd4
3
+ size 268226272
130m/came_lr5e_4_b1_0_9_b2_0_999_eps_1e_6_A100_ppl_23_7861_20260421_151403/model_20000/optimizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64be9b8d57244cf47c8bee63751bec450c0274a0fa55b0c99a095d3a2bacb0cd
3
+ size 538598330
130m/came_lr5e_4_b1_0_9_b2_0_999_eps_1e_6_A100_ppl_23_7861_20260421_151403/model_20000/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ff40661f8ec9e1d5571f7f833cca7e7e27bea0b9c42b17d88f429f79a31abce
3
+ size 268262966
130m/came_lr5e_4_b1_0_9_b2_0_999_eps_1e_6_A100_ppl_23_7861_20260421_151403/model_20000/training_state.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "global_step": 20000,
3
+ "update_step": 20000,
4
+ "tokens_seen": 1999942168,
5
+ "tokens_seen_before": 1999842256,
6
+ "update_time": 0.4587056636810303
7
+ }
130m/came_lr5e_4_b1_0_9_b2_0_999_eps_1e_6_A100_ppl_23_7861_20260421_151403/wandb.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "wandb_id": "c42l43zw"
3
+ }
130m/came_lr5e_4_b1_0_9_b2_0_999_eps_1e_6_A100_ppl_23_7861_20260421_151403/wandb/debug-internal.log ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"time":"2026-04-21T15:14:26.9128003+08:00","level":"INFO","msg":"stream: starting","core version":"0.23.0"}
2
+ {"time":"2026-04-21T15:14:27.141384684+08:00","level":"WARN","msg":"featurechecker: GraphQL client is nil, skipping feature loading"}
3
+ {"time":"2026-04-21T15:14:27.141455586+08:00","level":"INFO","msg":"stream: created new stream","id":"c42l43zw"}
4
+ {"time":"2026-04-21T15:14:27.141499443+08:00","level":"INFO","msg":"handler: started","stream_id":"c42l43zw"}
5
+ {"time":"2026-04-21T15:14:27.147641439+08:00","level":"INFO","msg":"stream: started","id":"c42l43zw"}
6
+ {"time":"2026-04-21T15:14:27.14764711+08:00","level":"INFO","msg":"writer: started","stream_id":"c42l43zw"}
7
+ {"time":"2026-04-21T15:14:27.147658516+08:00","level":"INFO","msg":"sender: started","stream_id":"c42l43zw"}
8
+ {"time":"2026-04-21T15:14:27.149120861+08:00","level":"WARN","msg":"runupserter: server does not expand metric globs but the x_server_side_expand_glob_metrics setting is set; ignoring"}
9
+ {"time":"2026-04-21T16:38:11.997527579+08:00","level":"INFO","msg":"stream: closing","id":"c42l43zw"}
10
+ {"time":"2026-04-21T16:38:11.998219162+08:00","level":"INFO","msg":"handler: closed","stream_id":"c42l43zw"}
11
+ {"time":"2026-04-21T16:38:11.999549553+08:00","level":"INFO","msg":"sender: closed","stream_id":"c42l43zw"}
12
+ {"time":"2026-04-21T16:38:11.999562024+08:00","level":"INFO","msg":"stream: closed","id":"c42l43zw"}
130m/came_lr5e_4_b1_0_9_b2_0_999_eps_1e_6_A100_ppl_23_7861_20260421_151403/wandb/debug.log ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2026-04-21 15:14:26,395 INFO MainThread:80806 [wandb_setup.py:_flush():80] Current SDK version is 0.23.0
2
+ 2026-04-21 15:14:26,395 INFO MainThread:80806 [wandb_setup.py:_flush():80] Configure stats pid to 80806
3
+ 2026-04-21 15:14:26,395 INFO MainThread:80806 [wandb_setup.py:_flush():80] Loading settings from /mnt/petrelfs/panjiabao/.config/wandb/settings
4
+ 2026-04-21 15:14:26,395 INFO MainThread:80806 [wandb_setup.py:_flush():80] Loading settings from /mnt/petrelfs/panjiabao/Optimizer/SAC/wandb/settings
5
+ 2026-04-21 15:14:26,395 INFO MainThread:80806 [wandb_setup.py:_flush():80] Loading settings from environment variables
6
+ 2026-04-21 15:14:26,395 INFO MainThread:80806 [wandb_init.py:setup_run_log_directory():713] Logging user logs to /mnt/dhwfile/tancheng/panjiabao/Result/SAC_C4/work_dirs/130m/came_v3_lr5e_4_b1_0_9_b2_0_999_eps_1e_6_20260421_151403/wandb/offline-run-20260421_151426-c42l43zw/logs/debug.log
7
+ 2026-04-21 15:14:26,395 INFO MainThread:80806 [wandb_init.py:setup_run_log_directory():714] Logging internal logs to /mnt/dhwfile/tancheng/panjiabao/Result/SAC_C4/work_dirs/130m/came_v3_lr5e_4_b1_0_9_b2_0_999_eps_1e_6_20260421_151403/wandb/offline-run-20260421_151426-c42l43zw/logs/debug-internal.log
8
+ 2026-04-21 15:14:26,395 INFO MainThread:80806 [wandb_init.py:init():840] calling init triggers
9
+ 2026-04-21 15:14:26,395 INFO MainThread:80806 [wandb_init.py:init():845] wandb.init called with sweep_config: {}
10
+ config: {'_wandb': {}}
11
+ 2026-04-21 15:14:26,395 INFO MainThread:80806 [wandb_init.py:init():888] starting backend
12
+ 2026-04-21 15:14:26,812 INFO MainThread:80806 [wandb_init.py:init():891] sending inform_init request
13
+ 2026-04-21 15:14:26,843 INFO MainThread:80806 [wandb_init.py:init():899] backend started and connected
14
+ 2026-04-21 15:14:26,847 INFO MainThread:80806 [wandb_init.py:init():969] updated telemetry
15
+ 2026-04-21 15:14:26,894 INFO MainThread:80806 [wandb_init.py:init():993] communicating run to backend with 90.0 second timeout
16
+ 2026-04-21 15:14:27,150 INFO MainThread:80806 [wandb_init.py:init():1040] starting run threads in backend
17
+ 2026-04-21 15:14:27,510 INFO MainThread:80806 [wandb_run.py:_console_start():2504] atexit reg
18
+ 2026-04-21 15:14:27,510 INFO MainThread:80806 [wandb_run.py:_redirect():2352] redirect: wrap_raw
19
+ 2026-04-21 15:14:27,510 INFO MainThread:80806 [wandb_run.py:_redirect():2421] Wrapping output streams.
20
+ 2026-04-21 15:14:27,510 INFO MainThread:80806 [wandb_run.py:_redirect():2444] Redirects installed.
21
+ 2026-04-21 15:14:27,517 INFO MainThread:80806 [wandb_init.py:init():1080] run started, returning control to user process
22
+ 2026-04-21 15:14:32,995 INFO MainThread:80806 [wandb_run.py:_config_callback():1385] config_cb None None {'model_config': 'configs/llama_130m.json', 'exp_config': 'exp_v2/configs/llama_130m.json', 'eval_every': 1000, 'save_every': 20000, 'dtype': 'bfloat16', 'seed': 0, 'compile': True, 'dynamo_suppress_errors': True, 'dynamo_cache_limit': 10000, 'memory_cleanup_frequency': 10000, 'resume_step': None, 'restore_optimizer': False, 'continue_from': None, 'single_gpu': False, 'save_dir': '/mnt/dhwfile/tancheng/panjiabao/Result/SAC_C4/work_dirs/130m/came_v3_lr5e_4_b1_0_9_b2_0_999_eps_1e_6_20260421_151403', 'use_hf_model': False, 'workers': 12, 'batch_size': 128, 'gradient_accumulation': 1, 'total_batch_size': 512, 'warmup_steps': 2000, 'num_training_steps': 20000, 'max_train_tokens': None, 'optimizer': 'came', 'max_length': 256, 'scheduler': 'cosine', 'min_lr_ratio': 0.1, 'weight_decay': 0.0, 'grad_clipping': 0.0, 'activation_checkpointing': False, 'data_path': '/mnt/dhwfile/tancheng/panjiabao/dataset/C4/en', 'data_name': 'en', 'tags': None, 'name': 'test', 'project': 'test', 'unset_wandb': False, 'entity': None, 'wandb_dir': '/mnt/dhwfile/tancheng/panjiabao/Result/SAC_C4/work_dirs/130m/came_v3_lr5e_4_b1_0_9_b2_0_999_eps_1e_6_20260421_151403', 'beta1': 0.9, 'beta2': 0.999, 'beta3': 0.99, 'eps': 1e-06, 'rank': 128, 'update_proj_gap': 50, 'galore_scale': 1.0, 'proj_type': 'std', 'proj_quant': False, 'proj_bits': 8, 'proj_group_size': 256, 'weight_quant': False, 'weight_bits': 8, 'weight_group_size': 256, 'stochastic_round': False, 'simulation': False, 'cos_threshold': 1, 'gamma_proj': 2, 'queue_size': 5, 'proj': 'random', 'scale_type': 'tensor', 'apollo_scale': 1.0, 'scale_front': False, 'n_clusters': 3, 'scale_update_freq': 500, 'scale_level': '1,0,1,1', 'scale_bound': None, 'metric': 'mean', 'align_grad': False, 'dim': 4096, 'n_heads': 32, 'muon_ns_steps': 5, 'muon_momentum': 0.95, 'nproc_per_node': 4, 'max_lr': 0.0005, 'total_params_M': 134.105856, 'dataset': 'c4', 'model': {'vocab_size': 32000, 'max_position_embeddings': 2048, 'hidden_size': 768, 'intermediate_size': 2048, 'num_hidden_layers': 12, 'num_attention_heads': 12, 'num_key_value_heads': 12, 'hidden_act': 'silu', 'initializer_range': 0.02, 'rms_norm_eps': 1e-06, 'pretraining_tp': 1, 'use_cache': True, 'rope_theta': 10000.0, 'rope_scaling': None, 'attention_bias': False, 'attention_dropout': 0.0, 'mlp_bias': False, 'head_dim': 64, 'return_dict': True, 'output_hidden_states': False, 'torchscript': False, 'dtype': None, 'pruned_heads': {}, 'tie_word_embeddings': False, 'chunk_size_feed_forward': 0, 'is_encoder_decoder': False, 'is_decoder': False, 'cross_attention_hidden_size': None, 'add_cross_attention': False, 'tie_encoder_decoder': False, 'architectures': ['LLaMAForCausalLM'], 'finetuning_task': None, 'id2label': {0: 'LABEL_0', 1: 'LABEL_1'}, 'label2id': {'LABEL_0': 0, 'LABEL_1': 1}, 'task_specific_params': None, 'problem_type': None, 'tokenizer_class': None, 'prefix': None, 'bos_token_id': 0, 'pad_token_id': -1, 'eos_token_id': 1, 'sep_token_id': None, 'decoder_start_token_id': None, 'max_length': 20, 'min_length': 0, 'do_sample': False, 'early_stopping': False, 'num_beams': 1, 'temperature': 1.0, 'top_k': 50, 'top_p': 1.0, 'typical_p': 1.0, 'repetition_penalty': 1.0, 'length_penalty': 1.0, 'no_repeat_ngram_size': 0, 'encoder_no_repeat_ngram_size': 0, 'bad_words_ids': None, 'num_return_sequences': 1, 'output_scores': False, 'return_dict_in_generate': False, 'forced_bos_token_id': None, 'forced_eos_token_id': None, 'remove_invalid_values': False, 'exponential_decay_length_penalty': None, 'suppress_tokens': None, 'begin_suppress_tokens': None, 'num_beam_groups': 1, 'diversity_penalty': 0.0, '_name_or_path': 'configs/llama_130m.json', 'transformers_version': '4.57.3', 'max_sequence_length': 1024, 'model_type': 'llama', 'tf_legacy_loss': False, 'use_bfloat16': False, 'output_attentions': False}, 'world_size': 4, 'device': 'cuda:0'}
23
+ 2026-04-21 16:38:11,996 INFO wandb-AsyncioManager-main:80806 [service_client.py:_forward_responses():80] Reached EOF.
24
+ 2026-04-21 16:38:11,997 INFO wandb-AsyncioManager-main:80806 [mailbox.py:close():137] Closing mailbox, abandoning 0 handles.
130m/came_lr5e_4_b1_0_9_b2_0_999_eps_1e_6_A100_ppl_23_7861_20260421_151403/wandb/offline-run-20260421_151426-c42l43zw/files/SAC/torchrun_main.py ADDED
@@ -0,0 +1,603 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ import os
7
+ import time
8
+ import json
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.distributed as dist
13
+
14
+ from tqdm import tqdm
15
+ from loguru import logger
16
+
17
+ import transformers
18
+
19
+ transformers.logging.set_verbosity_error()
20
+
21
+ import wandb
22
+
23
+ from utils.argparse import parse_args
24
+ from utils.setup import getting_svd_cnt, set_seed, setup_model, saving_model_weight, load_model_weight
25
+ from utils.optimizer_factory import setup_optimization
26
+ from utils.eval import evaluate_model
27
+ from utils.dataloader import setup_dataset
28
+ from utils.modeling_llama import LlamaForCausalLM
29
+ from utils.fake_quantization import QLinear
30
+ from utils.quantization import QScaleLinear
31
+
32
+
33
+ def main(args):
34
+ import torch
35
+ ############ Setup random seed ############
36
+ set_seed(args)
37
+
38
+ ############ Setup DDP environment ############
39
+ assert "LOCAL_RANK" in os.environ, "torchrun should set LOCAL_RANK"
40
+ global_rank = int(os.environ["RANK"])
41
+ local_rank = int(os.environ["LOCAL_RANK"])
42
+ world_size = int(os.environ["WORLD_SIZE"])
43
+ torch.cuda.set_device(local_rank)
44
+
45
+ logger.info(f"Global rank {global_rank}, local rank {local_rank}, device: {torch.cuda.current_device()}")
46
+ dist.init_process_group(backend="nccl", rank=global_rank, world_size=world_size)
47
+
48
+ logger.info("Process group initialized")
49
+ device = f"cuda:{local_rank}"
50
+
51
+ if global_rank != 0:
52
+ logger.remove() # turn off logger
53
+
54
+ logger.info(f"Using dist with rank {global_rank} (only rank 0 will log)")
55
+ logger.info("*" * 40)
56
+ logger.info(f"Starting training with the arguments")
57
+ for k, v in vars(args).items():
58
+ logger.info(f"{k:30} {v}")
59
+ logger.info("*" * 40)
60
+
61
+ ############ Initialize wandb without config (it is passed later) ############
62
+ if (not args.unset_wandb) and global_rank == 0:
63
+ if args.entity is None:
64
+ os.environ['WANDB_MODE'] = 'offline'
65
+ # Set wandb directory for offline mode
66
+ wandb_dir = getattr(args, 'wandb_dir', None) if getattr(args, 'wandb_dir', None) is not None else args.save_dir
67
+ if getattr(args, 'wandb_dir', None) is not None:
68
+ logger.info(f"Wandb directory set to: {wandb_dir}")
69
+ wandb.init(project=args.project, name=args.name, entity=args.entity, dir=wandb_dir)
70
+
71
+ ############ Setup training data ############
72
+ if args.total_batch_size is not None:
73
+ if args.gradient_accumulation is None:
74
+ assert args.total_batch_size % world_size == 0, "total_batch_size must be divisible by world_size"
75
+ args.gradient_accumulation = args.total_batch_size // (args.batch_size * world_size)
76
+ assert args.gradient_accumulation > 0, "gradient_accumulation must be greater than 0"
77
+
78
+ assert (
79
+ args.gradient_accumulation * args.batch_size * world_size == args.total_batch_size
80
+ ), "gradient_accumulation * batch_size * world_size must be equal to total_batch_size"
81
+
82
+ dataloader, tokenizer = setup_dataset(args, global_rank, world_size)
83
+
84
+ ############ Initialize model ############
85
+ model_config, model = setup_model(args)
86
+ # Ensure model has generation_config (fix for transformers version compatibility)
87
+ if model.generation_config is None:
88
+ from transformers import GenerationConfig
89
+ model.generation_config = GenerationConfig()
90
+ model.generation_config.pad_token_id = tokenizer.pad_token_id
91
+
92
+ ############ Resuming from checkpoints ############
93
+ global_step = 0
94
+ update_step = 0
95
+ beginning_step = 0
96
+ tokens_seen = 0
97
+ tokens_seen_before = 0
98
+
99
+ # identifying checkpointing
100
+ if args.continue_from is not None and os.path.exists(args.continue_from):
101
+ # searching the latest checkpoints
102
+ checkpoint_path_list = os.listdir(args.continue_from)
103
+ checkpoint_path_list = [int(x.split("_")[-1]) for x in checkpoint_path_list if x.startswith("model_")]
104
+ if len(checkpoint_path_list) > 0:
105
+ logger.info("Find Checkpoints", checkpoint_path_list)
106
+ beginning_step = max(checkpoint_path_list)
107
+ if args.resume_step is not None:
108
+ beginning_step = args.resume_step
109
+ args.continue_from = os.path.join(args.continue_from, f"model_{beginning_step}")
110
+ logger.info("Continue from", args.continue_from)
111
+ else:
112
+ logger.warning(f"Did not find any checkpoints in {args.continue_from}")
113
+ args.continue_from = None
114
+
115
+ # resuming from checkpointing
116
+ if args.continue_from is not None:
117
+ logger.info("*" * 40)
118
+ logger.info(f"Loading model from {args.continue_from}")
119
+ checkpoint_path = os.path.join(args.continue_from, "pytorch_model.bin")
120
+ if os.path.exists(checkpoint_path):
121
+ load_model_weight(model, checkpoint_path, args)
122
+ logger.info(f"Model successfully loaded (strict=False policy)")
123
+ else:
124
+ # Try safetensors format
125
+ checkpoint_path = os.path.join(args.continue_from, "model.safetensors")
126
+ if os.path.exists(checkpoint_path):
127
+ from safetensors import safe_open
128
+ tensors = {}
129
+ with safe_open(checkpoint_path, framework="pt", device=0) as f:
130
+ for k in f.keys():
131
+ tensors[k] = f.get_tensor(k)
132
+ print(k, tensors[k].shape)
133
+ ret = model.load_state_dict(tensors, strict=False)
134
+ logger.info(f"Model successfully loaded from safetensors (strict=False policy)", ret)
135
+ else:
136
+ logger.warning(f"No model checkpoint found in {args.continue_from}")
137
+
138
+ if os.path.exists(os.path.join(args.continue_from, "training_state.json")):
139
+ logger.info(
140
+ f"Loading training state like global_step, update_step, and tokens_seen from {args.continue_from}"
141
+ )
142
+ with open(os.path.join(args.continue_from, "training_state.json")) as f:
143
+ _old_state = json.load(f)
144
+ global_step = _old_state["global_step"]
145
+ update_step = _old_state["update_step"]
146
+ tokens_seen = _old_state["tokens_seen"]
147
+ tokens_seen_before = _old_state["tokens_seen_before"]
148
+ logger.info(f"global_step : {global_step}")
149
+ logger.info(f"update_step : {update_step}")
150
+ logger.info(f"tokens_seen : {tokens_seen}")
151
+ logger.info(f"tokens_seen_before: {tokens_seen_before}")
152
+ logger.info(f"Will train for {args.num_training_steps - update_step} update steps")
153
+ else:
154
+ logger.warning(f"Did not find training state in {args.continue_from}, global step will start from zero")
155
+ logger.info("*" * 40)
156
+
157
+ ############ Setup model ############
158
+ if args.dtype in ["bf16", "bfloat16"]:
159
+ model = model.to(dtype=torch.bfloat16)
160
+ model = model.to(device=device)
161
+
162
+ for _, module in model.named_modules():
163
+ if isinstance(module, QScaleLinear):
164
+ weight_device = module.weight.device
165
+ module.weight.scales = module.weight.scales.to(device=weight_device)
166
+ module.weight.zeros = module.weight.zeros.to(device=weight_device)
167
+
168
+ n_total_params = sum(p.numel() for p in model.parameters())
169
+ trainable_params = [p for p in model.parameters() if p.requires_grad]
170
+ trainable_params_int8 = [p for p in model.parameters() if hasattr(p, "group_size")]
171
+
172
+ ############ Initialize wandb ############
173
+ run_config = dict(vars(args))
174
+ run_config.update(
175
+ {
176
+ "max_lr": run_config.pop("lr"), # rename lr to max_lr to avoid conflicts with scheduler
177
+ "total_params_M": n_total_params / 1_000_000,
178
+ "dataset": "c4",
179
+ "model": model_config.to_dict(),
180
+ "world_size": world_size,
181
+ "device": str(device),
182
+ }
183
+ )
184
+
185
+ if global_rank == 0:
186
+ if not args.unset_wandb:
187
+ wandb.config.update(run_config, allow_val_change=True)
188
+ wandb.save(os.path.abspath(__file__), policy="now") # save current script
189
+ # fix tqdm visual length to 80 so that the progress bar
190
+ # doesn't jump around when changing from external display to laptop
191
+ pbar = tqdm(total=args.num_training_steps - update_step, desc="Update steps", ncols=80)
192
+
193
+ ############ Initialize optimization ############
194
+ if "galore" in args.optimizer.lower():
195
+ # make parameters with "rank" to a single group, if param_name has "mlp" or "attn"
196
+ lowrank_params = []
197
+ target_modules_list = ["attn", "mlp"]
198
+ for module_name, module in model.named_modules():
199
+ if not (isinstance(module, nn.Linear) or isinstance(module, QScaleLinear) or isinstance(module, QLinear)):
200
+ continue
201
+ if not any(target_key in module_name for target_key in target_modules_list):
202
+ continue
203
+ logger.info(f"Adding {module_name} to GaLore parameters")
204
+ lowrank_params.append(module.weight)
205
+
206
+ id_lowrank_params = [id(p) for p in lowrank_params]
207
+ # make parameters without "rank" to another group
208
+ regular_params = [p for p in model.parameters() if id(p) not in id_lowrank_params]
209
+ # then call low rank optimizer
210
+ param_groups = [
211
+ {"params": regular_params},
212
+ {
213
+ "params": lowrank_params,
214
+ "rank": args.rank,
215
+ "update_proj_gap": args.update_proj_gap,
216
+ "scale": args.galore_scale,
217
+ "proj_type": args.proj_type,
218
+ "quant": args.proj_quant,
219
+ "quant_n_bit": args.proj_bits,
220
+ "quant_group_size": args.proj_group_size,
221
+ "cos_threshold": args.cos_threshold,
222
+ "gamma_proj": args.gamma_proj,
223
+ "queue_size": args.queue_size,
224
+ },
225
+ ]
226
+ elif "apollo" in args.optimizer.lower():
227
+ # make parameters with "rank" to a single group, if param_name has "mlp" or "attn"
228
+ lowrank_params = []
229
+ target_modules_list = ["attn", "mlp"]
230
+ for module_name, module in model.named_modules():
231
+ if not (isinstance(module, nn.Linear) or isinstance(module, QScaleLinear) or isinstance(module, QLinear)):
232
+ continue
233
+ if not any(target_key in module_name for target_key in target_modules_list):
234
+ continue
235
+ logger.info(f"Adding {module_name} to APOLLO parameters")
236
+ lowrank_params.append(module.weight)
237
+
238
+ id_lowrank_params = [id(p) for p in lowrank_params]
239
+ # make parameters without "rank" to another group
240
+ regular_params = [p for p in model.parameters() if id(p) not in id_lowrank_params]
241
+ # then call low rank optimizer
242
+ param_groups = [
243
+ {"params": regular_params},
244
+ {
245
+ "params": lowrank_params,
246
+ "rank": args.rank,
247
+ "update_proj_gap": args.update_proj_gap,
248
+ "scale": args.apollo_scale,
249
+ "proj_type": args.proj_type,
250
+ "proj": args.proj,
251
+ "scale_type": args.scale_type,
252
+ },
253
+ ]
254
+ elif "conda" in args.optimizer.lower():
255
+ # make parameters with "rank" to a single group, if param_name has "mlp" or "attn"
256
+ lowrank_params = []
257
+ target_modules_list = ["attn", "mlp"]
258
+ for module_name, module in model.named_modules():
259
+ if not (isinstance(module, nn.Linear) or isinstance(module, QScaleLinear) or isinstance(module, QLinear)):
260
+ continue
261
+ if not any(target_key in module_name for target_key in target_modules_list):
262
+ continue
263
+ logger.info(f"Adding {module_name} to conda parameters")
264
+ lowrank_params.append(module.weight)
265
+
266
+ id_lowrank_params = [id(p) for p in lowrank_params]
267
+ # make parameters without "rank" to another group
268
+ regular_params = [p for p in model.parameters() if id(p) not in id_lowrank_params]
269
+ # then call low rank optimizer
270
+ param_groups = [
271
+ {"params": regular_params},
272
+ {
273
+ "params": lowrank_params,
274
+ "rank": args.rank,
275
+ "update_proj_gap": args.update_proj_gap,
276
+ "scale": args.apollo_scale,
277
+ "proj_type": args.proj_type,
278
+ "proj": args.proj,
279
+ "scale_type": args.scale_type,
280
+ },
281
+ ]
282
+ else:
283
+ param_groups = None
284
+ id_lowrank_params = None
285
+
286
+ # print params and trainable params
287
+ logger.info(f"\n{model}\n")
288
+ logger.info(f"Total params: {sum(p.numel() for p in model.parameters()) / 1_000_000:.2f}M")
289
+
290
+ if args.simulation:
291
+ num_train_params = sum(p.numel() for p in trainable_params)
292
+ else:
293
+ num_train_params = sum(p.numel() for p in trainable_params) + sum(p.numel() for p in trainable_params_int8)
294
+
295
+ logger.info(f"Trainable params: {num_train_params / 1_000_000:.2f}M")
296
+ if "q_galore" in args.optimizer.lower():
297
+ logger.info(
298
+ f"Trainable params with Q-GaLore enabled: {sum(p.numel() for p in trainable_params_int8) / 1_000_000:.2f}M"
299
+ )
300
+ elif "galore" in args.optimizer.lower():
301
+ logger.info(f"Total params with GaLore enabled: {sum(p.numel() for p in lowrank_params) / 1_000_000:.2f}M")
302
+ elif "q_apollo" in args.optimizer.lower():
303
+ logger.info(
304
+ f"Trainable params with Q-APOLLO enabled: {sum(p.numel() for p in trainable_params_int8) / 1_000_000:.2f}M"
305
+ )
306
+ elif "apollo" in args.optimizer.lower():
307
+ logger.info(f"Total params with APOLLO enabled: {sum(p.numel() for p in lowrank_params) / 1_000_000:.2f}M")
308
+
309
+ logger.info(f"Saving model to {args.save_dir} every {args.save_every} update steps")
310
+
311
+ model, optimizer, scheduler, layer_wise_flag = setup_optimization(
312
+ args, model, trainable_params, param_groups, id_lowrank_params, model_config
313
+ )
314
+
315
+ if layer_wise_flag:
316
+ # will pass optimizer_dict and scheduler_dict out instead of optimizer and scheduler
317
+ optimizer_dict = optimizer
318
+ scheduler_dict = scheduler
319
+
320
+ # Bug-3 fix: wrap with DDP *before* torch.compile per PyTorch recommendation.
321
+ # This ensures gradient reduction hooks are correctly installed on the DDP module,
322
+ # and the compiled graph captures the full DDP+model forward pass.
323
+ # (Issue-5: optimizer.load_state_dict is called after both DDP and compile below.)
324
+ if not args.single_gpu:
325
+ model: LlamaForCausalLM = torch.nn.parallel.DistributedDataParallel(
326
+ model,
327
+ device_ids=[local_rank],
328
+ output_device=local_rank,
329
+ broadcast_buffers=False,
330
+ )
331
+
332
+ # compile the model (after DDP so the compiled graph includes DDP reduction)
333
+ if args.compile:
334
+ print("Compiling the model... (takes a ~minute)")
335
+ unoptimized_model = model
336
+
337
+ # Configure TorchDynamo to suppress errors and fall back to eager mode
338
+ import torch._dynamo
339
+ torch._dynamo.config.suppress_errors = args.dynamo_suppress_errors
340
+ torch._dynamo.config.verbose = False
341
+ # Set cache size limit to prevent memory issues during long training
342
+ torch._dynamo.config.cache_size_limit = args.dynamo_cache_limit
343
+
344
+ model = torch.compile(model) # requires PyTorch 2.0
345
+
346
+ # resume optimizer
347
+ if args.restore_optimizer and args.continue_from is not None:
348
+ logger.info("Restoring optimizer and scheduler from the checkpoint")
349
+ _optimizer_dir = args.continue_from
350
+ optimizer_checkpoint = torch.load(os.path.join(_optimizer_dir, "optimizer.pt"), map_location="cpu")
351
+ optimizer.load_state_dict(optimizer_checkpoint["optimizer"])
352
+ scheduler.load_state_dict(optimizer_checkpoint["scheduler"])
353
+ update_step = optimizer_checkpoint["update_step"]
354
+ beginning_step = update_step
355
+ global_step = optimizer_checkpoint["global_step"]
356
+ logger.info(f"Optimizer and scheduler restored from {_optimizer_dir}")
357
+
358
+ # ##############################
359
+ # TRAINING LOOP
360
+ # we use iterable dataset, so we may never go through all the data
361
+ # ##############################
362
+ # global steps and others are defined above
363
+ pad_idx = tokenizer.pad_token_id
364
+ update_time = time.time()
365
+ local_step = 0 # when continue_from is used, local_step != global_step
366
+ total_svd_count = 0
367
+
368
+ dataloader_iter = iter(dataloader)
369
+
370
+ # Issue-4 fix: accumulate loss across micro-batches so logged loss is the true
371
+ # gradient-accumulation average, not just the last micro-batch.
372
+ accumulated_loss = 0.0
373
+
374
+ # Skip data if resuming from checkpoint
375
+ if update_step != 0:
376
+ skip_batches = args.gradient_accumulation * update_step
377
+ logger.info(f"Skipping {skip_batches} batches to resume from update step {update_step}")
378
+ skipped = 0
379
+ for _ in range(skip_batches):
380
+ # Issue-6 fix: handle StopIteration during skip so all ranks stay aligned
381
+ try:
382
+ next(dataloader_iter)
383
+ except StopIteration:
384
+ logger.warning(
385
+ f"Dataset exhausted during skip at batch {skipped}/{skip_batches}; "
386
+ f"restarting iterator to keep ranks aligned."
387
+ )
388
+ dataloader_iter = iter(dataloader)
389
+ next(dataloader_iter)
390
+ skipped += 1
391
+ logger.info(f"Skipped {skipped} batches successfully")
392
+
393
+ while update_step <= args.num_training_steps:
394
+ try:
395
+ batch = next(dataloader_iter)
396
+ except StopIteration:
397
+ logger.info(f"Dataset completed one epoch. Starting new epoch with reshuffled data.")
398
+ dataloader_iter = iter(dataloader)
399
+ batch = next(dataloader_iter)
400
+
401
+ global_step += 1
402
+ local_step += 1
403
+
404
+ if update_step >= args.num_training_steps:
405
+ logger.info(f"Reached max number of update steps ({args.num_training_steps}). Stopping training.")
406
+ logger.info(f"Rank {global_rank} stopping training.")
407
+ break
408
+
409
+ # forward & backward
410
+ batch = {k: v.to(device) for k, v in batch.items()}
411
+ labels = batch["input_ids"].clone()
412
+ labels[labels == pad_idx] = -100
413
+ tokens_seen += (batch["input_ids"] != pad_idx).sum().item() * world_size
414
+
415
+ loss = model(**batch, labels=labels).loss
416
+
417
+ scaled_loss = loss / args.gradient_accumulation
418
+ scaled_loss.backward()
419
+ accumulated_loss += loss.item() # Issue-4: accumulate before the continue
420
+
421
+ if global_step % args.gradient_accumulation != 0:
422
+ continue
423
+
424
+ # The below code is only executed during the update step
425
+ # Issue-4: compute average loss over all micro-batches in this accumulation window
426
+ avg_loss = accumulated_loss / args.gradient_accumulation
427
+ accumulated_loss = 0.0 # reset for next accumulation window
428
+ # add grad clipping: TODO: add gradient clipping of int8 weight
429
+ if args.grad_clipping != 0.0:
430
+ torch.nn.utils.clip_grad_norm_(trainable_params, args.grad_clipping)
431
+ # Periodic memory cleanup to prevent symbolic tensor issues during long training
432
+ if global_step % args.memory_cleanup_frequency == 0:
433
+ torch.cuda.empty_cache()
434
+ # Clear TorchDynamo cache to prevent memory accumulation
435
+ if args.compile:
436
+ import torch._dynamo
437
+ torch._dynamo.reset()
438
+
439
+ if global_rank == 0:
440
+ pbar.update(1)
441
+ if not layer_wise_flag: # layer-wise updation is done during backward; requires gradient_accumulation equals 1
442
+ optimizer.step()
443
+ scheduler.step()
444
+ optimizer.zero_grad()
445
+
446
+ update_step += 1
447
+ update_time = time.time() - update_time
448
+
449
+ # save checkpoint by save_every
450
+ if local_step > args.gradient_accumulation and update_step % args.save_every == 0 and global_rank == 0:
451
+ current_model_directory = f"{args.save_dir}/model_{update_step}"
452
+ logger.info(f"Saving model and optimizer to {current_model_directory}, update step {update_step}")
453
+ os.makedirs(args.save_dir, exist_ok=True)
454
+ # Bug-1 fix: unwrap DDP/compiled model for saving; works in both single-GPU and multi-GPU modes
455
+ unwrapped_model = model.module if hasattr(model, 'module') else model
456
+ unwrapped_model.save_pretrained(current_model_directory, max_shard_size="500GB", from_pt=True)
457
+ saving_model_weight(unwrapped_model, f"{current_model_directory}/pytorch_model.bin", args)
458
+
459
+ optimizer_checkpoint = {
460
+ "optimizer": optimizer.state_dict(),
461
+ "scheduler": scheduler.state_dict(),
462
+ "update_step": update_step,
463
+ "global_step": global_step,
464
+ "config": run_config,
465
+ "wandb": wandb.run.dir if not args.unset_wandb else None,
466
+ "dtype": args.dtype,
467
+ }
468
+ torch.save(optimizer_checkpoint, f"{current_model_directory}/optimizer.pt")
469
+
470
+ training_state_checkpoint = {
471
+ "global_step": global_step,
472
+ "update_step": update_step,
473
+ "tokens_seen": tokens_seen,
474
+ "tokens_seen_before": tokens_seen_before,
475
+ "update_time": update_time,
476
+ }
477
+ with open(f"{current_model_directory}/training_state.json", "w") as f:
478
+ json.dump(training_state_checkpoint, f, indent=4)
479
+
480
+ # save wandb related info
481
+ if not args.unset_wandb:
482
+ wandb_info = {
483
+ "wandb_id": wandb.run.id,
484
+ }
485
+ with open(f"{args.save_dir}/wandb.json", "w") as f:
486
+ json.dump(wandb_info, f, indent=4)
487
+
488
+ # evaluation
489
+ if update_step % args.eval_every == 0:
490
+ logger.info(f"Performing evaluation at step {update_step}")
491
+ total_loss, evaluated_on_tokens, perplexity = evaluate_model(
492
+ model, tokenizer, pad_idx, global_rank, world_size, device, args
493
+ )
494
+
495
+ if global_rank == 0:
496
+ if not args.unset_wandb:
497
+ wandb.log(
498
+ {
499
+ "eval_loss": total_loss,
500
+ "eval_perplexity": perplexity,
501
+ "eval_tokens": evaluated_on_tokens,
502
+ },
503
+ step=update_step,
504
+ )
505
+ logger.info(f"Eval loss at step {update_step}: {total_loss}, Eval perplexity: {perplexity}")
506
+
507
+ if not layer_wise_flag:
508
+ lr = optimizer.param_groups[0]["lr"]
509
+ else:
510
+ lr = list(optimizer_dict.values())[0].param_groups[0]["lr"]
511
+ tokens_in_update = tokens_seen - tokens_seen_before
512
+ tokens_seen_before = tokens_seen
513
+ batches_in_update = args.gradient_accumulation * world_size
514
+ if not layer_wise_flag:
515
+ total_svd_count = getting_svd_cnt(optimizer)
516
+ else:
517
+ total_svd_count = 0
518
+
519
+ if global_rank == 0:
520
+ if not args.unset_wandb:
521
+ wandb.log(
522
+ {
523
+ "loss": avg_loss,
524
+ "lr": lr,
525
+ "update_step": update_step,
526
+ "tokens_seen": tokens_seen,
527
+ "total_svd_count": total_svd_count,
528
+ "throughput_tokens": tokens_in_update / update_time,
529
+ "throughput_examples": args.total_batch_size / update_time,
530
+ "throughput_batches": batches_in_update / update_time,
531
+ },
532
+ step=update_step,
533
+ )
534
+ update_time = time.time()
535
+
536
+ # ##############################
537
+ # END of training loop
538
+ # ##############################
539
+ logger.info("Training finished")
540
+ if global_rank == 0:
541
+ pbar.close()
542
+
543
+ current_model_directory = f"{args.save_dir}/model_{update_step}"
544
+ if global_rank == 0 and not os.path.exists(current_model_directory):
545
+ logger.info(f"Saving model and optimizer to {current_model_directory}, update step {update_step}")
546
+ os.makedirs(args.save_dir, exist_ok=True)
547
+ # Bug-1 fix: unwrap DDP/compiled model for saving; works in both single-GPU and multi-GPU modes
548
+ unwrapped_model = model.module if hasattr(model, 'module') else model
549
+ unwrapped_model.save_pretrained(current_model_directory, max_shard_size="500GB", from_pt=True)
550
+ saving_model_weight(unwrapped_model, f"{current_model_directory}/pytorch_model.bin", args)
551
+
552
+ optimizer_checkpoint = {
553
+ "optimizer": optimizer.state_dict(),
554
+ "scheduler": scheduler.state_dict(),
555
+ "update_step": update_step,
556
+ "global_step": global_step,
557
+ "config": run_config,
558
+ "wandb": wandb.run.dir if not args.unset_wandb else None,
559
+ "dtype": args.dtype,
560
+ }
561
+ torch.save(optimizer_checkpoint, f"{current_model_directory}/optimizer.pt")
562
+
563
+ training_state_checkpoint = {
564
+ "global_step": global_step,
565
+ "update_step": update_step,
566
+ "tokens_seen": tokens_seen,
567
+ "tokens_seen_before": tokens_seen_before,
568
+ "update_time": update_time,
569
+ }
570
+ with open(f"{current_model_directory}/training_state.json", "w") as f:
571
+ json.dump(training_state_checkpoint, f, indent=4)
572
+
573
+ # Final evaluation
574
+ logger.info("Running final evaluation")
575
+ model.eval()
576
+ del loss, optimizer, scheduler
577
+ import gc
578
+
579
+ gc.collect()
580
+ torch.cuda.empty_cache()
581
+
582
+ total_loss, evaluated_on_tokens, perplexity = evaluate_model(model, tokenizer, pad_idx, global_rank, world_size, device, args)
583
+
584
+ if global_rank == 0:
585
+ if not args.unset_wandb:
586
+ wandb.log(
587
+ {
588
+ "final_eval_loss": total_loss,
589
+ "final_eval_perplexity": perplexity,
590
+ "final_eval_tokens": evaluated_on_tokens,
591
+ },
592
+ step=update_step,
593
+ )
594
+ logger.info(f"Final eval loss: {total_loss}, Final eval perplexity: {perplexity}")
595
+
596
+ logger.info("Script finished successfully")
597
+ print(f"Rank {global_rank} finished successfully")
598
+
599
+
600
+ if __name__ == "__main__":
601
+ print("Starting script")
602
+ args = parse_args(None)
603
+ main(args)
130m/came_lr5e_4_b1_0_9_b2_0_999_eps_1e_6_A100_ppl_23_7861_20260421_151403/wandb/offline-run-20260421_151426-c42l43zw/files/requirements.txt ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aria2==0.0.1b0
2
+ anyio==4.12.0
3
+ setuptools==80.9.0
4
+ torchvision==0.20.1+cu121
5
+ pyarrow==20.0.0
6
+ peft==0.17.1
7
+ conda-pack==0.8.1
8
+ scikit-learn==1.6.1
9
+ pyparsing==3.3.1
10
+ sympy==1.13.1
11
+ typer-slim==0.20.1
12
+ pip==25.1.1
13
+ pip==25.3
14
+ fonttools==4.60.2
15
+ packaging==25.0
16
+ click==8.1.8
17
+ accelerate==1.10.1
18
+ psutil==7.2.0
19
+ wheel==0.45.1
20
+ multidict==6.7.0
21
+ requests==2.32.5
22
+ async-timeout==5.0.1
23
+ triton==3.1.0
24
+ loguru==0.7.3
25
+ aiohappyeyeballs==2.6.1
26
+ sentry-sdk==2.48.0
27
+ annotated-types==0.7.0
28
+ certifi==2025.11.12
29
+ nvidia-curand-cu12==10.3.2.106
30
+ shellingham==1.5.4
31
+ package_name==0.1
32
+ wandb==0.23.0
33
+ nvitop==1.6.1
34
+ nvidia-nccl-cu12==2.21.5
35
+ nvidia-cublas-cu12==12.1.3.1
36
+ tokenizers==0.22.1
37
+ nvidia-cusparse-cu12==12.1.0.106
38
+ scipy==1.13.1
39
+ propcache==0.4.1
40
+ nvidia-ml-py==13.580.82
41
+ typing_extensions==4.15.0
42
+ sac==0.1.0
43
+ torch-optimizer==0.3.0
44
+ aria2==0.0.1b0
45
+ h11==0.16.0
46
+ pillow==11.3.0
47
+ PyYAML==6.0.3
48
+ six==1.17.0
49
+ GitPython==3.1.45
50
+ addict==2.4.0
51
+ seaborn==0.13.2
52
+ filelock==3.19.1
53
+ modelscope==1.33.0
54
+ et_xmlfile==2.0.0
55
+ regex==2025.11.3
56
+ nvidia-cufft-cu12==11.0.2.54
57
+ nvidia-cuda-cupti-cu12==12.1.105
58
+ lion-pytorch==0.2.3
59
+ matplotlib==3.9.4
60
+ pandas==2.3.2
61
+ gitdb==4.0.12
62
+ kiwisolver==1.4.7
63
+ idna==3.11
64
+ numpy==2.0.2
65
+ nvidia-cuda-runtime-cu12==12.1.105
66
+ httpx==0.28.1
67
+ frozenlist==1.8.0
68
+ smmap==5.0.2
69
+ datasets==2.14.0
70
+ yarl==1.22.0
71
+ eval_type_backport==0.3.1
72
+ nvidia-cuda-nvrtc-cu12==12.1.105
73
+ huggingface-hub==0.36.0
74
+ torchaudio==2.5.1+cu121
75
+ aiosignal==1.4.0
76
+ importlib_resources==6.5.2
77
+ nvidia-cusolver-cu12==11.4.5.107
78
+ networkx==3.2.1
79
+ tzdata==2025.3
80
+ bitsandbytes==0.42.0
81
+ cycler==0.12.1
82
+ jq==1.10.0
83
+ mpmath==1.3.0
84
+ pydantic_core==2.41.5
85
+ nvidia-cudnn-cu12==9.1.0.70
86
+ typing-inspection==0.4.2
87
+ httpcore==1.0.9
88
+ nvidia-nvtx-cu12==12.1.105
89
+ platformdirs==4.4.0
90
+ MarkupSafe==2.1.5
91
+ multiprocess==0.70.15
92
+ zipp==3.23.0
93
+ transformers==4.57.3
94
+ nvidia-nvjitlink-cu12==12.9.86
95
+ exceptiongroup==1.3.1
96
+ pydantic==2.12.5
97
+ charset-normalizer==3.4.4
98
+ joblib==1.5.3
99
+ dill==0.3.7
100
+ fsspec==2023.9.2
101
+ torch==2.5.1+cu121
102
+ aiohttp==3.13.2
103
+ urllib3==2.6.2
104
+ apollo-torch==1.0.3
105
+ contourpy==1.3.0
106
+ evaluate==0.4.6
107
+ attrs==25.4.0
108
+ pytz==2025.2
109
+ safetensors==0.7.0
110
+ pytorch-ranger==0.1.1
111
+ threadpoolctl==3.6.0
112
+ Jinja2==3.1.6
113
+ protobuf==6.33.2
114
+ python-dateutil==2.9.0.post0
115
+ xxhash==3.6.0
116
+ openpyxl==3.1.5
117
+ hf-xet==1.2.0
118
+ tqdm==4.67.1
119
+ jaraco.context==5.3.0
120
+ platformdirs==4.2.2
121
+ importlib_metadata==8.0.0
122
+ more-itertools==10.3.0
123
+ typing_extensions==4.12.2
124
+ autocommand==2.2.2
125
+ wheel==0.45.1
126
+ zipp==3.19.2
127
+ packaging==24.2
128
+ backports.tarfile==1.2.0
129
+ inflect==7.3.1
130
+ typeguard==4.3.0
131
+ jaraco.functools==4.0.1
132
+ jaraco.collections==5.1.0
133
+ jaraco.text==3.12.1
134
+ tomli==2.0.1
130m/came_lr5e_4_b1_0_9_b2_0_999_eps_1e_6_A100_ppl_23_7861_20260421_151403/wandb/offline-run-20260421_151426-c42l43zw/logs/debug-core.log ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"time":"2026-04-21T15:14:26.735732983+08:00","level":"INFO","msg":"main: starting server","port-filename":"/tmp/tmpxhim34_u/port-80806.txt","pid":80806,"log-level":0,"disable-analytics":false,"shutdown-on-parent-exit":false,"enable-dcgm-profiling":false}
2
+ {"time":"2026-04-21T15:14:26.738200789+08:00","level":"INFO","msg":"server: will exit if parent process dies","ppid":80806}
3
+ {"time":"2026-04-21T15:14:26.738205728+08:00","level":"INFO","msg":"server: accepting connections","addr":{"Name":"/tmp/wandb-80806-82988-2886939369/socket","Net":"unix"}}
4
+ {"time":"2026-04-21T15:14:26.812546826+08:00","level":"INFO","msg":"connection: ManageConnectionData: new connection created","id":"1(@)"}
5
+ {"time":"2026-04-21T15:14:26.848007733+08:00","level":"INFO","msg":"handleInformInit: received","streamId":"c42l43zw","id":"1(@)"}
6
+ {"time":"2026-04-21T15:14:27.147653649+08:00","level":"INFO","msg":"handleInformInit: stream started","streamId":"c42l43zw","id":"1(@)"}
7
+ {"time":"2026-04-21T16:38:11.996804042+08:00","level":"INFO","msg":"handleInformTeardown: server teardown initiated","id":"1(@)"}
8
+ {"time":"2026-04-21T16:38:11.997505446+08:00","level":"INFO","msg":"connection: closing","id":"1(@)"}
9
+ {"time":"2026-04-21T16:38:11.99801937+08:00","level":"INFO","msg":"connection: closed successfully","id":"1(@)"}
10
+ {"time":"2026-04-21T16:38:11.997526129+08:00","level":"INFO","msg":"server is shutting down"}
11
+ {"time":"2026-04-21T16:38:11.998906394+08:00","level":"INFO","msg":"server: listener closed","addr":{"Name":"/tmp/wandb-80806-82988-2886939369/socket","Net":"unix"}}
12
+ {"time":"2026-04-21T16:38:12.000733294+08:00","level":"INFO","msg":"handleInformTeardown: server shutdown complete","id":"1(@)"}
13
+ {"time":"2026-04-21T16:38:12.001134961+08:00","level":"INFO","msg":"connection: ManageConnectionData: connection closed","id":"1(@)"}
14
+ {"time":"2026-04-21T16:38:12.001608865+08:00","level":"INFO","msg":"server is closed"}
130m/came_lr5e_4_b1_0_9_b2_0_999_eps_1e_6_A100_ppl_23_7861_20260421_151403/wandb/offline-run-20260421_151426-c42l43zw/logs/debug-internal.log ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"time":"2026-04-21T15:14:26.9128003+08:00","level":"INFO","msg":"stream: starting","core version":"0.23.0"}
2
+ {"time":"2026-04-21T15:14:27.141384684+08:00","level":"WARN","msg":"featurechecker: GraphQL client is nil, skipping feature loading"}
3
+ {"time":"2026-04-21T15:14:27.141455586+08:00","level":"INFO","msg":"stream: created new stream","id":"c42l43zw"}
4
+ {"time":"2026-04-21T15:14:27.141499443+08:00","level":"INFO","msg":"handler: started","stream_id":"c42l43zw"}
5
+ {"time":"2026-04-21T15:14:27.147641439+08:00","level":"INFO","msg":"stream: started","id":"c42l43zw"}
6
+ {"time":"2026-04-21T15:14:27.14764711+08:00","level":"INFO","msg":"writer: started","stream_id":"c42l43zw"}
7
+ {"time":"2026-04-21T15:14:27.147658516+08:00","level":"INFO","msg":"sender: started","stream_id":"c42l43zw"}
8
+ {"time":"2026-04-21T15:14:27.149120861+08:00","level":"WARN","msg":"runupserter: server does not expand metric globs but the x_server_side_expand_glob_metrics setting is set; ignoring"}
9
+ {"time":"2026-04-21T16:38:11.997527579+08:00","level":"INFO","msg":"stream: closing","id":"c42l43zw"}
10
+ {"time":"2026-04-21T16:38:11.998219162+08:00","level":"INFO","msg":"handler: closed","stream_id":"c42l43zw"}
11
+ {"time":"2026-04-21T16:38:11.999549553+08:00","level":"INFO","msg":"sender: closed","stream_id":"c42l43zw"}
12
+ {"time":"2026-04-21T16:38:11.999562024+08:00","level":"INFO","msg":"stream: closed","id":"c42l43zw"}
130m/came_lr5e_4_b1_0_9_b2_0_999_eps_1e_6_A100_ppl_23_7861_20260421_151403/wandb/offline-run-20260421_151426-c42l43zw/logs/debug.log ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2026-04-21 15:14:26,395 INFO MainThread:80806 [wandb_setup.py:_flush():80] Current SDK version is 0.23.0
2
+ 2026-04-21 15:14:26,395 INFO MainThread:80806 [wandb_setup.py:_flush():80] Configure stats pid to 80806
3
+ 2026-04-21 15:14:26,395 INFO MainThread:80806 [wandb_setup.py:_flush():80] Loading settings from /mnt/petrelfs/panjiabao/.config/wandb/settings
4
+ 2026-04-21 15:14:26,395 INFO MainThread:80806 [wandb_setup.py:_flush():80] Loading settings from /mnt/petrelfs/panjiabao/Optimizer/SAC/wandb/settings
5
+ 2026-04-21 15:14:26,395 INFO MainThread:80806 [wandb_setup.py:_flush():80] Loading settings from environment variables
6
+ 2026-04-21 15:14:26,395 INFO MainThread:80806 [wandb_init.py:setup_run_log_directory():713] Logging user logs to /mnt/dhwfile/tancheng/panjiabao/Result/SAC_C4/work_dirs/130m/came_v3_lr5e_4_b1_0_9_b2_0_999_eps_1e_6_20260421_151403/wandb/offline-run-20260421_151426-c42l43zw/logs/debug.log
7
+ 2026-04-21 15:14:26,395 INFO MainThread:80806 [wandb_init.py:setup_run_log_directory():714] Logging internal logs to /mnt/dhwfile/tancheng/panjiabao/Result/SAC_C4/work_dirs/130m/came_v3_lr5e_4_b1_0_9_b2_0_999_eps_1e_6_20260421_151403/wandb/offline-run-20260421_151426-c42l43zw/logs/debug-internal.log
8
+ 2026-04-21 15:14:26,395 INFO MainThread:80806 [wandb_init.py:init():840] calling init triggers
9
+ 2026-04-21 15:14:26,395 INFO MainThread:80806 [wandb_init.py:init():845] wandb.init called with sweep_config: {}
10
+ config: {'_wandb': {}}
11
+ 2026-04-21 15:14:26,395 INFO MainThread:80806 [wandb_init.py:init():888] starting backend
12
+ 2026-04-21 15:14:26,812 INFO MainThread:80806 [wandb_init.py:init():891] sending inform_init request
13
+ 2026-04-21 15:14:26,843 INFO MainThread:80806 [wandb_init.py:init():899] backend started and connected
14
+ 2026-04-21 15:14:26,847 INFO MainThread:80806 [wandb_init.py:init():969] updated telemetry
15
+ 2026-04-21 15:14:26,894 INFO MainThread:80806 [wandb_init.py:init():993] communicating run to backend with 90.0 second timeout
16
+ 2026-04-21 15:14:27,150 INFO MainThread:80806 [wandb_init.py:init():1040] starting run threads in backend
17
+ 2026-04-21 15:14:27,510 INFO MainThread:80806 [wandb_run.py:_console_start():2504] atexit reg
18
+ 2026-04-21 15:14:27,510 INFO MainThread:80806 [wandb_run.py:_redirect():2352] redirect: wrap_raw
19
+ 2026-04-21 15:14:27,510 INFO MainThread:80806 [wandb_run.py:_redirect():2421] Wrapping output streams.
20
+ 2026-04-21 15:14:27,510 INFO MainThread:80806 [wandb_run.py:_redirect():2444] Redirects installed.
21
+ 2026-04-21 15:14:27,517 INFO MainThread:80806 [wandb_init.py:init():1080] run started, returning control to user process
22
+ 2026-04-21 15:14:32,995 INFO MainThread:80806 [wandb_run.py:_config_callback():1385] config_cb None None {'model_config': 'configs/llama_130m.json', 'exp_config': 'exp_v2/configs/llama_130m.json', 'eval_every': 1000, 'save_every': 20000, 'dtype': 'bfloat16', 'seed': 0, 'compile': True, 'dynamo_suppress_errors': True, 'dynamo_cache_limit': 10000, 'memory_cleanup_frequency': 10000, 'resume_step': None, 'restore_optimizer': False, 'continue_from': None, 'single_gpu': False, 'save_dir': '/mnt/dhwfile/tancheng/panjiabao/Result/SAC_C4/work_dirs/130m/came_v3_lr5e_4_b1_0_9_b2_0_999_eps_1e_6_20260421_151403', 'use_hf_model': False, 'workers': 12, 'batch_size': 128, 'gradient_accumulation': 1, 'total_batch_size': 512, 'warmup_steps': 2000, 'num_training_steps': 20000, 'max_train_tokens': None, 'optimizer': 'came', 'max_length': 256, 'scheduler': 'cosine', 'min_lr_ratio': 0.1, 'weight_decay': 0.0, 'grad_clipping': 0.0, 'activation_checkpointing': False, 'data_path': '/mnt/dhwfile/tancheng/panjiabao/dataset/C4/en', 'data_name': 'en', 'tags': None, 'name': 'test', 'project': 'test', 'unset_wandb': False, 'entity': None, 'wandb_dir': '/mnt/dhwfile/tancheng/panjiabao/Result/SAC_C4/work_dirs/130m/came_v3_lr5e_4_b1_0_9_b2_0_999_eps_1e_6_20260421_151403', 'beta1': 0.9, 'beta2': 0.999, 'beta3': 0.99, 'eps': 1e-06, 'rank': 128, 'update_proj_gap': 50, 'galore_scale': 1.0, 'proj_type': 'std', 'proj_quant': False, 'proj_bits': 8, 'proj_group_size': 256, 'weight_quant': False, 'weight_bits': 8, 'weight_group_size': 256, 'stochastic_round': False, 'simulation': False, 'cos_threshold': 1, 'gamma_proj': 2, 'queue_size': 5, 'proj': 'random', 'scale_type': 'tensor', 'apollo_scale': 1.0, 'scale_front': False, 'n_clusters': 3, 'scale_update_freq': 500, 'scale_level': '1,0,1,1', 'scale_bound': None, 'metric': 'mean', 'align_grad': False, 'dim': 4096, 'n_heads': 32, 'muon_ns_steps': 5, 'muon_momentum': 0.95, 'nproc_per_node': 4, 'max_lr': 0.0005, 'total_params_M': 134.105856, 'dataset': 'c4', 'model': {'vocab_size': 32000, 'max_position_embeddings': 2048, 'hidden_size': 768, 'intermediate_size': 2048, 'num_hidden_layers': 12, 'num_attention_heads': 12, 'num_key_value_heads': 12, 'hidden_act': 'silu', 'initializer_range': 0.02, 'rms_norm_eps': 1e-06, 'pretraining_tp': 1, 'use_cache': True, 'rope_theta': 10000.0, 'rope_scaling': None, 'attention_bias': False, 'attention_dropout': 0.0, 'mlp_bias': False, 'head_dim': 64, 'return_dict': True, 'output_hidden_states': False, 'torchscript': False, 'dtype': None, 'pruned_heads': {}, 'tie_word_embeddings': False, 'chunk_size_feed_forward': 0, 'is_encoder_decoder': False, 'is_decoder': False, 'cross_attention_hidden_size': None, 'add_cross_attention': False, 'tie_encoder_decoder': False, 'architectures': ['LLaMAForCausalLM'], 'finetuning_task': None, 'id2label': {0: 'LABEL_0', 1: 'LABEL_1'}, 'label2id': {'LABEL_0': 0, 'LABEL_1': 1}, 'task_specific_params': None, 'problem_type': None, 'tokenizer_class': None, 'prefix': None, 'bos_token_id': 0, 'pad_token_id': -1, 'eos_token_id': 1, 'sep_token_id': None, 'decoder_start_token_id': None, 'max_length': 20, 'min_length': 0, 'do_sample': False, 'early_stopping': False, 'num_beams': 1, 'temperature': 1.0, 'top_k': 50, 'top_p': 1.0, 'typical_p': 1.0, 'repetition_penalty': 1.0, 'length_penalty': 1.0, 'no_repeat_ngram_size': 0, 'encoder_no_repeat_ngram_size': 0, 'bad_words_ids': None, 'num_return_sequences': 1, 'output_scores': False, 'return_dict_in_generate': False, 'forced_bos_token_id': None, 'forced_eos_token_id': None, 'remove_invalid_values': False, 'exponential_decay_length_penalty': None, 'suppress_tokens': None, 'begin_suppress_tokens': None, 'num_beam_groups': 1, 'diversity_penalty': 0.0, '_name_or_path': 'configs/llama_130m.json', 'transformers_version': '4.57.3', 'max_sequence_length': 1024, 'model_type': 'llama', 'tf_legacy_loss': False, 'use_bfloat16': False, 'output_attentions': False}, 'world_size': 4, 'device': 'cuda:0'}
23
+ 2026-04-21 16:38:11,996 INFO wandb-AsyncioManager-main:80806 [service_client.py:_forward_responses():80] Reached EOF.
24
+ 2026-04-21 16:38:11,997 INFO wandb-AsyncioManager-main:80806 [mailbox.py:close():137] Closing mailbox, abandoning 0 handles.
130m/came_lr5e_4_b1_0_9_b2_0_999_eps_1e_6_A100_ppl_23_7861_20260421_151403/wandb/offline-run-20260421_151426-c42l43zw/run-c42l43zw.wandb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:378b74db2f5ae0dac444534e941b8d9cc3c7209b0bd4e2a109d86685b750c93b
3
+ size 19364076