AlbertTan commited on
Commit
f5820ba
·
verified ·
1 Parent(s): cc091cf

Upload 8 files

Browse files
logs/colar/qsa-gsm/colar-final/checkpoints/colar_best.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d770babf58f8b265cfc5b658862e1d3c67159dc1bdb3be5b3b7982d9514b5fa8
3
+ size 121711010
logs/colar/qsa-gsm/colar-final/events.out.tfevents.60691.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b81c001ee7cf832b7ae0d87d7996efc0a5686a13cded40084c0ff0b4eb5c1ad
3
+ size 1298289
logs/colar/qsa-gsm/colar-final/events.out.tfevents.60691.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a57b6327d0a80040cd01677060eb0266a04e9bfb71671abd726beaeb68ef3f2
3
+ size 17551
logs/colar/qsa-gsm/colar-final/hparams.yaml ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ all_config:
2
+ trainer:
3
+ target: lightning.pytorch.trainer.Trainer
4
+ devices:
5
+ - 0
6
+ - 1
7
+ - 2
8
+ - 3
9
+ - 4
10
+ - 5
11
+ - 6
12
+ - 7
13
+ max_steps: -1
14
+ check_val_every_n_epoch: 1
15
+ log_every_n_steps: 10
16
+ num_sanity_val_steps: 2
17
+ gradient_clip_val: null
18
+ reload_dataloaders_every_n_epochs: 0
19
+ accumulate_grad_batches: 1
20
+ precision: bf16-mixed
21
+ use_distributed_sampler: true
22
+ strategy: auto
23
+ logger:
24
+ target: lightning.pytorch.loggers.TensorBoardLogger
25
+ save_dir: logs/colar
26
+ name: qsa-gsm
27
+ version: colar-final
28
+ max_epochs: 50
29
+ callbacks:
30
+ - target: lightning.pytorch.callbacks.ModelCheckpoint
31
+ save_last: true
32
+ save_top_k: 3
33
+ mode: max
34
+ monitor: monitor
35
+ auto_insert_metric_name: false
36
+ filename: epoch{epoch}__step{step}__monitor{monitor:.3f}
37
+ save_weights_only: true
38
+ seed: null
39
+ model:
40
+ target: src.models.colar.LitCoLaR
41
+ model_kwargs:
42
+ model_id: Llama-3.2-1B-Instruct
43
+ depth: 1
44
+ sft_method: CoLaR
45
+ set_pad_as_last_token: false
46
+ do_lora: true
47
+ lora_config:
48
+ r: 128
49
+ lora_alpha: 32
50
+ latent_cot_config:
51
+ ce_weight: 1
52
+ embed_modeling_weight: 1
53
+ embed_modeling_loss: mse
54
+ entropy_weight: -1e-6
55
+ pred_embed_forward_weight: 0
56
+ max_compression_factor: 5
57
+ pred_compressed_cot: true
58
+ replace_r_with_auto_prob: 0
59
+ sqrt_mean: true
60
+ latent_policy_config:
61
+ lp_determinisitc: false
62
+ lp_intermediate_size: 2048
63
+ latent_generation_config:
64
+ max_n_latent_forward: 64
65
+ latent_temperature: 1.0
66
+ compression_factor: 5
67
+ answer_generation_config:
68
+ max_new_tokens: 16
69
+ do_sample: true
70
+ top_p: 0.9
71
+ temperature: 1.0
72
+ do_rl: false
73
+ rl_config:
74
+ random_speed_in_group: false
75
+ filter_dataset: false
76
+ exp_batch_size: 8
77
+ group_size: 8
78
+ punish_latent_length: false
79
+ clip_grad_norm: 1.0
80
+ clip_eps: 0.2
81
+ use_latent_loss: true
82
+ use_answer_loss: true
83
+ n_train_samples_per_epoch: 512
84
+ training_kwargs:
85
+ optimizer:
86
+ target: torch.optim.AdamW
87
+ lr: 0.0001
88
+ weight_decay: 0.01
89
+ use_scheduler: false
90
+ scheduler:
91
+ target: constant_schedule_with_warmup
92
+ warmup_steps: 1000
93
+ dataloader:
94
+ batch_size: 32
95
+ val_batch_size: 32
96
+ num_workers: 32
97
+ pin_memory: true
98
+ persistent_workers: true
99
+ data_module:
100
+ target: src.datasets.qsa.QSADataModule
101
+ dataset_name: gsm
102
+ tiny_dataset: false
103
+ epoch_scaling: 1
logs/colar/qsa-math/colar-rl/checkpoints/colar_best.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0af9a1be218b34c41d2de6410aa70c3e24637ea9bdc60fa9c7f3ce9774a0888e
3
+ size 124346104
logs/colar/qsa-math/colar-rl/events.out.tfevents.14242.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4299501ca8fe2554dbab76ae590c22491a35d2ce5433b4f242bd39e304664d3a
3
+ size 2803637
logs/colar/qsa-math/colar-rl/events.out.tfevents.14242.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b04b26dd4a3f82b23c8f733ee4aef31c2811e8d0ea207b76b75ec707d3f77d63
3
+ size 17954
logs/colar/qsa-math/colar-rl/hparams.yaml ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ all_config:
2
+ trainer:
3
+ target: lightning.pytorch.trainer.Trainer
4
+ devices:
5
+ - 0
6
+ max_steps: -1
7
+ check_val_every_n_epoch: 1
8
+ log_every_n_steps: 1
9
+ num_sanity_val_steps: 2
10
+ gradient_clip_val: null
11
+ reload_dataloaders_every_n_epochs: 0
12
+ accumulate_grad_batches: 1
13
+ precision: 32-true
14
+ use_distributed_sampler: true
15
+ strategy: auto
16
+ logger:
17
+ target: lightning.pytorch.loggers.TensorBoardLogger
18
+ save_dir: logs/colar
19
+ name: qsa-math
20
+ version:
21
+ max_epochs: 50
22
+ callbacks:
23
+ - target: lightning.pytorch.callbacks.ModelCheckpoint
24
+ save_last: true
25
+ save_top_k: 3
26
+ mode: max
27
+ monitor: monitor
28
+ auto_insert_metric_name: false
29
+ filename: epoch{epoch}__step{step}__monitor{monitor:.3f}
30
+ save_weights_only: true
31
+ seed: null
32
+ model:
33
+ target: src.models.latent_colar.LitCoLaR
34
+ model_kwargs:
35
+ model_id: DeepSeek-R1-Distill-Qwen-1.5B
36
+ depth: 1
37
+ sft_method: colar
38
+ set_pad_as_last_token: false
39
+ do_lora: true
40
+ lora_config:
41
+ r: 128
42
+ lora_alpha: 32
43
+ latent_cot_config:
44
+ ce_weight: 1
45
+ embed_modeling_weight: 1
46
+ embed_modeling_loss: nll
47
+ entropy_weight: 0
48
+ pred_embed_forward_weight: 0
49
+ max_compression_factor: 4
50
+ pred_compressed_cot: true
51
+ replace_r_with_auto_prob: 0.0
52
+ sqrt_mean: true
53
+ latent_policy_config:
54
+ lp_determinisitc: false
55
+ lp_intermediate_size: 2048
56
+ latent_generation_config:
57
+ max_n_latent_forward: 64
58
+ latent_temperature: 1.0
59
+ compression_factor: 2
60
+ answer_generation_config:
61
+ max_new_tokens: 16
62
+ do_sample: true
63
+ top_p: 0.9
64
+ temperature: 1.0
65
+ do_rl: true
66
+ rl_config:
67
+ random_speed_in_group: false
68
+ filter_dataset: false
69
+ exp_batch_size: 8
70
+ group_size: 8
71
+ punish_latent_length: false
72
+ clip_grad_norm: 1.0
73
+ clip_eps: 0.2
74
+ use_latent_loss: true
75
+ use_answer_loss: true
76
+ n_train_samples_per_epoch: 512
77
+ training_kwargs:
78
+ optimizer:
79
+ target: torch.optim.AdamW
80
+ lr: 1.0e-06
81
+ weight_decay: 0.01
82
+ use_scheduler: false
83
+ scheduler:
84
+ target: constant_schedule_with_warmup
85
+ warmup_steps: 1000
86
+ dataloader:
87
+ batch_size: 4
88
+ val_batch_size: 32
89
+ num_workers: 32
90
+ pin_memory: true
91
+ persistent_workers: true
92
+ data_module:
93
+ target: src.datasets.qsa.QSADataModule
94
+ dataset_name: math
95
+ tiny_dataset: false
96
+ epoch_scaling: 1