natmin322 commited on
Commit
3c81e8e
·
1 Parent(s): 6c25d90
improve_gainlora/IDEA_Overall.md CHANGED
@@ -566,6 +566,8 @@ $A_t$ này đảm bảo capture **variance task-relevant tối đa** trong null-
566
  | C4: Preconditioner | `precompute_preconditioners()` → eigendecomposition | `cl_trainer_specroute.py` |
567
  | **C5: Data-informed init** | **`pre_task_data_collection()` → `eigh(Q@C@Q)` → set `lora_A.data`** | **`cl_trainer_specroute.py`** |
568
  | C5: Fallback | max eigval < 1e-6 → skip C5, keep Kaiming + InfLoRA projection | `cl_trainer_specroute.py` |
 
 
569
 
570
  ---
571
 
 
566
  | C4: Preconditioner | `precompute_preconditioners()` → eigendecomposition | `cl_trainer_specroute.py` |
567
  | **C5: Data-informed init** | **`pre_task_data_collection()` → `eigh(Q@C@Q)` → set `lora_A.data`** | **`cl_trainer_specroute.py`** |
568
  | C5: Fallback | max eigval < 1e-6 → skip C5, keep Kaiming + InfLoRA projection | `cl_trainer_specroute.py` |
569
+ | **V10a: Learned Routing** | **`Trans_input` + `prompt_key` gating with exact post-step GPM constraints** | **`t5_specroute.py` & `cl_trainer_specroute.py`** |
570
+ | **V10b: Grassmann Routing** | **Geometry-based routing via Grassmannian distance on batch principal subspaces** | **`t5_specroute.py`** |
571
 
572
  ---
573
 
improve_gainlora/SPECROUTE_IDEA.md CHANGED
@@ -433,5 +433,7 @@ The Routing–Protection Duality Theorem (Theorem 1) assumes $h \in \mathrm{span
433
  | **Adaptive GPM threshold** | ⬜ Pending | Relax constraint for later tasks to preserve capacity |
434
  | **Same-domain routing** | ⬜ Research | Geometry-based (no labels, no data) task similarity for routing |
435
  | **Rank expansion** | ⬜ Pending | Increase r for later tasks to compensate null-space shrinkage |
 
 
436
 
437
- **Key constraint**: Any direction must keep zero-replay AND maintain Routing–Protection Duality narrative (SpecRoute's core theoretical contribution). Oracle routing during training is valid; inference routing must remain parameter-free for the claim to hold.
 
433
  | **Adaptive GPM threshold** | ⬜ Pending | Relax constraint for later tasks to preserve capacity |
434
  | **Same-domain routing** | ⬜ Research | Geometry-based (no labels, no data) task similarity for routing |
435
  | **Rank expansion** | ⬜ Pending | Increase r for later tasks to compensate null-space shrinkage |
436
+ | **V10a Learned Routing** | ✅ Implemented | Relax parameter-free constraint; use ROOT's MLP & prompt keys with strict GPM |
437
+ | **V10b Grassmann Routing** | ✅ Implemented | Geometry-based routing using Grassmannian distance on batch principal subspaces |
438
 
439
+ **Key constraint**: Any direction must keep zero-replay AND maintain Routing–Protection Duality narrative (SpecRoute's core theoretical contribution). Oracle routing during training is valid; inference routing must remain parameter-free for the claim to hold (V10b achieves this, V10a relaxes it for empirical bounding).
improve_gainlora/T5_small/gen_script_long_order3_t5_small_specroute_v10a.sh ADDED
@@ -0,0 +1,906 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH -J cl
3
+ #SBATCH -o cl-%j.out
4
+ #SBATCH -p compute
5
+ #SBATCH -N 1
6
+ #SBATCH -t 20:00:00
7
+ #SBATCH --mem 128G
8
+ #SBATCH --gres=gpu:2
9
+
10
+ export CUDA_DEVICE_ORDER="PCI_BUS_ID"
11
+
12
+ port=$(shuf -i25000-30000 -n1)
13
+
14
+ # ============================================================
15
+ # Auto-detect GPU count and type for optimal parallelism
16
+ # ============================================================
17
+ NUM_GPUS=$(nvidia-smi -L 2>/dev/null | wc -l)
18
+ GPU_MEM=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits 2>/dev/null | head -1)
19
+
20
+ if [ -z "$GPU_MEM" ]; then
21
+ echo "ERROR: No GPU detected!"
22
+ exit 1
23
+ fi
24
+
25
+ # Determine GPU type
26
+ if [ "$GPU_MEM" -lt 20000 ]; then
27
+ IS_T4=1
28
+ echo "[GPU] Detected T4 GPUs (${GPU_MEM}MB VRAM each)"
29
+ else
30
+ IS_T4=0
31
+ echo "[GPU] Detected high-memory GPUs (${GPU_MEM}MB VRAM each)"
32
+ fi
33
+
34
+ # Determine parallelism strategy
35
+ if [ "$IS_T4" -eq 1 ] && [ "$NUM_GPUS" -ge 2 ]; then
36
+ GPU_MODE="t4_2gpu"
37
+ GPU_IDS="0,1"
38
+ FP16_FLAG=""
39
+ echo "[GPU] Strategy: 2x T4 DataParallel + fp32 + gradient_checkpointing"
40
+ elif [ "$IS_T4" -eq 1 ]; then
41
+ GPU_MODE="t4_1gpu"
42
+ GPU_IDS="${1:-0}"
43
+ FP16_FLAG=""
44
+ echo "[GPU] Strategy: 1x T4 + fp32 + gradient_checkpointing"
45
+ else
46
+ GPU_MODE="a100"
47
+ GPU_IDS="${1:-0}"
48
+ FP16_FLAG=""
49
+ echo "[GPU] Strategy: A100 (single GPU, fp32)"
50
+ fi
51
+
52
+ echo "[GPU] Using CUDA_VISIBLE_DEVICES=$GPU_IDS"
53
+ echo "============================================================"
54
+ echo ""
55
+
56
+ if [ "$GPU_MODE" = "t4_2gpu" ]; then
57
+ BSZ=16; GA=1; EVAL_BSZ=256
58
+ elif [ "$GPU_MODE" = "t4_1gpu" ]; then
59
+ BSZ=32; GA=1; EVAL_BSZ=256
60
+ else
61
+ BSZ=64; GA=1; EVAL_BSZ=512
62
+ fi
63
+
64
+ CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
65
+ --do_train \
66
+ --do_predict \
67
+ --predict_with_generate \
68
+ --model_name_or_path $2 \
69
+ --data_dir CL_Benchmark \
70
+ --task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
71
+ --task_config_dir configs/gen_script_long_order3_t5_configs/yelp \
72
+ --output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/1-yelp \
73
+ --per_device_train_batch_size $BSZ \
74
+ --per_device_eval_batch_size $EVAL_BSZ \
75
+ --gradient_accumulation_steps $GA \
76
+ --learning_rate 0.0003 \
77
+ --num_train_epochs 10 \
78
+ --run_name gen_script_long_order3_t5_small_specroute_v10a \
79
+ --max_source_length 512 \
80
+ --max_target_length 50 \
81
+ --generation_max_length 50 \
82
+ --add_task_name False \
83
+ --add_dataset_name False \
84
+ --overwrite_output_dir \
85
+ --overwrite_cache \
86
+ --lr_scheduler_type constant \
87
+ --warmup_steps 0 \
88
+ --logging_strategy steps \
89
+ --logging_steps 10 \
90
+ --metric_for_best_model eval_exact_match \
91
+ --evaluation_strategy epoch \
92
+ --save_strategy epoch \
93
+ --save_total_limit 1 \
94
+ --load_best_model_at_end \
95
+ --lora_r 8 \
96
+ --lora_alpha 32 \
97
+ --lora_dropout 0.0 \
98
+ --data_replay_freq -1 \
99
+ --mlp_hidden_dim 100 \
100
+ --model_name specroute \
101
+ --routing_mode learned \
102
+ --threshold 0.995 \
103
+ --transthreshold 0.995 \
104
+ $FP16_FLAG
105
+
106
+ rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/1-yelp/checkpoint*
107
+
108
+ sleep 5
109
+
110
+ if [ "$GPU_MODE" = "t4_2gpu" ]; then
111
+ BSZ=16; GA=1; EVAL_BSZ=256
112
+ elif [ "$GPU_MODE" = "t4_1gpu" ]; then
113
+ BSZ=32; GA=1; EVAL_BSZ=256
114
+ else
115
+ BSZ=64; GA=1; EVAL_BSZ=512
116
+ fi
117
+
118
+ CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
119
+ --do_train \
120
+ --load_checkpoint_from logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/1-yelp/saved_weights/trans_input.pt \
121
+ --previous_prompt_key_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/1-yelp/saved_weights/prompts_keys_till_now.pt \
122
+ --do_predict \
123
+ --predict_with_generate \
124
+ --model_name_or_path $2 \
125
+ --previous_lora_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/1-yelp/saved_weights \
126
+ --data_dir CL_Benchmark \
127
+ --task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
128
+ --task_config_dir configs/gen_script_long_order3_t5_configs/amazon \
129
+ --output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/2-amazon \
130
+ --per_device_train_batch_size $BSZ \
131
+ --per_device_eval_batch_size $EVAL_BSZ \
132
+ --gradient_accumulation_steps $GA \
133
+ --learning_rate 0.0003 \
134
+ --num_train_epochs 10 \
135
+ --run_name gen_script_long_order3_t5_small_specroute_v10a \
136
+ --max_source_length 512 \
137
+ --max_target_length 50 \
138
+ --generation_max_length 50 \
139
+ --add_task_name False \
140
+ --add_dataset_name False \
141
+ --overwrite_output_dir \
142
+ --overwrite_cache \
143
+ --lr_scheduler_type constant \
144
+ --warmup_steps 0 \
145
+ --logging_strategy steps \
146
+ --logging_steps 10 \
147
+ --metric_for_best_model eval_exact_match_for_amazon \
148
+ --evaluation_strategy epoch \
149
+ --save_strategy epoch \
150
+ --save_total_limit 1 \
151
+ --load_best_model_at_end \
152
+ --lora_r 8 \
153
+ --lora_alpha 32 \
154
+ --lora_dropout 0.0 \
155
+ --data_replay_freq -1 \
156
+ --mlp_hidden_dim 100 \
157
+ --model_name specroute \
158
+ --routing_mode learned \
159
+ --threshold 0.995 \
160
+ --transthreshold 0.995 \
161
+ $FP16_FLAG
162
+
163
+ rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/2-amazon/checkpoint*
164
+
165
+ sleep 5
166
+
167
+ if [ "$GPU_MODE" = "t4_2gpu" ]; then
168
+ BSZ=16; GA=1; EVAL_BSZ=256
169
+ elif [ "$GPU_MODE" = "t4_1gpu" ]; then
170
+ BSZ=32; GA=1; EVAL_BSZ=256
171
+ else
172
+ BSZ=64; GA=1; EVAL_BSZ=512
173
+ fi
174
+
175
+ CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
176
+ --do_train \
177
+ --load_checkpoint_from logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/2-amazon/saved_weights/trans_input.pt \
178
+ --previous_prompt_key_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/2-amazon/saved_weights/prompts_keys_till_now.pt \
179
+ --do_predict \
180
+ --predict_with_generate \
181
+ --model_name_or_path $2 \
182
+ --previous_lora_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/1-yelp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/2-amazon/saved_weights \
183
+ --data_dir CL_Benchmark \
184
+ --task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
185
+ --task_config_dir configs/gen_script_long_order3_t5_configs/mnli \
186
+ --output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/3-mnli \
187
+ --per_device_train_batch_size $BSZ \
188
+ --per_device_eval_batch_size $EVAL_BSZ \
189
+ --gradient_accumulation_steps $GA \
190
+ --learning_rate 0.0003 \
191
+ --num_train_epochs 10 \
192
+ --run_name gen_script_long_order3_t5_small_specroute_v10a \
193
+ --max_source_length 512 \
194
+ --max_target_length 50 \
195
+ --generation_max_length 50 \
196
+ --add_task_name False \
197
+ --add_dataset_name False \
198
+ --overwrite_output_dir \
199
+ --overwrite_cache \
200
+ --lr_scheduler_type constant \
201
+ --warmup_steps 0 \
202
+ --logging_strategy steps \
203
+ --logging_steps 10 \
204
+ --metric_for_best_model eval_exact_match_for_mnli \
205
+ --evaluation_strategy epoch \
206
+ --save_strategy epoch \
207
+ --save_total_limit 1 \
208
+ --load_best_model_at_end \
209
+ --lora_r 8 \
210
+ --lora_alpha 32 \
211
+ --lora_dropout 0.0 \
212
+ --data_replay_freq -1 \
213
+ --mlp_hidden_dim 100 \
214
+ --model_name specroute \
215
+ --routing_mode learned \
216
+ --threshold 0.995 \
217
+ --transthreshold 0.995 \
218
+ $FP16_FLAG
219
+
220
+ rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/3-mnli/checkpoint*
221
+
222
+ sleep 5
223
+
224
+ if [ "$GPU_MODE" = "t4_2gpu" ]; then
225
+ BSZ=16; GA=1; EVAL_BSZ=256
226
+ elif [ "$GPU_MODE" = "t4_1gpu" ]; then
227
+ BSZ=32; GA=1; EVAL_BSZ=256
228
+ else
229
+ BSZ=64; GA=1; EVAL_BSZ=512
230
+ fi
231
+
232
+ CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
233
+ --do_train \
234
+ --load_checkpoint_from logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/3-mnli/saved_weights/trans_input.pt \
235
+ --previous_prompt_key_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/3-mnli/saved_weights/prompts_keys_till_now.pt \
236
+ --do_predict \
237
+ --predict_with_generate \
238
+ --model_name_or_path $2 \
239
+ --previous_lora_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/1-yelp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/2-amazon/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/3-mnli/saved_weights \
240
+ --data_dir CL_Benchmark \
241
+ --task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
242
+ --task_config_dir configs/gen_script_long_order3_t5_configs/cb \
243
+ --output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/4-cb \
244
+ --per_device_train_batch_size $BSZ \
245
+ --per_device_eval_batch_size $EVAL_BSZ \
246
+ --gradient_accumulation_steps $GA \
247
+ --learning_rate 0.0003 \
248
+ --num_train_epochs 10 \
249
+ --run_name gen_script_long_order3_t5_small_specroute_v10a \
250
+ --max_source_length 512 \
251
+ --max_target_length 50 \
252
+ --generation_max_length 50 \
253
+ --add_task_name False \
254
+ --add_dataset_name False \
255
+ --overwrite_output_dir \
256
+ --overwrite_cache \
257
+ --lr_scheduler_type constant \
258
+ --warmup_steps 0 \
259
+ --logging_strategy steps \
260
+ --logging_steps 10 \
261
+ --metric_for_best_model eval_exact_match_for_cb \
262
+ --evaluation_strategy epoch \
263
+ --save_strategy epoch \
264
+ --save_total_limit 1 \
265
+ --load_best_model_at_end \
266
+ --lora_r 8 \
267
+ --lora_alpha 32 \
268
+ --lora_dropout 0.0 \
269
+ --data_replay_freq -1 \
270
+ --mlp_hidden_dim 100 \
271
+ --model_name specroute \
272
+ --routing_mode learned \
273
+ --threshold 0.995 \
274
+ --transthreshold 0.995 \
275
+ $FP16_FLAG
276
+
277
+ rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/4-cb/checkpoint*
278
+
279
+ sleep 5
280
+
281
+ if [ "$GPU_MODE" = "t4_2gpu" ]; then
282
+ BSZ=16; GA=1; EVAL_BSZ=256
283
+ elif [ "$GPU_MODE" = "t4_1gpu" ]; then
284
+ BSZ=32; GA=1; EVAL_BSZ=256
285
+ else
286
+ BSZ=64; GA=1; EVAL_BSZ=512
287
+ fi
288
+
289
+ CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
290
+ --do_train \
291
+ --load_checkpoint_from logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/4-cb/saved_weights/trans_input.pt \
292
+ --previous_prompt_key_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/4-cb/saved_weights/prompts_keys_till_now.pt \
293
+ --do_predict \
294
+ --predict_with_generate \
295
+ --model_name_or_path $2 \
296
+ --previous_lora_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/1-yelp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/2-amazon/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/3-mnli/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/4-cb/saved_weights \
297
+ --data_dir CL_Benchmark \
298
+ --task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
299
+ --task_config_dir configs/gen_script_long_order3_t5_configs/copa \
300
+ --output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/5-copa \
301
+ --per_device_train_batch_size $BSZ \
302
+ --per_device_eval_batch_size $EVAL_BSZ \
303
+ --gradient_accumulation_steps $GA \
304
+ --learning_rate 0.0003 \
305
+ --num_train_epochs 10 \
306
+ --run_name gen_script_long_order3_t5_small_specroute_v10a \
307
+ --max_source_length 512 \
308
+ --max_target_length 50 \
309
+ --generation_max_length 50 \
310
+ --add_task_name False \
311
+ --add_dataset_name False \
312
+ --overwrite_output_dir \
313
+ --overwrite_cache \
314
+ --lr_scheduler_type constant \
315
+ --warmup_steps 0 \
316
+ --logging_strategy steps \
317
+ --logging_steps 10 \
318
+ --metric_for_best_model eval_exact_match_for_copa \
319
+ --evaluation_strategy epoch \
320
+ --save_strategy epoch \
321
+ --save_total_limit 1 \
322
+ --load_best_model_at_end \
323
+ --lora_r 8 \
324
+ --lora_alpha 32 \
325
+ --lora_dropout 0.0 \
326
+ --data_replay_freq -1 \
327
+ --mlp_hidden_dim 100 \
328
+ --model_name specroute \
329
+ --routing_mode learned \
330
+ --threshold 0.995 \
331
+ --transthreshold 0.995 \
332
+ $FP16_FLAG
333
+
334
+ rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/5-copa/checkpoint*
335
+
336
+ sleep 5
337
+
338
+ if [ "$GPU_MODE" = "t4_2gpu" ]; then
339
+ BSZ=16; GA=1; EVAL_BSZ=256
340
+ elif [ "$GPU_MODE" = "t4_1gpu" ]; then
341
+ BSZ=32; GA=1; EVAL_BSZ=256
342
+ else
343
+ BSZ=64; GA=1; EVAL_BSZ=512
344
+ fi
345
+
346
+ CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
347
+ --do_train \
348
+ --load_checkpoint_from logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/5-copa/saved_weights/trans_input.pt \
349
+ --previous_prompt_key_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/5-copa/saved_weights/prompts_keys_till_now.pt \
350
+ --do_predict \
351
+ --predict_with_generate \
352
+ --model_name_or_path $2 \
353
+ --previous_lora_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/1-yelp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/2-amazon/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/3-mnli/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/4-cb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/5-copa/saved_weights \
354
+ --data_dir CL_Benchmark \
355
+ --task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
356
+ --task_config_dir configs/gen_script_long_order3_t5_configs/qqp \
357
+ --output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/6-qqp \
358
+ --per_device_train_batch_size $BSZ \
359
+ --per_device_eval_batch_size $EVAL_BSZ \
360
+ --gradient_accumulation_steps $GA \
361
+ --learning_rate 0.0003 \
362
+ --num_train_epochs 10 \
363
+ --run_name gen_script_long_order3_t5_small_specroute_v10a \
364
+ --max_source_length 512 \
365
+ --max_target_length 50 \
366
+ --generation_max_length 50 \
367
+ --add_task_name False \
368
+ --add_dataset_name False \
369
+ --overwrite_output_dir \
370
+ --overwrite_cache \
371
+ --lr_scheduler_type constant \
372
+ --warmup_steps 0 \
373
+ --logging_strategy steps \
374
+ --logging_steps 10 \
375
+ --metric_for_best_model eval_exact_match_for_qqp \
376
+ --evaluation_strategy epoch \
377
+ --save_strategy epoch \
378
+ --save_total_limit 1 \
379
+ --load_best_model_at_end \
380
+ --lora_r 8 \
381
+ --lora_alpha 32 \
382
+ --lora_dropout 0.0 \
383
+ --data_replay_freq -1 \
384
+ --mlp_hidden_dim 100 \
385
+ --model_name specroute \
386
+ --routing_mode learned \
387
+ --threshold 0.995 \
388
+ --transthreshold 0.995 \
389
+ $FP16_FLAG
390
+
391
+ rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/6-qqp/checkpoint*
392
+
393
+ sleep 5
394
+
395
+ if [ "$GPU_MODE" = "t4_2gpu" ]; then
396
+ BSZ=16; GA=1; EVAL_BSZ=256
397
+ elif [ "$GPU_MODE" = "t4_1gpu" ]; then
398
+ BSZ=32; GA=1; EVAL_BSZ=256
399
+ else
400
+ BSZ=64; GA=1; EVAL_BSZ=512
401
+ fi
402
+
403
+ CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
404
+ --do_train \
405
+ --load_checkpoint_from logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/6-qqp/saved_weights/trans_input.pt \
406
+ --previous_prompt_key_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/6-qqp/saved_weights/prompts_keys_till_now.pt \
407
+ --do_predict \
408
+ --predict_with_generate \
409
+ --model_name_or_path $2 \
410
+ --previous_lora_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/1-yelp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/2-amazon/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/3-mnli/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/4-cb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/5-copa/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/6-qqp/saved_weights \
411
+ --data_dir CL_Benchmark \
412
+ --task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
413
+ --task_config_dir configs/gen_script_long_order3_t5_configs/rte \
414
+ --output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/7-rte \
415
+ --per_device_train_batch_size $BSZ \
416
+ --per_device_eval_batch_size $EVAL_BSZ \
417
+ --gradient_accumulation_steps $GA \
418
+ --learning_rate 0.0003 \
419
+ --num_train_epochs 10 \
420
+ --run_name gen_script_long_order3_t5_small_specroute_v10a \
421
+ --max_source_length 512 \
422
+ --max_target_length 50 \
423
+ --generation_max_length 50 \
424
+ --add_task_name False \
425
+ --add_dataset_name False \
426
+ --overwrite_output_dir \
427
+ --overwrite_cache \
428
+ --lr_scheduler_type constant \
429
+ --warmup_steps 0 \
430
+ --logging_strategy steps \
431
+ --logging_steps 10 \
432
+ --metric_for_best_model eval_exact_match_for_rte \
433
+ --evaluation_strategy epoch \
434
+ --save_strategy epoch \
435
+ --save_total_limit 1 \
436
+ --load_best_model_at_end \
437
+ --lora_r 8 \
438
+ --lora_alpha 32 \
439
+ --lora_dropout 0.0 \
440
+ --data_replay_freq -1 \
441
+ --mlp_hidden_dim 100 \
442
+ --model_name specroute \
443
+ --routing_mode learned \
444
+ --threshold 0.995 \
445
+ --transthreshold 0.995 \
446
+ $FP16_FLAG
447
+
448
+ rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/7-rte/checkpoint*
449
+
450
+ sleep 5
451
+
452
+ if [ "$GPU_MODE" = "t4_2gpu" ]; then
453
+ BSZ=16; GA=1; EVAL_BSZ=256
454
+ elif [ "$GPU_MODE" = "t4_1gpu" ]; then
455
+ BSZ=32; GA=1; EVAL_BSZ=256
456
+ else
457
+ BSZ=64; GA=1; EVAL_BSZ=512
458
+ fi
459
+
460
+ CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
461
+ --do_train \
462
+ --load_checkpoint_from logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/7-rte/saved_weights/trans_input.pt \
463
+ --previous_prompt_key_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/7-rte/saved_weights/prompts_keys_till_now.pt \
464
+ --do_predict \
465
+ --predict_with_generate \
466
+ --model_name_or_path $2 \
467
+ --previous_lora_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/1-yelp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/2-amazon/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/3-mnli/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/4-cb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/5-copa/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/6-qqp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/7-rte/saved_weights \
468
+ --data_dir CL_Benchmark \
469
+ --task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
470
+ --task_config_dir configs/gen_script_long_order3_t5_configs/imdb \
471
+ --output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/8-imdb \
472
+ --per_device_train_batch_size $BSZ \
473
+ --per_device_eval_batch_size $EVAL_BSZ \
474
+ --gradient_accumulation_steps $GA \
475
+ --learning_rate 0.0003 \
476
+ --num_train_epochs 10 \
477
+ --run_name gen_script_long_order3_t5_small_specroute_v10a \
478
+ --max_source_length 512 \
479
+ --max_target_length 50 \
480
+ --generation_max_length 50 \
481
+ --add_task_name False \
482
+ --add_dataset_name False \
483
+ --overwrite_output_dir \
484
+ --overwrite_cache \
485
+ --lr_scheduler_type constant \
486
+ --warmup_steps 0 \
487
+ --logging_strategy steps \
488
+ --logging_steps 10 \
489
+ --metric_for_best_model eval_exact_match_for_imdb \
490
+ --evaluation_strategy epoch \
491
+ --save_strategy epoch \
492
+ --save_total_limit 1 \
493
+ --load_best_model_at_end \
494
+ --lora_r 8 \
495
+ --lora_alpha 32 \
496
+ --lora_dropout 0.0 \
497
+ --data_replay_freq -1 \
498
+ --mlp_hidden_dim 100 \
499
+ --model_name specroute \
500
+ --routing_mode learned \
501
+ --threshold 0.995 \
502
+ --transthreshold 0.995 \
503
+ $FP16_FLAG
504
+
505
+ rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/8-imdb/checkpoint*
506
+
507
+ sleep 5
508
+
509
+ if [ "$GPU_MODE" = "t4_2gpu" ]; then
510
+ BSZ=16; GA=1; EVAL_BSZ=256
511
+ elif [ "$GPU_MODE" = "t4_1gpu" ]; then
512
+ BSZ=32; GA=1; EVAL_BSZ=256
513
+ else
514
+ BSZ=64; GA=1; EVAL_BSZ=512
515
+ fi
516
+
517
+ CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
518
+ --do_train \
519
+ --load_checkpoint_from logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/8-imdb/saved_weights/trans_input.pt \
520
+ --previous_prompt_key_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/8-imdb/saved_weights/prompts_keys_till_now.pt \
521
+ --do_predict \
522
+ --predict_with_generate \
523
+ --model_name_or_path $2 \
524
+ --previous_lora_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/1-yelp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/2-amazon/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/3-mnli/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/4-cb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/5-copa/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/6-qqp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/7-rte/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/8-imdb/saved_weights \
525
+ --data_dir CL_Benchmark \
526
+ --task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
527
+ --task_config_dir configs/gen_script_long_order3_t5_configs/sst2 \
528
+ --output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/9-sst2 \
529
+ --per_device_train_batch_size $BSZ \
530
+ --per_device_eval_batch_size $EVAL_BSZ \
531
+ --gradient_accumulation_steps $GA \
532
+ --learning_rate 0.0003 \
533
+ --num_train_epochs 10 \
534
+ --run_name gen_script_long_order3_t5_small_specroute_v10a \
535
+ --max_source_length 512 \
536
+ --max_target_length 50 \
537
+ --generation_max_length 50 \
538
+ --add_task_name False \
539
+ --add_dataset_name False \
540
+ --overwrite_output_dir \
541
+ --overwrite_cache \
542
+ --lr_scheduler_type constant \
543
+ --warmup_steps 0 \
544
+ --logging_strategy steps \
545
+ --logging_steps 10 \
546
+ --metric_for_best_model eval_exact_match_for_sst2 \
547
+ --evaluation_strategy epoch \
548
+ --save_strategy epoch \
549
+ --save_total_limit 1 \
550
+ --load_best_model_at_end \
551
+ --lora_r 8 \
552
+ --lora_alpha 32 \
553
+ --lora_dropout 0.0 \
554
+ --data_replay_freq -1 \
555
+ --mlp_hidden_dim 100 \
556
+ --model_name specroute \
557
+ --routing_mode learned \
558
+ --threshold 0.995 \
559
+ --transthreshold 0.995 \
560
+ $FP16_FLAG
561
+
562
+ rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/9-sst2/checkpoint*
563
+
564
+ sleep 5
565
+
566
+ if [ "$GPU_MODE" = "t4_2gpu" ]; then
567
+ BSZ=16; GA=1; EVAL_BSZ=256
568
+ elif [ "$GPU_MODE" = "t4_1gpu" ]; then
569
+ BSZ=32; GA=1; EVAL_BSZ=256
570
+ else
571
+ BSZ=64; GA=1; EVAL_BSZ=512
572
+ fi
573
+
574
+ CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
575
+ --do_train \
576
+ --load_checkpoint_from logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/9-sst2/saved_weights/trans_input.pt \
577
+ --previous_prompt_key_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/9-sst2/saved_weights/prompts_keys_till_now.pt \
578
+ --do_predict \
579
+ --predict_with_generate \
580
+ --model_name_or_path $2 \
581
+ --previous_lora_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/1-yelp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/2-amazon/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/3-mnli/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/4-cb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/5-copa/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/6-qqp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/7-rte/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/8-imdb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/9-sst2/saved_weights \
582
+ --data_dir CL_Benchmark \
583
+ --task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
584
+ --task_config_dir configs/gen_script_long_order3_t5_configs/dbpedia \
585
+ --output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/10-dbpedia \
586
+ --per_device_train_batch_size $BSZ \
587
+ --per_device_eval_batch_size $EVAL_BSZ \
588
+ --gradient_accumulation_steps $GA \
589
+ --learning_rate 0.0003 \
590
+ --num_train_epochs 10 \
591
+ --run_name gen_script_long_order3_t5_small_specroute_v10a \
592
+ --max_source_length 512 \
593
+ --max_target_length 50 \
594
+ --generation_max_length 50 \
595
+ --add_task_name False \
596
+ --add_dataset_name False \
597
+ --overwrite_output_dir \
598
+ --overwrite_cache \
599
+ --lr_scheduler_type constant \
600
+ --warmup_steps 0 \
601
+ --logging_strategy steps \
602
+ --logging_steps 10 \
603
+ --metric_for_best_model eval_exact_match_for_dbpedia \
604
+ --evaluation_strategy epoch \
605
+ --save_strategy epoch \
606
+ --save_total_limit 1 \
607
+ --load_best_model_at_end \
608
+ --lora_r 8 \
609
+ --lora_alpha 32 \
610
+ --lora_dropout 0.0 \
611
+ --data_replay_freq -1 \
612
+ --mlp_hidden_dim 100 \
613
+ --model_name specroute \
614
+ --routing_mode learned \
615
+ --threshold 0.995 \
616
+ --transthreshold 0.995 \
617
+ $FP16_FLAG
618
+
619
+ rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/10-dbpedia/checkpoint*
620
+
621
+ sleep 5
622
+
623
+ if [ "$GPU_MODE" = "t4_2gpu" ]; then
624
+ BSZ=16; GA=1; EVAL_BSZ=256
625
+ elif [ "$GPU_MODE" = "t4_1gpu" ]; then
626
+ BSZ=32; GA=1; EVAL_BSZ=256
627
+ else
628
+ BSZ=64; GA=1; EVAL_BSZ=512
629
+ fi
630
+
631
+ CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
632
+ --do_train \
633
+ --load_checkpoint_from logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/10-dbpedia/saved_weights/trans_input.pt \
634
+ --previous_prompt_key_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/10-dbpedia/saved_weights/prompts_keys_till_now.pt \
635
+ --do_predict \
636
+ --predict_with_generate \
637
+ --model_name_or_path $2 \
638
+ --previous_lora_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/1-yelp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/2-amazon/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/3-mnli/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/4-cb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/5-copa/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/6-qqp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/7-rte/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/8-imdb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/9-sst2/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/10-dbpedia/saved_weights \
639
+ --data_dir CL_Benchmark \
640
+ --task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
641
+ --task_config_dir configs/gen_script_long_order3_t5_configs/agnews \
642
+ --output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/11-agnews \
643
+ --per_device_train_batch_size $BSZ \
644
+ --per_device_eval_batch_size $EVAL_BSZ \
645
+ --gradient_accumulation_steps $GA \
646
+ --learning_rate 0.0003 \
647
+ --num_train_epochs 10 \
648
+ --run_name gen_script_long_order3_t5_small_specroute_v10a \
649
+ --max_source_length 512 \
650
+ --max_target_length 50 \
651
+ --generation_max_length 50 \
652
+ --add_task_name False \
653
+ --add_dataset_name False \
654
+ --overwrite_output_dir \
655
+ --overwrite_cache \
656
+ --lr_scheduler_type constant \
657
+ --warmup_steps 0 \
658
+ --logging_strategy steps \
659
+ --logging_steps 10 \
660
+ --metric_for_best_model eval_exact_match_for_agnews \
661
+ --evaluation_strategy epoch \
662
+ --save_strategy epoch \
663
+ --save_total_limit 1 \
664
+ --load_best_model_at_end \
665
+ --lora_r 8 \
666
+ --lora_alpha 32 \
667
+ --lora_dropout 0.0 \
668
+ --data_replay_freq -1 \
669
+ --mlp_hidden_dim 100 \
670
+ --model_name specroute \
671
+ --routing_mode learned \
672
+ --threshold 0.995 \
673
+ --transthreshold 0.995 \
674
+ $FP16_FLAG
675
+
676
+ rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/11-agnews/checkpoint*
677
+
678
+ sleep 5
679
+
680
+ if [ "$GPU_MODE" = "t4_2gpu" ]; then
681
+ BSZ=16; GA=1; EVAL_BSZ=256
682
+ elif [ "$GPU_MODE" = "t4_1gpu" ]; then
683
+ BSZ=32; GA=1; EVAL_BSZ=256
684
+ else
685
+ BSZ=64; GA=1; EVAL_BSZ=512
686
+ fi
687
+
688
+ CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
689
+ --do_train \
690
+ --load_checkpoint_from logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/11-agnews/saved_weights/trans_input.pt \
691
+ --previous_prompt_key_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/11-agnews/saved_weights/prompts_keys_till_now.pt \
692
+ --do_predict \
693
+ --predict_with_generate \
694
+ --model_name_or_path $2 \
695
+ --previous_lora_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/1-yelp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/2-amazon/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/3-mnli/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/4-cb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/5-copa/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/6-qqp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/7-rte/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/8-imdb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/9-sst2/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/10-dbpedia/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/11-agnews/saved_weights \
696
+ --data_dir CL_Benchmark \
697
+ --task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
698
+ --task_config_dir configs/gen_script_long_order3_t5_configs/yahoo \
699
+ --output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/12-yahoo \
700
+ --per_device_train_batch_size $BSZ \
701
+ --per_device_eval_batch_size $EVAL_BSZ \
702
+ --gradient_accumulation_steps $GA \
703
+ --learning_rate 0.0003 \
704
+ --num_train_epochs 10 \
705
+ --run_name gen_script_long_order3_t5_small_specroute_v10a \
706
+ --max_source_length 512 \
707
+ --max_target_length 50 \
708
+ --generation_max_length 50 \
709
+ --add_task_name False \
710
+ --add_dataset_name False \
711
+ --overwrite_output_dir \
712
+ --overwrite_cache \
713
+ --lr_scheduler_type constant \
714
+ --warmup_steps 0 \
715
+ --logging_strategy steps \
716
+ --logging_steps 10 \
717
+ --metric_for_best_model eval_exact_match_for_yahoo \
718
+ --evaluation_strategy epoch \
719
+ --save_strategy epoch \
720
+ --save_total_limit 1 \
721
+ --load_best_model_at_end \
722
+ --lora_r 8 \
723
+ --lora_alpha 32 \
724
+ --lora_dropout 0.0 \
725
+ --data_replay_freq -1 \
726
+ --mlp_hidden_dim 100 \
727
+ --model_name specroute \
728
+ --routing_mode learned \
729
+ --threshold 0.995 \
730
+ --transthreshold 0.995 \
731
+ $FP16_FLAG
732
+
733
+ rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/12-yahoo/checkpoint*
734
+
735
+ sleep 5
736
+
737
+ if [ "$GPU_MODE" = "t4_2gpu" ]; then
738
+ BSZ=16; GA=1; EVAL_BSZ=256
739
+ elif [ "$GPU_MODE" = "t4_1gpu" ]; then
740
+ BSZ=32; GA=1; EVAL_BSZ=256
741
+ else
742
+ BSZ=64; GA=1; EVAL_BSZ=512
743
+ fi
744
+
745
+ CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
746
+ --do_train \
747
+ --load_checkpoint_from logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/12-yahoo/saved_weights/trans_input.pt \
748
+ --previous_prompt_key_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/12-yahoo/saved_weights/prompts_keys_till_now.pt \
749
+ --do_predict \
750
+ --predict_with_generate \
751
+ --model_name_or_path $2 \
752
+ --previous_lora_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/1-yelp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/2-amazon/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/3-mnli/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/4-cb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/5-copa/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/6-qqp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/7-rte/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/8-imdb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/9-sst2/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/10-dbpedia/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/11-agnews/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/12-yahoo/saved_weights \
753
+ --data_dir CL_Benchmark \
754
+ --task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
755
+ --task_config_dir configs/gen_script_long_order3_t5_configs/multirc \
756
+ --output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/13-multirc \
757
+ --per_device_train_batch_size $BSZ \
758
+ --per_device_eval_batch_size $EVAL_BSZ \
759
+ --gradient_accumulation_steps $GA \
760
+ --learning_rate 0.0003 \
761
+ --num_train_epochs 10 \
762
+ --run_name gen_script_long_order3_t5_small_specroute_v10a \
763
+ --max_source_length 512 \
764
+ --max_target_length 50 \
765
+ --generation_max_length 50 \
766
+ --add_task_name False \
767
+ --add_dataset_name False \
768
+ --overwrite_output_dir \
769
+ --overwrite_cache \
770
+ --lr_scheduler_type constant \
771
+ --warmup_steps 0 \
772
+ --logging_strategy steps \
773
+ --logging_steps 10 \
774
+ --metric_for_best_model eval_exact_match_for_multirc \
775
+ --evaluation_strategy epoch \
776
+ --save_strategy epoch \
777
+ --save_total_limit 1 \
778
+ --load_best_model_at_end \
779
+ --lora_r 8 \
780
+ --lora_alpha 32 \
781
+ --lora_dropout 0.0 \
782
+ --data_replay_freq -1 \
783
+ --mlp_hidden_dim 100 \
784
+ --model_name specroute \
785
+ --routing_mode learned \
786
+ --threshold 0.995 \
787
+ --transthreshold 0.995 \
788
+ $FP16_FLAG
789
+
790
+ rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/13-multirc/checkpoint*
791
+
792
+ sleep 5
793
+
794
+ if [ "$GPU_MODE" = "t4_2gpu" ]; then
795
+ BSZ=16; GA=1; EVAL_BSZ=256
796
+ elif [ "$GPU_MODE" = "t4_1gpu" ]; then
797
+ BSZ=32; GA=1; EVAL_BSZ=256
798
+ else
799
+ BSZ=64; GA=1; EVAL_BSZ=512
800
+ fi
801
+
802
+ CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
803
+ --do_train \
804
+ --load_checkpoint_from logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/13-multirc/saved_weights/trans_input.pt \
805
+ --previous_prompt_key_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/13-multirc/saved_weights/prompts_keys_till_now.pt \
806
+ --do_predict \
807
+ --predict_with_generate \
808
+ --model_name_or_path $2 \
809
+ --previous_lora_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/1-yelp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/2-amazon/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/3-mnli/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/4-cb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/5-copa/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/6-qqp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/7-rte/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/8-imdb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/9-sst2/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/10-dbpedia/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/11-agnews/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/12-yahoo/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/13-multirc/saved_weights \
810
+ --data_dir CL_Benchmark \
811
+ --task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
812
+ --task_config_dir configs/gen_script_long_order3_t5_configs/boolq \
813
+ --output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/14-boolq \
814
+ --per_device_train_batch_size $BSZ \
815
+ --per_device_eval_batch_size $EVAL_BSZ \
816
+ --gradient_accumulation_steps $GA \
817
+ --learning_rate 0.0003 \
818
+ --num_train_epochs 10 \
819
+ --run_name gen_script_long_order3_t5_small_specroute_v10a \
820
+ --max_source_length 512 \
821
+ --max_target_length 50 \
822
+ --generation_max_length 50 \
823
+ --add_task_name False \
824
+ --add_dataset_name False \
825
+ --overwrite_output_dir \
826
+ --overwrite_cache \
827
+ --lr_scheduler_type constant \
828
+ --warmup_steps 0 \
829
+ --logging_strategy steps \
830
+ --logging_steps 10 \
831
+ --metric_for_best_model eval_exact_match_for_boolq \
832
+ --evaluation_strategy epoch \
833
+ --save_strategy epoch \
834
+ --save_total_limit 1 \
835
+ --load_best_model_at_end \
836
+ --lora_r 8 \
837
+ --lora_alpha 32 \
838
+ --lora_dropout 0.0 \
839
+ --data_replay_freq -1 \
840
+ --mlp_hidden_dim 100 \
841
+ --model_name specroute \
842
+ --routing_mode learned \
843
+ --threshold 0.995 \
844
+ --transthreshold 0.995 \
845
+ $FP16_FLAG
846
+
847
+ rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/14-boolq/checkpoint*
848
+
849
+ sleep 5
850
+
851
+ if [ "$GPU_MODE" = "t4_2gpu" ]; then
852
+ BSZ=16; GA=1; EVAL_BSZ=256
853
+ elif [ "$GPU_MODE" = "t4_1gpu" ]; then
854
+ BSZ=32; GA=1; EVAL_BSZ=256
855
+ else
856
+ BSZ=64; GA=1; EVAL_BSZ=512
857
+ fi
858
+
859
+ CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
860
+ --do_train \
861
+ --load_checkpoint_from logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/14-boolq/saved_weights/trans_input.pt \
862
+ --previous_prompt_key_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/14-boolq/saved_weights/prompts_keys_till_now.pt \
863
+ --do_predict \
864
+ --predict_with_generate \
865
+ --model_name_or_path $2 \
866
+ --previous_lora_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/1-yelp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/2-amazon/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/3-mnli/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/4-cb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/5-copa/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/6-qqp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/7-rte/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/8-imdb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/9-sst2/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/10-dbpedia/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/11-agnews/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/12-yahoo/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/13-multirc/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/14-boolq/saved_weights \
867
+ --data_dir CL_Benchmark \
868
+ --task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
869
+ --task_config_dir configs/gen_script_long_order3_t5_configs/wic \
870
+ --output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/15-wic \
871
+ --per_device_train_batch_size $BSZ \
872
+ --per_device_eval_batch_size $EVAL_BSZ \
873
+ --gradient_accumulation_steps $GA \
874
+ --learning_rate 0.0003 \
875
+ --num_train_epochs 10 \
876
+ --run_name gen_script_long_order3_t5_small_specroute_v10a \
877
+ --max_source_length 512 \
878
+ --max_target_length 50 \
879
+ --generation_max_length 50 \
880
+ --add_task_name False \
881
+ --add_dataset_name False \
882
+ --overwrite_output_dir \
883
+ --overwrite_cache \
884
+ --lr_scheduler_type constant \
885
+ --warmup_steps 0 \
886
+ --logging_strategy steps \
887
+ --logging_steps 10 \
888
+ --metric_for_best_model eval_exact_match_for_wic \
889
+ --evaluation_strategy epoch \
890
+ --save_strategy epoch \
891
+ --save_total_limit 1 \
892
+ --load_best_model_at_end \
893
+ --lora_r 8 \
894
+ --lora_alpha 32 \
895
+ --lora_dropout 0.0 \
896
+ --data_replay_freq -1 \
897
+ --mlp_hidden_dim 100 \
898
+ --model_name specroute \
899
+ --routing_mode learned \
900
+ --threshold 0.995 \
901
+ --transthreshold 0.995 \
902
+ $FP16_FLAG
903
+
904
+ rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10a/outputs/15-wic/checkpoint*
905
+
906
+ sleep 5
improve_gainlora/T5_small/gen_script_long_order3_t5_small_specroute_v10b.sh ADDED
@@ -0,0 +1,878 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH -J cl
3
+ #SBATCH -o cl-%j.out
4
+ #SBATCH -p compute
5
+ #SBATCH -N 1
6
+ #SBATCH -t 20:00:00
7
+ #SBATCH --mem 128G
8
+ #SBATCH --gres=gpu:2
9
+
10
+ export CUDA_DEVICE_ORDER="PCI_BUS_ID"
11
+
12
+ port=$(shuf -i25000-30000 -n1)
13
+
14
+ # ============================================================
15
+ # Auto-detect GPU count and type for optimal parallelism
16
+ # ============================================================
17
+ NUM_GPUS=$(nvidia-smi -L 2>/dev/null | wc -l)
18
+ GPU_MEM=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits 2>/dev/null | head -1)
19
+
20
+ if [ -z "$GPU_MEM" ]; then
21
+ echo "ERROR: No GPU detected!"
22
+ exit 1
23
+ fi
24
+
25
+ # Determine GPU type
26
+ if [ "$GPU_MEM" -lt 20000 ]; then
27
+ IS_T4=1
28
+ echo "[GPU] Detected T4 GPUs (${GPU_MEM}MB VRAM each)"
29
+ else
30
+ IS_T4=0
31
+ echo "[GPU] Detected high-memory GPUs (${GPU_MEM}MB VRAM each)"
32
+ fi
33
+
34
+ # Determine parallelism strategy
35
+ if [ "$IS_T4" -eq 1 ] && [ "$NUM_GPUS" -ge 2 ]; then
36
+ GPU_MODE="t4_2gpu"
37
+ GPU_IDS="0,1"
38
+ FP16_FLAG=""
39
+ echo "[GPU] Strategy: 2x T4 DataParallel + fp32 + gradient_checkpointing"
40
+ elif [ "$IS_T4" -eq 1 ]; then
41
+ GPU_MODE="t4_1gpu"
42
+ GPU_IDS="${1:-0}"
43
+ FP16_FLAG=""
44
+ echo "[GPU] Strategy: 1x T4 + fp32 + gradient_checkpointing"
45
+ else
46
+ GPU_MODE="a100"
47
+ GPU_IDS="${1:-0}"
48
+ FP16_FLAG=""
49
+ echo "[GPU] Strategy: A100 (single GPU, fp32)"
50
+ fi
51
+
52
+ echo "[GPU] Using CUDA_VISIBLE_DEVICES=$GPU_IDS"
53
+ echo "============================================================"
54
+ echo ""
55
+
56
+ if [ "$GPU_MODE" = "t4_2gpu" ]; then
57
+ BSZ=16; GA=1; EVAL_BSZ=256
58
+ elif [ "$GPU_MODE" = "t4_1gpu" ]; then
59
+ BSZ=32; GA=1; EVAL_BSZ=256
60
+ else
61
+ BSZ=64; GA=1; EVAL_BSZ=512
62
+ fi
63
+
64
+ CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
65
+ --do_train \
66
+ --do_predict \
67
+ --predict_with_generate \
68
+ --model_name_or_path $2 \
69
+ --data_dir CL_Benchmark \
70
+ --task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
71
+ --task_config_dir configs/gen_script_long_order3_t5_configs/yelp \
72
+ --output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/1-yelp \
73
+ --per_device_train_batch_size $BSZ \
74
+ --per_device_eval_batch_size $EVAL_BSZ \
75
+ --gradient_accumulation_steps $GA \
76
+ --learning_rate 0.0003 \
77
+ --num_train_epochs 10 \
78
+ --run_name gen_script_long_order3_t5_small_specroute_v10b \
79
+ --max_source_length 512 \
80
+ --max_target_length 50 \
81
+ --generation_max_length 50 \
82
+ --add_task_name False \
83
+ --add_dataset_name False \
84
+ --overwrite_output_dir \
85
+ --overwrite_cache \
86
+ --lr_scheduler_type constant \
87
+ --warmup_steps 0 \
88
+ --logging_strategy steps \
89
+ --logging_steps 10 \
90
+ --metric_for_best_model eval_exact_match \
91
+ --evaluation_strategy epoch \
92
+ --save_strategy epoch \
93
+ --save_total_limit 1 \
94
+ --load_best_model_at_end \
95
+ --lora_r 8 \
96
+ --lora_alpha 32 \
97
+ --lora_dropout 0.0 \
98
+ --data_replay_freq -1 \
99
+ --mlp_hidden_dim 100 \
100
+ --model_name specroute \
101
+ --routing_mode grassmann \
102
+ --threshold 0.995 \
103
+ --transthreshold 0.995 \
104
+ $FP16_FLAG
105
+
106
+ rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/1-yelp/checkpoint*
107
+
108
+ sleep 5
109
+
110
+ if [ "$GPU_MODE" = "t4_2gpu" ]; then
111
+ BSZ=16; GA=1; EVAL_BSZ=256
112
+ elif [ "$GPU_MODE" = "t4_1gpu" ]; then
113
+ BSZ=32; GA=1; EVAL_BSZ=256
114
+ else
115
+ BSZ=64; GA=1; EVAL_BSZ=512
116
+ fi
117
+
118
+ CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
119
+ --do_train \
120
+ --do_predict \
121
+ --predict_with_generate \
122
+ --model_name_or_path $2 \
123
+ --previous_lora_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/1-yelp/saved_weights \
124
+ --data_dir CL_Benchmark \
125
+ --task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
126
+ --task_config_dir configs/gen_script_long_order3_t5_configs/amazon \
127
+ --output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/2-amazon \
128
+ --per_device_train_batch_size $BSZ \
129
+ --per_device_eval_batch_size $EVAL_BSZ \
130
+ --gradient_accumulation_steps $GA \
131
+ --learning_rate 0.0003 \
132
+ --num_train_epochs 10 \
133
+ --run_name gen_script_long_order3_t5_small_specroute_v10b \
134
+ --max_source_length 512 \
135
+ --max_target_length 50 \
136
+ --generation_max_length 50 \
137
+ --add_task_name False \
138
+ --add_dataset_name False \
139
+ --overwrite_output_dir \
140
+ --overwrite_cache \
141
+ --lr_scheduler_type constant \
142
+ --warmup_steps 0 \
143
+ --logging_strategy steps \
144
+ --logging_steps 10 \
145
+ --metric_for_best_model eval_exact_match_for_amazon \
146
+ --evaluation_strategy epoch \
147
+ --save_strategy epoch \
148
+ --save_total_limit 1 \
149
+ --load_best_model_at_end \
150
+ --lora_r 8 \
151
+ --lora_alpha 32 \
152
+ --lora_dropout 0.0 \
153
+ --data_replay_freq -1 \
154
+ --mlp_hidden_dim 100 \
155
+ --model_name specroute \
156
+ --routing_mode grassmann \
157
+ --threshold 0.995 \
158
+ --transthreshold 0.995 \
159
+ $FP16_FLAG
160
+
161
+ rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/2-amazon/checkpoint*
162
+
163
+ sleep 5
164
+
165
+ if [ "$GPU_MODE" = "t4_2gpu" ]; then
166
+ BSZ=16; GA=1; EVAL_BSZ=256
167
+ elif [ "$GPU_MODE" = "t4_1gpu" ]; then
168
+ BSZ=32; GA=1; EVAL_BSZ=256
169
+ else
170
+ BSZ=64; GA=1; EVAL_BSZ=512
171
+ fi
172
+
173
+ CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
174
+ --do_train \
175
+ --do_predict \
176
+ --predict_with_generate \
177
+ --model_name_or_path $2 \
178
+ --previous_lora_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/1-yelp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/2-amazon/saved_weights \
179
+ --data_dir CL_Benchmark \
180
+ --task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
181
+ --task_config_dir configs/gen_script_long_order3_t5_configs/mnli \
182
+ --output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/3-mnli \
183
+ --per_device_train_batch_size $BSZ \
184
+ --per_device_eval_batch_size $EVAL_BSZ \
185
+ --gradient_accumulation_steps $GA \
186
+ --learning_rate 0.0003 \
187
+ --num_train_epochs 10 \
188
+ --run_name gen_script_long_order3_t5_small_specroute_v10b \
189
+ --max_source_length 512 \
190
+ --max_target_length 50 \
191
+ --generation_max_length 50 \
192
+ --add_task_name False \
193
+ --add_dataset_name False \
194
+ --overwrite_output_dir \
195
+ --overwrite_cache \
196
+ --lr_scheduler_type constant \
197
+ --warmup_steps 0 \
198
+ --logging_strategy steps \
199
+ --logging_steps 10 \
200
+ --metric_for_best_model eval_exact_match_for_mnli \
201
+ --evaluation_strategy epoch \
202
+ --save_strategy epoch \
203
+ --save_total_limit 1 \
204
+ --load_best_model_at_end \
205
+ --lora_r 8 \
206
+ --lora_alpha 32 \
207
+ --lora_dropout 0.0 \
208
+ --data_replay_freq -1 \
209
+ --mlp_hidden_dim 100 \
210
+ --model_name specroute \
211
+ --routing_mode grassmann \
212
+ --threshold 0.995 \
213
+ --transthreshold 0.995 \
214
+ $FP16_FLAG
215
+
216
+ rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/3-mnli/checkpoint*
217
+
218
+ sleep 5
219
+
220
+ if [ "$GPU_MODE" = "t4_2gpu" ]; then
221
+ BSZ=16; GA=1; EVAL_BSZ=256
222
+ elif [ "$GPU_MODE" = "t4_1gpu" ]; then
223
+ BSZ=32; GA=1; EVAL_BSZ=256
224
+ else
225
+ BSZ=64; GA=1; EVAL_BSZ=512
226
+ fi
227
+
228
+ CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
229
+ --do_train \
230
+ --do_predict \
231
+ --predict_with_generate \
232
+ --model_name_or_path $2 \
233
+ --previous_lora_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/1-yelp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/2-amazon/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/3-mnli/saved_weights \
234
+ --data_dir CL_Benchmark \
235
+ --task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
236
+ --task_config_dir configs/gen_script_long_order3_t5_configs/cb \
237
+ --output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/4-cb \
238
+ --per_device_train_batch_size $BSZ \
239
+ --per_device_eval_batch_size $EVAL_BSZ \
240
+ --gradient_accumulation_steps $GA \
241
+ --learning_rate 0.0003 \
242
+ --num_train_epochs 10 \
243
+ --run_name gen_script_long_order3_t5_small_specroute_v10b \
244
+ --max_source_length 512 \
245
+ --max_target_length 50 \
246
+ --generation_max_length 50 \
247
+ --add_task_name False \
248
+ --add_dataset_name False \
249
+ --overwrite_output_dir \
250
+ --overwrite_cache \
251
+ --lr_scheduler_type constant \
252
+ --warmup_steps 0 \
253
+ --logging_strategy steps \
254
+ --logging_steps 10 \
255
+ --metric_for_best_model eval_exact_match_for_cb \
256
+ --evaluation_strategy epoch \
257
+ --save_strategy epoch \
258
+ --save_total_limit 1 \
259
+ --load_best_model_at_end \
260
+ --lora_r 8 \
261
+ --lora_alpha 32 \
262
+ --lora_dropout 0.0 \
263
+ --data_replay_freq -1 \
264
+ --mlp_hidden_dim 100 \
265
+ --model_name specroute \
266
+ --routing_mode grassmann \
267
+ --threshold 0.995 \
268
+ --transthreshold 0.995 \
269
+ $FP16_FLAG
270
+
271
+ rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/4-cb/checkpoint*
272
+
273
+ sleep 5
274
+
275
+ if [ "$GPU_MODE" = "t4_2gpu" ]; then
276
+ BSZ=16; GA=1; EVAL_BSZ=256
277
+ elif [ "$GPU_MODE" = "t4_1gpu" ]; then
278
+ BSZ=32; GA=1; EVAL_BSZ=256
279
+ else
280
+ BSZ=64; GA=1; EVAL_BSZ=512
281
+ fi
282
+
283
+ CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
284
+ --do_train \
285
+ --do_predict \
286
+ --predict_with_generate \
287
+ --model_name_or_path $2 \
288
+ --previous_lora_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/1-yelp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/2-amazon/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/3-mnli/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/4-cb/saved_weights \
289
+ --data_dir CL_Benchmark \
290
+ --task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
291
+ --task_config_dir configs/gen_script_long_order3_t5_configs/copa \
292
+ --output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/5-copa \
293
+ --per_device_train_batch_size $BSZ \
294
+ --per_device_eval_batch_size $EVAL_BSZ \
295
+ --gradient_accumulation_steps $GA \
296
+ --learning_rate 0.0003 \
297
+ --num_train_epochs 10 \
298
+ --run_name gen_script_long_order3_t5_small_specroute_v10b \
299
+ --max_source_length 512 \
300
+ --max_target_length 50 \
301
+ --generation_max_length 50 \
302
+ --add_task_name False \
303
+ --add_dataset_name False \
304
+ --overwrite_output_dir \
305
+ --overwrite_cache \
306
+ --lr_scheduler_type constant \
307
+ --warmup_steps 0 \
308
+ --logging_strategy steps \
309
+ --logging_steps 10 \
310
+ --metric_for_best_model eval_exact_match_for_copa \
311
+ --evaluation_strategy epoch \
312
+ --save_strategy epoch \
313
+ --save_total_limit 1 \
314
+ --load_best_model_at_end \
315
+ --lora_r 8 \
316
+ --lora_alpha 32 \
317
+ --lora_dropout 0.0 \
318
+ --data_replay_freq -1 \
319
+ --mlp_hidden_dim 100 \
320
+ --model_name specroute \
321
+ --routing_mode grassmann \
322
+ --threshold 0.995 \
323
+ --transthreshold 0.995 \
324
+ $FP16_FLAG
325
+
326
+ rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/5-copa/checkpoint*
327
+
328
+ sleep 5
329
+
330
+ if [ "$GPU_MODE" = "t4_2gpu" ]; then
331
+ BSZ=16; GA=1; EVAL_BSZ=256
332
+ elif [ "$GPU_MODE" = "t4_1gpu" ]; then
333
+ BSZ=32; GA=1; EVAL_BSZ=256
334
+ else
335
+ BSZ=64; GA=1; EVAL_BSZ=512
336
+ fi
337
+
338
+ CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
339
+ --do_train \
340
+ --do_predict \
341
+ --predict_with_generate \
342
+ --model_name_or_path $2 \
343
+ --previous_lora_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/1-yelp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/2-amazon/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/3-mnli/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/4-cb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/5-copa/saved_weights \
344
+ --data_dir CL_Benchmark \
345
+ --task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
346
+ --task_config_dir configs/gen_script_long_order3_t5_configs/qqp \
347
+ --output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/6-qqp \
348
+ --per_device_train_batch_size $BSZ \
349
+ --per_device_eval_batch_size $EVAL_BSZ \
350
+ --gradient_accumulation_steps $GA \
351
+ --learning_rate 0.0003 \
352
+ --num_train_epochs 10 \
353
+ --run_name gen_script_long_order3_t5_small_specroute_v10b \
354
+ --max_source_length 512 \
355
+ --max_target_length 50 \
356
+ --generation_max_length 50 \
357
+ --add_task_name False \
358
+ --add_dataset_name False \
359
+ --overwrite_output_dir \
360
+ --overwrite_cache \
361
+ --lr_scheduler_type constant \
362
+ --warmup_steps 0 \
363
+ --logging_strategy steps \
364
+ --logging_steps 10 \
365
+ --metric_for_best_model eval_exact_match_for_qqp \
366
+ --evaluation_strategy epoch \
367
+ --save_strategy epoch \
368
+ --save_total_limit 1 \
369
+ --load_best_model_at_end \
370
+ --lora_r 8 \
371
+ --lora_alpha 32 \
372
+ --lora_dropout 0.0 \
373
+ --data_replay_freq -1 \
374
+ --mlp_hidden_dim 100 \
375
+ --model_name specroute \
376
+ --routing_mode grassmann \
377
+ --threshold 0.995 \
378
+ --transthreshold 0.995 \
379
+ $FP16_FLAG
380
+
381
+ rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/6-qqp/checkpoint*
382
+
383
+ sleep 5
384
+
385
+ if [ "$GPU_MODE" = "t4_2gpu" ]; then
386
+ BSZ=16; GA=1; EVAL_BSZ=256
387
+ elif [ "$GPU_MODE" = "t4_1gpu" ]; then
388
+ BSZ=32; GA=1; EVAL_BSZ=256
389
+ else
390
+ BSZ=64; GA=1; EVAL_BSZ=512
391
+ fi
392
+
393
+ CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
394
+ --do_train \
395
+ --do_predict \
396
+ --predict_with_generate \
397
+ --model_name_or_path $2 \
398
+ --previous_lora_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/1-yelp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/2-amazon/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/3-mnli/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/4-cb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/5-copa/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/6-qqp/saved_weights \
399
+ --data_dir CL_Benchmark \
400
+ --task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
401
+ --task_config_dir configs/gen_script_long_order3_t5_configs/rte \
402
+ --output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/7-rte \
403
+ --per_device_train_batch_size $BSZ \
404
+ --per_device_eval_batch_size $EVAL_BSZ \
405
+ --gradient_accumulation_steps $GA \
406
+ --learning_rate 0.0003 \
407
+ --num_train_epochs 10 \
408
+ --run_name gen_script_long_order3_t5_small_specroute_v10b \
409
+ --max_source_length 512 \
410
+ --max_target_length 50 \
411
+ --generation_max_length 50 \
412
+ --add_task_name False \
413
+ --add_dataset_name False \
414
+ --overwrite_output_dir \
415
+ --overwrite_cache \
416
+ --lr_scheduler_type constant \
417
+ --warmup_steps 0 \
418
+ --logging_strategy steps \
419
+ --logging_steps 10 \
420
+ --metric_for_best_model eval_exact_match_for_rte \
421
+ --evaluation_strategy epoch \
422
+ --save_strategy epoch \
423
+ --save_total_limit 1 \
424
+ --load_best_model_at_end \
425
+ --lora_r 8 \
426
+ --lora_alpha 32 \
427
+ --lora_dropout 0.0 \
428
+ --data_replay_freq -1 \
429
+ --mlp_hidden_dim 100 \
430
+ --model_name specroute \
431
+ --routing_mode grassmann \
432
+ --threshold 0.995 \
433
+ --transthreshold 0.995 \
434
+ $FP16_FLAG
435
+
436
+ rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/7-rte/checkpoint*
437
+
438
+ sleep 5
439
+
440
+ if [ "$GPU_MODE" = "t4_2gpu" ]; then
441
+ BSZ=16; GA=1; EVAL_BSZ=256
442
+ elif [ "$GPU_MODE" = "t4_1gpu" ]; then
443
+ BSZ=32; GA=1; EVAL_BSZ=256
444
+ else
445
+ BSZ=64; GA=1; EVAL_BSZ=512
446
+ fi
447
+
448
+ CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
449
+ --do_train \
450
+ --do_predict \
451
+ --predict_with_generate \
452
+ --model_name_or_path $2 \
453
+ --previous_lora_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/1-yelp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/2-amazon/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/3-mnli/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/4-cb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/5-copa/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/6-qqp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/7-rte/saved_weights \
454
+ --data_dir CL_Benchmark \
455
+ --task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
456
+ --task_config_dir configs/gen_script_long_order3_t5_configs/imdb \
457
+ --output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/8-imdb \
458
+ --per_device_train_batch_size $BSZ \
459
+ --per_device_eval_batch_size $EVAL_BSZ \
460
+ --gradient_accumulation_steps $GA \
461
+ --learning_rate 0.0003 \
462
+ --num_train_epochs 10 \
463
+ --run_name gen_script_long_order3_t5_small_specroute_v10b \
464
+ --max_source_length 512 \
465
+ --max_target_length 50 \
466
+ --generation_max_length 50 \
467
+ --add_task_name False \
468
+ --add_dataset_name False \
469
+ --overwrite_output_dir \
470
+ --overwrite_cache \
471
+ --lr_scheduler_type constant \
472
+ --warmup_steps 0 \
473
+ --logging_strategy steps \
474
+ --logging_steps 10 \
475
+ --metric_for_best_model eval_exact_match_for_imdb \
476
+ --evaluation_strategy epoch \
477
+ --save_strategy epoch \
478
+ --save_total_limit 1 \
479
+ --load_best_model_at_end \
480
+ --lora_r 8 \
481
+ --lora_alpha 32 \
482
+ --lora_dropout 0.0 \
483
+ --data_replay_freq -1 \
484
+ --mlp_hidden_dim 100 \
485
+ --model_name specroute \
486
+ --routing_mode grassmann \
487
+ --threshold 0.995 \
488
+ --transthreshold 0.995 \
489
+ $FP16_FLAG
490
+
491
+ rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/8-imdb/checkpoint*
492
+
493
+ sleep 5
494
+
495
+ if [ "$GPU_MODE" = "t4_2gpu" ]; then
496
+ BSZ=16; GA=1; EVAL_BSZ=256
497
+ elif [ "$GPU_MODE" = "t4_1gpu" ]; then
498
+ BSZ=32; GA=1; EVAL_BSZ=256
499
+ else
500
+ BSZ=64; GA=1; EVAL_BSZ=512
501
+ fi
502
+
503
+ CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
504
+ --do_train \
505
+ --do_predict \
506
+ --predict_with_generate \
507
+ --model_name_or_path $2 \
508
+ --previous_lora_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/1-yelp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/2-amazon/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/3-mnli/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/4-cb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/5-copa/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/6-qqp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/7-rte/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/8-imdb/saved_weights \
509
+ --data_dir CL_Benchmark \
510
+ --task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
511
+ --task_config_dir configs/gen_script_long_order3_t5_configs/sst2 \
512
+ --output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/9-sst2 \
513
+ --per_device_train_batch_size $BSZ \
514
+ --per_device_eval_batch_size $EVAL_BSZ \
515
+ --gradient_accumulation_steps $GA \
516
+ --learning_rate 0.0003 \
517
+ --num_train_epochs 10 \
518
+ --run_name gen_script_long_order3_t5_small_specroute_v10b \
519
+ --max_source_length 512 \
520
+ --max_target_length 50 \
521
+ --generation_max_length 50 \
522
+ --add_task_name False \
523
+ --add_dataset_name False \
524
+ --overwrite_output_dir \
525
+ --overwrite_cache \
526
+ --lr_scheduler_type constant \
527
+ --warmup_steps 0 \
528
+ --logging_strategy steps \
529
+ --logging_steps 10 \
530
+ --metric_for_best_model eval_exact_match_for_sst2 \
531
+ --evaluation_strategy epoch \
532
+ --save_strategy epoch \
533
+ --save_total_limit 1 \
534
+ --load_best_model_at_end \
535
+ --lora_r 8 \
536
+ --lora_alpha 32 \
537
+ --lora_dropout 0.0 \
538
+ --data_replay_freq -1 \
539
+ --mlp_hidden_dim 100 \
540
+ --model_name specroute \
541
+ --routing_mode grassmann \
542
+ --threshold 0.995 \
543
+ --transthreshold 0.995 \
544
+ $FP16_FLAG
545
+
546
+ rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/9-sst2/checkpoint*
547
+
548
+ sleep 5
549
+
550
+ if [ "$GPU_MODE" = "t4_2gpu" ]; then
551
+ BSZ=16; GA=1; EVAL_BSZ=256
552
+ elif [ "$GPU_MODE" = "t4_1gpu" ]; then
553
+ BSZ=32; GA=1; EVAL_BSZ=256
554
+ else
555
+ BSZ=64; GA=1; EVAL_BSZ=512
556
+ fi
557
+
558
+ CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
559
+ --do_train \
560
+ --do_predict \
561
+ --predict_with_generate \
562
+ --model_name_or_path $2 \
563
+ --previous_lora_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/1-yelp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/2-amazon/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/3-mnli/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/4-cb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/5-copa/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/6-qqp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/7-rte/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/8-imdb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/9-sst2/saved_weights \
564
+ --data_dir CL_Benchmark \
565
+ --task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
566
+ --task_config_dir configs/gen_script_long_order3_t5_configs/dbpedia \
567
+ --output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/10-dbpedia \
568
+ --per_device_train_batch_size $BSZ \
569
+ --per_device_eval_batch_size $EVAL_BSZ \
570
+ --gradient_accumulation_steps $GA \
571
+ --learning_rate 0.0003 \
572
+ --num_train_epochs 10 \
573
+ --run_name gen_script_long_order3_t5_small_specroute_v10b \
574
+ --max_source_length 512 \
575
+ --max_target_length 50 \
576
+ --generation_max_length 50 \
577
+ --add_task_name False \
578
+ --add_dataset_name False \
579
+ --overwrite_output_dir \
580
+ --overwrite_cache \
581
+ --lr_scheduler_type constant \
582
+ --warmup_steps 0 \
583
+ --logging_strategy steps \
584
+ --logging_steps 10 \
585
+ --metric_for_best_model eval_exact_match_for_dbpedia \
586
+ --evaluation_strategy epoch \
587
+ --save_strategy epoch \
588
+ --save_total_limit 1 \
589
+ --load_best_model_at_end \
590
+ --lora_r 8 \
591
+ --lora_alpha 32 \
592
+ --lora_dropout 0.0 \
593
+ --data_replay_freq -1 \
594
+ --mlp_hidden_dim 100 \
595
+ --model_name specroute \
596
+ --routing_mode grassmann \
597
+ --threshold 0.995 \
598
+ --transthreshold 0.995 \
599
+ $FP16_FLAG
600
+
601
+ rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/10-dbpedia/checkpoint*
602
+
603
+ sleep 5
604
+
605
+ if [ "$GPU_MODE" = "t4_2gpu" ]; then
606
+ BSZ=16; GA=1; EVAL_BSZ=256
607
+ elif [ "$GPU_MODE" = "t4_1gpu" ]; then
608
+ BSZ=32; GA=1; EVAL_BSZ=256
609
+ else
610
+ BSZ=64; GA=1; EVAL_BSZ=512
611
+ fi
612
+
613
+ CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
614
+ --do_train \
615
+ --do_predict \
616
+ --predict_with_generate \
617
+ --model_name_or_path $2 \
618
+ --previous_lora_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/1-yelp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/2-amazon/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/3-mnli/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/4-cb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/5-copa/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/6-qqp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/7-rte/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/8-imdb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/9-sst2/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/10-dbpedia/saved_weights \
619
+ --data_dir CL_Benchmark \
620
+ --task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
621
+ --task_config_dir configs/gen_script_long_order3_t5_configs/agnews \
622
+ --output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/11-agnews \
623
+ --per_device_train_batch_size $BSZ \
624
+ --per_device_eval_batch_size $EVAL_BSZ \
625
+ --gradient_accumulation_steps $GA \
626
+ --learning_rate 0.0003 \
627
+ --num_train_epochs 10 \
628
+ --run_name gen_script_long_order3_t5_small_specroute_v10b \
629
+ --max_source_length 512 \
630
+ --max_target_length 50 \
631
+ --generation_max_length 50 \
632
+ --add_task_name False \
633
+ --add_dataset_name False \
634
+ --overwrite_output_dir \
635
+ --overwrite_cache \
636
+ --lr_scheduler_type constant \
637
+ --warmup_steps 0 \
638
+ --logging_strategy steps \
639
+ --logging_steps 10 \
640
+ --metric_for_best_model eval_exact_match_for_agnews \
641
+ --evaluation_strategy epoch \
642
+ --save_strategy epoch \
643
+ --save_total_limit 1 \
644
+ --load_best_model_at_end \
645
+ --lora_r 8 \
646
+ --lora_alpha 32 \
647
+ --lora_dropout 0.0 \
648
+ --data_replay_freq -1 \
649
+ --mlp_hidden_dim 100 \
650
+ --model_name specroute \
651
+ --routing_mode grassmann \
652
+ --threshold 0.995 \
653
+ --transthreshold 0.995 \
654
+ $FP16_FLAG
655
+
656
+ rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/11-agnews/checkpoint*
657
+
658
+ sleep 5
659
+
660
+ if [ "$GPU_MODE" = "t4_2gpu" ]; then
661
+ BSZ=16; GA=1; EVAL_BSZ=256
662
+ elif [ "$GPU_MODE" = "t4_1gpu" ]; then
663
+ BSZ=32; GA=1; EVAL_BSZ=256
664
+ else
665
+ BSZ=64; GA=1; EVAL_BSZ=512
666
+ fi
667
+
668
+ CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
669
+ --do_train \
670
+ --do_predict \
671
+ --predict_with_generate \
672
+ --model_name_or_path $2 \
673
+ --previous_lora_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/1-yelp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/2-amazon/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/3-mnli/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/4-cb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/5-copa/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/6-qqp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/7-rte/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/8-imdb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/9-sst2/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/10-dbpedia/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/11-agnews/saved_weights \
674
+ --data_dir CL_Benchmark \
675
+ --task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
676
+ --task_config_dir configs/gen_script_long_order3_t5_configs/yahoo \
677
+ --output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/12-yahoo \
678
+ --per_device_train_batch_size $BSZ \
679
+ --per_device_eval_batch_size $EVAL_BSZ \
680
+ --gradient_accumulation_steps $GA \
681
+ --learning_rate 0.0003 \
682
+ --num_train_epochs 10 \
683
+ --run_name gen_script_long_order3_t5_small_specroute_v10b \
684
+ --max_source_length 512 \
685
+ --max_target_length 50 \
686
+ --generation_max_length 50 \
687
+ --add_task_name False \
688
+ --add_dataset_name False \
689
+ --overwrite_output_dir \
690
+ --overwrite_cache \
691
+ --lr_scheduler_type constant \
692
+ --warmup_steps 0 \
693
+ --logging_strategy steps \
694
+ --logging_steps 10 \
695
+ --metric_for_best_model eval_exact_match_for_yahoo \
696
+ --evaluation_strategy epoch \
697
+ --save_strategy epoch \
698
+ --save_total_limit 1 \
699
+ --load_best_model_at_end \
700
+ --lora_r 8 \
701
+ --lora_alpha 32 \
702
+ --lora_dropout 0.0 \
703
+ --data_replay_freq -1 \
704
+ --mlp_hidden_dim 100 \
705
+ --model_name specroute \
706
+ --routing_mode grassmann \
707
+ --threshold 0.995 \
708
+ --transthreshold 0.995 \
709
+ $FP16_FLAG
710
+
711
+ rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/12-yahoo/checkpoint*
712
+
713
+ sleep 5
714
+
715
+ if [ "$GPU_MODE" = "t4_2gpu" ]; then
716
+ BSZ=16; GA=1; EVAL_BSZ=256
717
+ elif [ "$GPU_MODE" = "t4_1gpu" ]; then
718
+ BSZ=32; GA=1; EVAL_BSZ=256
719
+ else
720
+ BSZ=64; GA=1; EVAL_BSZ=512
721
+ fi
722
+
723
+ CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
724
+ --do_train \
725
+ --do_predict \
726
+ --predict_with_generate \
727
+ --model_name_or_path $2 \
728
+ --previous_lora_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/1-yelp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/2-amazon/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/3-mnli/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/4-cb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/5-copa/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/6-qqp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/7-rte/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/8-imdb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/9-sst2/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/10-dbpedia/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/11-agnews/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/12-yahoo/saved_weights \
729
+ --data_dir CL_Benchmark \
730
+ --task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
731
+ --task_config_dir configs/gen_script_long_order3_t5_configs/multirc \
732
+ --output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/13-multirc \
733
+ --per_device_train_batch_size $BSZ \
734
+ --per_device_eval_batch_size $EVAL_BSZ \
735
+ --gradient_accumulation_steps $GA \
736
+ --learning_rate 0.0003 \
737
+ --num_train_epochs 10 \
738
+ --run_name gen_script_long_order3_t5_small_specroute_v10b \
739
+ --max_source_length 512 \
740
+ --max_target_length 50 \
741
+ --generation_max_length 50 \
742
+ --add_task_name False \
743
+ --add_dataset_name False \
744
+ --overwrite_output_dir \
745
+ --overwrite_cache \
746
+ --lr_scheduler_type constant \
747
+ --warmup_steps 0 \
748
+ --logging_strategy steps \
749
+ --logging_steps 10 \
750
+ --metric_for_best_model eval_exact_match_for_multirc \
751
+ --evaluation_strategy epoch \
752
+ --save_strategy epoch \
753
+ --save_total_limit 1 \
754
+ --load_best_model_at_end \
755
+ --lora_r 8 \
756
+ --lora_alpha 32 \
757
+ --lora_dropout 0.0 \
758
+ --data_replay_freq -1 \
759
+ --mlp_hidden_dim 100 \
760
+ --model_name specroute \
761
+ --routing_mode grassmann \
762
+ --threshold 0.995 \
763
+ --transthreshold 0.995 \
764
+ $FP16_FLAG
765
+
766
+ rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/13-multirc/checkpoint*
767
+
768
+ sleep 5
769
+
770
+ if [ "$GPU_MODE" = "t4_2gpu" ]; then
771
+ BSZ=16; GA=1; EVAL_BSZ=256
772
+ elif [ "$GPU_MODE" = "t4_1gpu" ]; then
773
+ BSZ=32; GA=1; EVAL_BSZ=256
774
+ else
775
+ BSZ=64; GA=1; EVAL_BSZ=512
776
+ fi
777
+
778
+ CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
779
+ --do_train \
780
+ --do_predict \
781
+ --predict_with_generate \
782
+ --model_name_or_path $2 \
783
+ --previous_lora_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/1-yelp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/2-amazon/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/3-mnli/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/4-cb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/5-copa/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/6-qqp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/7-rte/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/8-imdb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/9-sst2/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/10-dbpedia/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/11-agnews/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/12-yahoo/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/13-multirc/saved_weights \
784
+ --data_dir CL_Benchmark \
785
+ --task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
786
+ --task_config_dir configs/gen_script_long_order3_t5_configs/boolq \
787
+ --output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/14-boolq \
788
+ --per_device_train_batch_size $BSZ \
789
+ --per_device_eval_batch_size $EVAL_BSZ \
790
+ --gradient_accumulation_steps $GA \
791
+ --learning_rate 0.0003 \
792
+ --num_train_epochs 10 \
793
+ --run_name gen_script_long_order3_t5_small_specroute_v10b \
794
+ --max_source_length 512 \
795
+ --max_target_length 50 \
796
+ --generation_max_length 50 \
797
+ --add_task_name False \
798
+ --add_dataset_name False \
799
+ --overwrite_output_dir \
800
+ --overwrite_cache \
801
+ --lr_scheduler_type constant \
802
+ --warmup_steps 0 \
803
+ --logging_strategy steps \
804
+ --logging_steps 10 \
805
+ --metric_for_best_model eval_exact_match_for_boolq \
806
+ --evaluation_strategy epoch \
807
+ --save_strategy epoch \
808
+ --save_total_limit 1 \
809
+ --load_best_model_at_end \
810
+ --lora_r 8 \
811
+ --lora_alpha 32 \
812
+ --lora_dropout 0.0 \
813
+ --data_replay_freq -1 \
814
+ --mlp_hidden_dim 100 \
815
+ --model_name specroute \
816
+ --routing_mode grassmann \
817
+ --threshold 0.995 \
818
+ --transthreshold 0.995 \
819
+ $FP16_FLAG
820
+
821
+ rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/14-boolq/checkpoint*
822
+
823
+ sleep 5
824
+
825
+ if [ "$GPU_MODE" = "t4_2gpu" ]; then
826
+ BSZ=16; GA=1; EVAL_BSZ=256
827
+ elif [ "$GPU_MODE" = "t4_1gpu" ]; then
828
+ BSZ=32; GA=1; EVAL_BSZ=256
829
+ else
830
+ BSZ=64; GA=1; EVAL_BSZ=512
831
+ fi
832
+
833
+ CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
834
+ --do_train \
835
+ --do_predict \
836
+ --predict_with_generate \
837
+ --model_name_or_path $2 \
838
+ --previous_lora_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/1-yelp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/2-amazon/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/3-mnli/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/4-cb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/5-copa/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/6-qqp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/7-rte/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/8-imdb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/9-sst2/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/10-dbpedia/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/11-agnews/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/12-yahoo/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/13-multirc/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/14-boolq/saved_weights \
839
+ --data_dir CL_Benchmark \
840
+ --task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
841
+ --task_config_dir configs/gen_script_long_order3_t5_configs/wic \
842
+ --output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/15-wic \
843
+ --per_device_train_batch_size $BSZ \
844
+ --per_device_eval_batch_size $EVAL_BSZ \
845
+ --gradient_accumulation_steps $GA \
846
+ --learning_rate 0.0003 \
847
+ --num_train_epochs 10 \
848
+ --run_name gen_script_long_order3_t5_small_specroute_v10b \
849
+ --max_source_length 512 \
850
+ --max_target_length 50 \
851
+ --generation_max_length 50 \
852
+ --add_task_name False \
853
+ --add_dataset_name False \
854
+ --overwrite_output_dir \
855
+ --overwrite_cache \
856
+ --lr_scheduler_type constant \
857
+ --warmup_steps 0 \
858
+ --logging_strategy steps \
859
+ --logging_steps 10 \
860
+ --metric_for_best_model eval_exact_match_for_wic \
861
+ --evaluation_strategy epoch \
862
+ --save_strategy epoch \
863
+ --save_total_limit 1 \
864
+ --load_best_model_at_end \
865
+ --lora_r 8 \
866
+ --lora_alpha 32 \
867
+ --lora_dropout 0.0 \
868
+ --data_replay_freq -1 \
869
+ --mlp_hidden_dim 100 \
870
+ --model_name specroute \
871
+ --routing_mode grassmann \
872
+ --threshold 0.995 \
873
+ --transthreshold 0.995 \
874
+ $FP16_FLAG
875
+
876
+ rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v10b/outputs/15-wic/checkpoint*
877
+
878
+ sleep 5
improve_gainlora/discuss_AI.txt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6d90f4755ebe7ef6c899e0eef85a1866bc5629ae25ce0ec3fd56616fa92644c4
3
- size 3934
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:199005383467e5e167f68250698a393b3ab431cad10597bb6377b1cb52019985
3
+ size 20466
improve_gainlora/generate_v10_scripts.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+
4
+ with open("T5_small/gen_script_long_order3_t5_small_gainlora_inflora.sh", "r") as f:
5
+ gainlora_content = f.read()
6
+
7
+ with open("T5_small/gen_script_long_order3_t5_small_specroute.sh", "r") as f:
8
+ specroute_content = f.read()
9
+
10
+ def create_script(mode, suffix):
11
+ new_content = specroute_content.replace("gen_script_long_order3_t5_small_specroute", f"gen_script_long_order3_t5_small_specroute_{suffix}")
12
+ new_content = new_content.replace("--model_name specroute \\", f"--model_name specroute \\\n --routing_mode {mode} \\")
13
+
14
+ if mode == "learned":
15
+ # Extract previous_prompt_key_path and load_checkpoint_from from gainlora
16
+ blocks = new_content.split("python src/run_t5.py")
17
+ final_content = blocks[0]
18
+
19
+ gainlora_blocks = gainlora_content.split("python src/run_t5.py")
20
+
21
+ for i in range(1, len(blocks)):
22
+ block = blocks[i]
23
+ gainlora_block = gainlora_blocks[i]
24
+
25
+ m1 = re.search(r'--load_checkpoint_from\s+([^\s\\]+)', gainlora_block)
26
+ m2 = re.search(r'--previous_prompt_key_path\s+([^\s\\]+)', gainlora_block)
27
+
28
+ args_to_add = ""
29
+ if m1:
30
+ path1 = m1.group(1).replace("gen_script_long_order3_t5_small_gainlora_inflora", "gen_script_long_order3_t5_small_specroute_v10a")
31
+ args_to_add += f" --load_checkpoint_from {path1} \\\n"
32
+ if m2:
33
+ path2 = m2.group(1).replace("gen_script_long_order3_t5_small_gainlora_inflora", "gen_script_long_order3_t5_small_specroute_v10a")
34
+ args_to_add += f" --previous_prompt_key_path {path2} \\\n"
35
+
36
+ final_content += "python src/run_t5.py" + block.replace(" --do_train \\\n", f" --do_train \\\n{args_to_add}")
37
+
38
+ new_content = final_content
39
+
40
+ with open(f"T5_small/gen_script_long_order3_t5_small_specroute_{suffix}.sh", "w") as f:
41
+ f.write(new_content)
42
+ print(f"Created T5_small/gen_script_long_order3_t5_small_specroute_{suffix}.sh")
43
+
44
+ create_script("learned", "v10a")
45
+ create_script("grassmann", "v10b")
improve_gainlora/src/cl_trainer_specroute.py CHANGED
@@ -81,6 +81,51 @@ class PeriodicGCCallback(TrainerCallback):
81
  return control
82
 
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  class SpecRoute_Trainer(Seq2SeqTrainer):
85
 
86
  def __init__(self, model, args, train_dataset, cur_task_id, task_order,
@@ -89,6 +134,9 @@ class SpecRoute_Trainer(Seq2SeqTrainer):
89
  lambda_entropy=0.0, use_preconditioning=False,
90
  precond_eps=1e-6, entropy_warmup_ratio=0.1,
91
  n_batches_c5=100):
 
 
 
92
  super().__init__(
93
  model=model, args=args, train_dataset=train_dataset,
94
  eval_dataset=eval_dataset, tokenizer=tokenizer,
@@ -259,13 +307,14 @@ class SpecRoute_Trainer(Seq2SeqTrainer):
259
  print(f'[C5] Covariance collected for {len(self._task_covariance)} layers.')
260
 
261
  def load_previous_reg_matrix(self):
262
- """Load LoRA GPM bases from previous task. No trans_input GPM needed."""
263
  log_path = os.path.dirname(self.args.output_dir)
264
  local_dir = os.path.basename(self.args.output_dir)
265
  print(log_path)
266
 
267
  all_dirs = os.listdir(log_path)
268
  reg_matrix = []
 
269
  for all_dir in all_dirs:
270
  if not os.path.isdir(os.path.join(log_path, all_dir)):
271
  continue
@@ -277,22 +326,38 @@ class SpecRoute_Trainer(Seq2SeqTrainer):
277
  os.path.join(os.path.join(log_path, all_dir), "reg_{}.pt".format(i))
278
  ))
279
  i += 1
 
 
 
 
 
280
  print(os.path.join(log_path, all_dir))
281
  print(len(reg_matrix))
282
  break
283
- return reg_matrix, eval(local_dir.split('-')[0]) - 1
284
 
285
  def get_reg_matrix(self):
286
  """
287
  Project current LoRA A into null-space of old tasks' GPM bases.
288
  No prompt_key/trans_input operations.
289
  """
290
- self.feature_list, self._cur_task = self.load_previous_reg_matrix()
291
 
292
  if len(self.feature_list) == 0:
293
  # First task: no constraints
294
  return
295
 
 
 
 
 
 
 
 
 
 
 
 
296
  # Compute projection matrices for LoRA GPM
297
  self.feature_mat, i = [], 0
298
  for name, module in self.model.named_modules():
@@ -366,10 +431,9 @@ class SpecRoute_Trainer(Seq2SeqTrainer):
366
  def get_repsentation(self):
367
  """
368
  Collect LoRA input covariance and compute GPM bases via SVD.
369
- ESA: Use constant threshold (no increasing schedule).
370
- No trans_input features collected.
371
  """
372
- self.feature_list, self._cur_task = self.load_previous_reg_matrix()
373
 
374
  train_dataloader = self.get_train_dataloader()
375
  if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
@@ -382,6 +446,11 @@ class SpecRoute_Trainer(Seq2SeqTrainer):
382
  module.get_feature = True
383
  module.stage = 0
384
 
 
 
 
 
 
385
  print('begin get representation')
386
  with torch.no_grad():
387
  for step, inputs in enumerate(train_dataloader):
@@ -395,6 +464,10 @@ class SpecRoute_Trainer(Seq2SeqTrainer):
395
  break
396
  print('end get representation')
397
 
 
 
 
 
398
  # Collect LoRA covariance matrices
399
  mat_list = []
400
  for name, module in self.model.named_modules():
@@ -469,6 +542,32 @@ class SpecRoute_Trainer(Seq2SeqTrainer):
469
  else:
470
  self.feature_list[i][index] = from_dlpack(Ui.toDlpack())
471
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472
  print('-' * 40)
473
  print('Gradient Constraints Summary')
474
  print('-' * 40)
@@ -485,8 +584,12 @@ class SpecRoute_Trainer(Seq2SeqTrainer):
485
  for i in range(len(self.feature_list)):
486
  torch.save(self.feature_list[i], os.path.join(self.args.output_dir, 'reg_{}.pt'.format(i)))
487
 
488
- # No trans_input GPM to save
489
-
 
 
 
 
490
  # training_step: removed — base Seq2SeqTrainer handles it correctly.
491
  # SpecRoute has no memory replay or custom training_step logic.
492
 
 
81
  return control
82
 
83
 
84
+ class TransInputGPMCallback(TrainerCallback):
85
+ """V10a: Apply GPM projection to trans_input and prompt_key after optimizer step."""
86
+ def __init__(self, trainer):
87
+ self.trainer = trainer
88
+
89
+ def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
90
+ if getattr(self.trainer, "cur_task_id", 0) > 1 and getattr(self.trainer.model.encoder, "routing_mode", "") == "learned":
91
+ from copy import deepcopy
92
+ self.trainer._old_trans_input_0 = deepcopy(self.trainer.model.encoder.trans_input[0].weight.detach())
93
+ self.trainer._old_trans_input_1 = deepcopy(self.trainer.model.encoder.trans_input[2].weight.detach())
94
+ self.trainer._old_prompt_key = deepcopy(self.trainer.model.encoder.prompt_key.detach())
95
+
96
+ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
97
+ if getattr(self.trainer, "cur_task_id", 0) > 1 and getattr(self.trainer.model.encoder, "routing_mode", "") == "learned":
98
+ if not hasattr(self.trainer, "feature_trans_mat") or not self.trainer.feature_trans_mat:
99
+ return
100
+
101
+ from copy import deepcopy
102
+ new_trans_input_0 = deepcopy(self.trainer.model.encoder.trans_input[0].weight.detach())
103
+ new_trans_input_1 = deepcopy(self.trainer.model.encoder.trans_input[2].weight.detach())
104
+ new_trans_input_0norm = new_trans_input_0.norm(dim=1, keepdim=True)
105
+ new_trans_input_1norm = new_trans_input_1.norm(dim=1, keepdim=True)
106
+
107
+ new_prompt_key = deepcopy(self.trainer.model.encoder.prompt_key.detach())
108
+ new_prompt_key_norm = new_prompt_key.norm(dim=1, keepdim=True)
109
+
110
+ old_trans_input_0 = self.trainer._old_trans_input_0
111
+ old_trans_input_1 = self.trainer._old_trans_input_1
112
+ old_prompt_key = self.trainer._old_prompt_key
113
+
114
+ for index in self.trainer.feature_trans_mat[0].keys():
115
+ new_trans_input_0[:,index*self.trainer.model.encoder.step:(index+1)*self.trainer.model.encoder.step] = self.trainer.model.encoder.trans_input[0].weight.detach()[:,index*self.trainer.model.encoder.step:(index+1)*self.trainer.model.encoder.step] - torch.mm(self.trainer.model.encoder.trans_input[0].weight.detach()[:,index*self.trainer.model.encoder.step:(index+1)*self.trainer.model.encoder.step]-old_trans_input_0[:,index*self.trainer.model.encoder.step:(index+1)*self.trainer.model.encoder.step], self.trainer.feature_trans_mat[0][index])
116
+ new_prompt_key[:,index*self.trainer.model.encoder.step:(index+1)*self.trainer.model.encoder.step] = self.trainer.model.encoder.prompt_key.detach()[:,index*self.trainer.model.encoder.step:(index+1)*self.trainer.model.encoder.step] - torch.mm(self.trainer.model.encoder.prompt_key.detach()[:,index*self.trainer.model.encoder.step:(index+1)*self.trainer.model.encoder.step]-old_prompt_key[:,index*self.trainer.model.encoder.step:(index+1)*self.trainer.model.encoder.step], self.trainer.feature_trans_mat[2][index])
117
+ new_trans_input_1 = self.trainer.model.encoder.trans_input[2].weight.detach() - torch.mm(self.trainer.model.encoder.trans_input[2].weight.detach()-old_trans_input_1, self.trainer.feature_trans_mat[1])
118
+
119
+ new_trans_input_0 = new_trans_input_0*new_trans_input_0norm / new_trans_input_0.norm(dim=1, keepdim=True).clamp(min=1e-12)
120
+ new_trans_input_1 = new_trans_input_1*new_trans_input_1norm / new_trans_input_1.norm(dim=1, keepdim=True).clamp(min=1e-12)
121
+ new_prompt_key = new_prompt_key*new_prompt_key_norm / new_prompt_key.norm(dim=1, keepdim=True).clamp(min=1e-12)
122
+
123
+ self.trainer.model.encoder.trans_input[0].weight.data.copy_(new_trans_input_0)
124
+ self.trainer.model.encoder.trans_input[2].weight.data.copy_(new_trans_input_1)
125
+ self.trainer.model.encoder.prompt_key.data.copy_(new_prompt_key)
126
+ return control
127
+
128
+
129
  class SpecRoute_Trainer(Seq2SeqTrainer):
130
 
131
  def __init__(self, model, args, train_dataset, cur_task_id, task_order,
 
134
  lambda_entropy=0.0, use_preconditioning=False,
135
  precond_eps=1e-6, entropy_warmup_ratio=0.1,
136
  n_batches_c5=100):
137
+ if callbacks is None:
138
+ callbacks = []
139
+ callbacks.append(TransInputGPMCallback(self))
140
  super().__init__(
141
  model=model, args=args, train_dataset=train_dataset,
142
  eval_dataset=eval_dataset, tokenizer=tokenizer,
 
307
  print(f'[C5] Covariance collected for {len(self._task_covariance)} layers.')
308
 
309
  def load_previous_reg_matrix(self):
310
+ """Load LoRA GPM bases from previous task. Also load trans_input GPM if learned routing."""
311
  log_path = os.path.dirname(self.args.output_dir)
312
  local_dir = os.path.basename(self.args.output_dir)
313
  print(log_path)
314
 
315
  all_dirs = os.listdir(log_path)
316
  reg_matrix = []
317
+ reg_trans_matrix = []
318
  for all_dir in all_dirs:
319
  if not os.path.isdir(os.path.join(log_path, all_dir)):
320
  continue
 
326
  os.path.join(os.path.join(log_path, all_dir), "reg_{}.pt".format(i))
327
  ))
328
  i += 1
329
+ if getattr(self.model.encoder, "routing_mode", "") == "learned":
330
+ reg_trans_matrix.append(torch.load(os.path.join(os.path.join(log_path, all_dir, 'trans_input'), "reg_0.pt"), weights_only=True))
331
+ reg_trans_matrix.append(torch.load(os.path.join(os.path.join(log_path, all_dir, 'trans_input'), "reg_1.pt"), weights_only=True))
332
+ reg_trans_matrix.append(torch.load(os.path.join(os.path.join(log_path, all_dir, 'trans_input'), "reg_2.pt"), weights_only=True))
333
+
334
  print(os.path.join(log_path, all_dir))
335
  print(len(reg_matrix))
336
  break
337
+ return reg_matrix, reg_trans_matrix, eval(local_dir.split('-')[0]) - 1
338
 
339
  def get_reg_matrix(self):
340
  """
341
  Project current LoRA A into null-space of old tasks' GPM bases.
342
  No prompt_key/trans_input operations.
343
  """
344
+ self.feature_list, self.feature_trans_list, self._cur_task = self.load_previous_reg_matrix()
345
 
346
  if len(self.feature_list) == 0:
347
  # First task: no constraints
348
  return
349
 
350
+ if getattr(self.model.encoder, "routing_mode", "") == "learned":
351
+ self.feature_trans_mat = []
352
+ for i in range(len(self.feature_trans_list)):
353
+ if i == 1:
354
+ self.feature_trans_mat.append(torch.mm(self.feature_trans_list[i], self.feature_trans_list[i].T).to("cuda:0"))
355
+ else:
356
+ feature_trans_mat = {}
357
+ for index in self.feature_trans_list[i].keys():
358
+ feature_trans_mat[index] = torch.mm(self.feature_trans_list[i][index], self.feature_trans_list[i][index].T).to("cuda:0")
359
+ self.feature_trans_mat.append(feature_trans_mat)
360
+
361
  # Compute projection matrices for LoRA GPM
362
  self.feature_mat, i = [], 0
363
  for name, module in self.model.named_modules():
 
431
  def get_repsentation(self):
432
  """
433
  Collect LoRA input covariance and compute GPM bases via SVD.
434
+ For V10a (learned routing), also collect trans_input covariance.
 
435
  """
436
+ self.feature_list, self.feature_trans_list, self._cur_task = self.load_previous_reg_matrix()
437
 
438
  train_dataloader = self.get_train_dataloader()
439
  if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
 
446
  module.get_feature = True
447
  module.stage = 0
448
 
449
+ # V10a: enable trans_input covariance collection
450
+ if getattr(self.model.encoder, "routing_mode", "") == "learned":
451
+ self.model.encoder.get_chunk(self.args.chunk)
452
+ self.model.encoder.get_trans_feature = True
453
+
454
  print('begin get representation')
455
  with torch.no_grad():
456
  for step, inputs in enumerate(train_dataloader):
 
464
  break
465
  print('end get representation')
466
 
467
+ # V10a: disable trans_input collection after forward pass
468
+ if getattr(self.model.encoder, "routing_mode", "") == "learned":
469
+ self.model.encoder.get_trans_feature = False
470
+
471
  # Collect LoRA covariance matrices
472
  mat_list = []
473
  for name, module in self.model.named_modules():
 
542
  else:
543
  self.feature_list[i][index] = from_dlpack(Ui.toDlpack())
544
 
545
+ # Collect trans_input GPM bases if learned routing
546
+ if getattr(self.model.encoder, "routing_mode", "") == "learned":
547
+ mat_trans_list = []
548
+ if self.model.encoder.matrix_trans_2.sum() != 0:
549
+ mat_trans_list.append(self.model.encoder.matrix_trans_1)
550
+ mat_trans_list.append(self.model.encoder.matrix_trans_2)
551
+ mat_trans_list.append(self.model.encoder.matrix_trans_3)
552
+
553
+ self.feature_trans_list, self.feature_trans_mat = [], []
554
+ for i in range(len(mat_trans_list)):
555
+ if i == 1:
556
+ U, S, Vh = torch.linalg.svd(mat_trans_list[i].data, full_matrices=False)
557
+ sval_total = (S**2).sum()
558
+ sval_ratio = (S**2)/sval_total
559
+ r = np.sum(np.cumsum(sval_ratio.cpu().numpy()) < self.args.transthreshold) + 1
560
+ self.feature_trans_list.append(U[:,0:r].float())
561
+ else:
562
+ feature_trans_list, feature_trans_mat = {}, {}
563
+ for index in mat_trans_list[i].keys():
564
+ U, S, Vh = torch.linalg.svd(mat_trans_list[i][index].data, full_matrices=False)
565
+ sval_total = (S**2).sum()
566
+ sval_ratio = (S**2)/sval_total
567
+ r = np.sum(np.cumsum(sval_ratio.cpu().numpy()) < self.args.transthreshold) + 1
568
+ feature_trans_list[index] = U[:,0:r].float()
569
+ self.feature_trans_list.append(feature_trans_list)
570
+
571
  print('-' * 40)
572
  print('Gradient Constraints Summary')
573
  print('-' * 40)
 
584
  for i in range(len(self.feature_list)):
585
  torch.save(self.feature_list[i], os.path.join(self.args.output_dir, 'reg_{}.pt'.format(i)))
586
 
587
+ # Save trans_input GPM bases
588
+ if getattr(self.model.encoder, "routing_mode", "") == "learned" and hasattr(self, "feature_trans_list"):
589
+ os.makedirs(os.path.join(self.args.output_dir, 'trans_input'), exist_ok=True)
590
+ for i in range(len(self.feature_trans_list)):
591
+ torch.save(self.feature_trans_list[i], os.path.join(self.args.output_dir, 'trans_input', 'reg_{}.pt'.format(i)))
592
+
593
  # training_step: removed — base Seq2SeqTrainer handles it correctly.
594
  # SpecRoute has no memory replay or custom training_step logic.
595
 
improve_gainlora/src/run_t5.py CHANGED
@@ -172,6 +172,10 @@ class ModelArguments:
172
  "Adaptive bias = T*ln(alpha*n_old/(1-alpha)). Set 0 to use fixed training_bias."
173
  },
174
  )
 
 
 
 
175
 
176
  # C4: Spectrally-Conditioned LoRA Training
177
  lambda_entropy: Optional[float] = field(
@@ -367,6 +371,10 @@ class TrainingArguments(Seq2SeqTrainingArguments):
367
  default='SAPT',
368
  metadata={"help": "models' name"}
369
  )
 
 
 
 
370
  chunk: Optional[int] = field(
371
  default=1,
372
  metadata={"help": "models' name"}
@@ -512,6 +520,7 @@ def main():
512
  'seq_len': data_args.max_source_length,
513
  'mlp_hidden_dim': model_args.mlp_hidden_dim,
514
  'attn_temperature': model_args.attn_temperature,
 
515
  'previous_lora_path': model_args.previous_lora_path,
516
  'previous_prompt_key_path': model_args.previous_prompt_key_path,
517
  'task_id': cur_task_id,
@@ -566,13 +575,13 @@ def main():
566
  device = torch.device(f"cuda:{local_rank}")
567
  except:
568
  device = torch.device(f"cuda:0")
569
- if model_args.load_checkpoint_from and training_args.model_name != 'specroute':
570
  if not os.path.exists(model_args.load_checkpoint_from):
571
  logger.warning(f"load_checkpoint_from not found: {model_args.load_checkpoint_from}, skipping load")
572
  else:
573
  print("----------Loading Previous Query Projection Layer----------")
574
  model.encoder.trans_input.load_state_dict(torch.load(model_args.load_checkpoint_from, map_location=device))
575
- if training_args.model_name in ['gainlora_inflora', 'gainlora_olora']:
576
  model.encoder.previous_trans_input.input_linear[0].data.copy_(torch.load(model_args.load_checkpoint_from, map_location=device)['0.weight'])
577
  model.encoder.previous_trans_input.output_linear[0].data.copy_(torch.load(model_args.load_checkpoint_from, map_location=device)['2.weight'])
578
  model.encoder.previous_trans_input.state_dict()
@@ -661,6 +670,9 @@ def main():
661
  param.requires_grad = False
662
  if "lora_B" in name and "previous_lora_weights" not in name:
663
  param.requires_grad = True
 
 
 
664
 
665
  total_params, params = 0, 0
666
  for n, p in model.named_parameters():
@@ -1029,6 +1041,17 @@ def main():
1029
  signatures = compute_spectral_signatures(trainer.model, config)
1030
  torch.save(signatures, os.path.join(save_path, 'spectral_signatures.pt'))
1031
  print("----------Saved spectral signatures----------")
 
 
 
 
 
 
 
 
 
 
 
1032
  # Only save tokenizer for non-specroute (specroute never reloads it)
1033
  if training_args.model_name != 'specroute':
1034
  tokenizer.save_pretrained(save_path)
 
172
  "Adaptive bias = T*ln(alpha*n_old/(1-alpha)). Set 0 to use fixed training_bias."
173
  },
174
  )
175
+ previous_prompt_key_path: Optional[str] = field(
176
+ default=None,
177
+ metadata={"help": "Path to the previous key prompt layer."}
178
+ )
179
 
180
  # C4: Spectrally-Conditioned LoRA Training
181
  lambda_entropy: Optional[float] = field(
 
371
  default='SAPT',
372
  metadata={"help": "models' name"}
373
  )
374
+ routing_mode: Optional[str] = field(
375
+ default='spectral',
376
+ metadata={"help": "Routing mode for SpecRoute"}
377
+ )
378
  chunk: Optional[int] = field(
379
  default=1,
380
  metadata={"help": "models' name"}
 
520
  'seq_len': data_args.max_source_length,
521
  'mlp_hidden_dim': model_args.mlp_hidden_dim,
522
  'attn_temperature': model_args.attn_temperature,
523
+ 'routing_mode': training_args.routing_mode,
524
  'previous_lora_path': model_args.previous_lora_path,
525
  'previous_prompt_key_path': model_args.previous_prompt_key_path,
526
  'task_id': cur_task_id,
 
575
  device = torch.device(f"cuda:{local_rank}")
576
  except:
577
  device = torch.device(f"cuda:0")
578
+ if model_args.load_checkpoint_from and (training_args.model_name != 'specroute' or getattr(training_args, "routing_mode", "") == "learned"):
579
  if not os.path.exists(model_args.load_checkpoint_from):
580
  logger.warning(f"load_checkpoint_from not found: {model_args.load_checkpoint_from}, skipping load")
581
  else:
582
  print("----------Loading Previous Query Projection Layer----------")
583
  model.encoder.trans_input.load_state_dict(torch.load(model_args.load_checkpoint_from, map_location=device))
584
+ if training_args.model_name in ['gainlora_inflora', 'gainlora_olora'] or (training_args.model_name == 'specroute' and getattr(training_args, "routing_mode", "") == "learned"):
585
  model.encoder.previous_trans_input.input_linear[0].data.copy_(torch.load(model_args.load_checkpoint_from, map_location=device)['0.weight'])
586
  model.encoder.previous_trans_input.output_linear[0].data.copy_(torch.load(model_args.load_checkpoint_from, map_location=device)['2.weight'])
587
  model.encoder.previous_trans_input.state_dict()
 
670
  param.requires_grad = False
671
  if "lora_B" in name and "previous_lora_weights" not in name:
672
  param.requires_grad = True
673
+ if getattr(training_args, "routing_mode", "") == "learned":
674
+ if ("trans_input" in name and "previous_trans_input" not in name) or "prompt_key" in name:
675
+ param.requires_grad = True
676
 
677
  total_params, params = 0, 0
678
  for n, p in model.named_parameters():
 
1041
  signatures = compute_spectral_signatures(trainer.model, config)
1042
  torch.save(signatures, os.path.join(save_path, 'spectral_signatures.pt'))
1043
  print("----------Saved spectral signatures----------")
1044
+
1045
+ if getattr(training_args, "routing_mode", "") == "learned":
1046
+ from copy import deepcopy
1047
+ if not prompt_config["run_single"]:
1048
+ if prompt_config["previous_prompt_key_path"] is not None:
1049
+ previous_trans_input = deepcopy(trainer.model.encoder.previous_trans_input.state_dict())
1050
+ torch.save(previous_trans_input, os.path.join(save_path, 'previous_trans_input.pt'))
1051
+ torch.save(torch.cat([trainer.model.encoder.prompt_key, trainer.model.encoder.previous_prompts_keys], dim=0).data, os.path.join(save_path, 'prompts_keys_till_now.pt'))
1052
+ else:
1053
+ torch.save(trainer.model.encoder.prompt_key.data, os.path.join(save_path, 'prompts_keys_till_now.pt'))
1054
+ torch.save(trainer.model.encoder.trans_input.state_dict(), os.path.join(save_path, 'trans_input.pt'))
1055
  # Only save tokenizer for non-specroute (specroute never reloads it)
1056
  if training_args.model_name != 'specroute':
1057
  tokenizer.save_pretrained(save_path)
improve_gainlora/src/t5_specroute.py CHANGED
@@ -57,6 +57,7 @@ from t5_gainlora_inflora import (
57
  T5LayerCrossAttention,
58
  T5Block,
59
  T5PreTrainedModel,
 
60
  )
61
 
62
  logger = logging.get_logger(__name__)
@@ -145,16 +146,43 @@ class T5Stack(T5PreTrainedModel):
145
  self.prompt_config = prompt_config
146
 
147
  if not self.is_decoder and not prompt_config["run_single"]:
148
- # ===== Spectral routing: NO learned parameters for routing =====
149
- # Spectral signatures loaded from previous tasks' saved weights
 
150
  self.spectral_signatures = [] # List[dict] — one dict per old task
151
- self.routing_temperature = prompt_config.get('attn_temperature', 1.0)
 
 
 
 
 
 
 
 
 
 
 
152
 
153
- # Adaptive training bias: beta = T * ln(alpha * n_old / (1 - alpha))
154
- # Ensures current task gets consistent routing weight ~alpha regardless
155
- # of total number of tasks (fixes softmax dilution with constant bias).
156
- # At inference, same A-row formula without bias (V8 symmetric routing).
157
- self._target_routing_alpha = prompt_config.get('target_routing_alpha', 0.8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
  # For inference logging
160
  self.all_attn_weights = []
@@ -174,7 +202,139 @@ class T5Stack(T5PreTrainedModel):
174
  # The old format (with 'value' param) causes transformers to silently ignore
175
  # gradient_checkpointing_kwargs (including use_reentrant=False).
176
 
177
- def compute_spectral_routing(self, avg_inputs_embeds):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  """
179
  V9: Routing with oracle-training / spectral-inference split + calibration.
180
 
@@ -202,9 +362,6 @@ class T5Stack(T5PreTrainedModel):
202
  Returns:
203
  (B, n_tasks, 1) routing weights: oracle one-hot (training) or top-1 (inference)
204
  """
205
- h = avg_inputs_embeds # (B, 1, d_model)
206
- h_norm_sq = (h ** 2).sum(dim=-1) + 1e-8 # (B, 1)
207
-
208
  fits = []
209
 
210
  # === CURRENT TASK: A-row fit ===
@@ -380,29 +537,36 @@ class T5Stack(T5PreTrainedModel):
380
  avg_inputs_embeds = (attention_mask.unsqueeze(-1) * inputs_embeds).sum(dim=1, keepdim=True) / _mask_count
381
 
382
  if not self.is_decoder and not self.prompt_config["run_single"]:
383
- if len(self.spectral_signatures) > 0:
384
- # Multi-task: compute routing
385
- key_attention_weights = self.compute_spectral_routing(avg_inputs_embeds)
386
- # Detach: routing weights are shared across all gradient-checkpointed
387
- # blocks via closure. Without detach, the second block's backward
388
- # fails with "backward through graph a second time" because the
389
- # first block already freed the shared graph (inputs_embeds -> routing).
390
- # Safe because routing uses lora_A.data (detached) and frozen signatures.
391
- key_attention_weights = key_attention_weights.detach()
392
-
393
- if self.is_inference:
394
- self.all_attn_weights.append(
395
- key_attention_weights.squeeze().mean(dim=0, keepdim=True).detach().to(torch.float).cpu().numpy()
396
- )
397
  else:
398
- # First task or no previous info: single LoRA, weight = 1
399
- key_attention_weights = torch.ones(
400
- batch_size, 1, 1, device=inputs_embeds.device, dtype=inputs_embeds.dtype
401
- )
402
- if self.is_inference:
403
- self.all_attn_weights.append(
404
- key_attention_weights.squeeze(2).mean(dim=0, keepdim=True).detach().to(torch.float).cpu().numpy()
 
 
 
 
 
 
 
 
 
 
405
  )
 
 
 
 
406
  self.key_attention_weights = key_attention_weights
407
  else:
408
  # Decoder or run_single: use whatever was passed (from encoder)
 
57
  T5LayerCrossAttention,
58
  T5Block,
59
  T5PreTrainedModel,
60
+ Trans_input,
61
  )
62
 
63
  logger = logging.get_logger(__name__)
 
146
  self.prompt_config = prompt_config
147
 
148
  if not self.is_decoder and not prompt_config["run_single"]:
149
+ self.routing_mode = prompt_config.get("routing_mode", "spectral")
150
+
151
+ # Common for all spectral/grassmann modes
152
  self.spectral_signatures = [] # List[dict] — one dict per old task
153
+
154
+ if self.routing_mode == "learned":
155
+ # V10a: Learned routing matching GainLoRA ROOT exactly
156
+ self.prompt_key = nn.Parameter(torch.randn((1, config.d_model)))
157
+ nn.init.uniform_(self.prompt_key, -1, 1)
158
+
159
+ self.trans_input = nn.Sequential(
160
+ nn.Linear(config.d_model, prompt_config["mlp_hidden_dim"], bias=False),
161
+ nn.SiLU(),
162
+ nn.Linear(prompt_config["mlp_hidden_dim"], config.d_model, bias=False),
163
+ nn.SiLU(),
164
+ )
165
 
166
+ self.get_trans_feature = False
167
+ self.stage_trans = 0
168
+ self.matrix_trans_1 = torch.zeros(config.d_model, config.d_model)
169
+ self.matrix_trans_2 = torch.zeros(prompt_config["mlp_hidden_dim"], prompt_config["mlp_hidden_dim"])
170
+ self.n_trans_matrix = 0
171
+
172
+ self.previous_prompts_keys = None
173
+ if prompt_config.get("previous_prompt_key_path") is not None and prompt_config.get("task_id", 0):
174
+ print("----------Loading Previous Keys----------")
175
+ self.previous_prompts_keys = nn.Parameter(torch.randn((prompt_config["task_id"], config.d_model)))
176
+ self.previous_prompts_keys.data = torch.load(prompt_config["previous_prompt_key_path"], weights_only=True)
177
+ self.previous_prompts_keys.requires_grad = False
178
+
179
+ self.previous_trans_input = Trans_input(config.d_model, prompt_config["mlp_hidden_dim"], prompt_config["task_id"])
180
+ for param in self.previous_trans_input.parameters():
181
+ param.requires_grad = False
182
+ else:
183
+ # V8/V9/V10b: Spectral routing parameters
184
+ self.routing_temperature = prompt_config.get('attn_temperature', 1.0)
185
+ self._target_routing_alpha = prompt_config.get('target_routing_alpha', 0.8)
186
 
187
  # For inference logging
188
  self.all_attn_weights = []
 
202
  # The old format (with 'value' param) causes transformers to silently ignore
203
  # gradient_checkpointing_kwargs (including use_reentrant=False).
204
 
205
+ def get_chunk(self, chunk):
206
+ if self.routing_mode == "learned":
207
+ self.chunk_trans = chunk
208
+ self.index_trans, self.step_trans = chunk, self.config.d_model // chunk
209
+ self.step, self.index = self.step_trans, self.index_trans
210
+ self.matrix_trans_1, self.matrix_trans_3, self.n_trans_matrix = {}, {}, {}
211
+ for idx in range(self.index_trans):
212
+ self.matrix_trans_1[idx] = torch.zeros(self.step_trans, self.step_trans).cuda()
213
+ self.matrix_trans_3[idx] = torch.zeros(self.step_trans, self.step_trans).cuda()
214
+ self.n_trans_matrix[idx] = 0
215
+ self.matrix_trans_2 = self.matrix_trans_2.cuda()
216
+
217
+ def get_matrix3(self, x, medium, x_final):
218
+ if self.routing_mode == "learned":
219
+ for idx in range(self.index_trans):
220
+ m1_curr = torch.bmm(x[:,:,idx*self.step_trans:(idx+1)*self.step_trans].detach().permute(0, 2, 1), x[:,:,idx*self.step_trans:(idx+1)*self.step_trans].detach()).sum(dim=0).float()/(x.shape[0]*x.shape[1])
221
+ m3_curr = torch.bmm(x_final[:,:,idx*self.step_trans:(idx+1)*self.step_trans].detach().permute(0, 2, 1), x_final[:,:,idx*self.step_trans:(idx+1)*self.step_trans].detach()).sum(dim=0).float()/(x_final.shape[0]*x_final.shape[1])
222
+
223
+ if len(self.matrix_trans_1) > 0 and isinstance(self.matrix_trans_1.get(idx), torch.Tensor) and self.matrix_trans_1.get(idx).sum() != 0:
224
+ self.matrix_trans_1[idx] = (self.matrix_trans_1[idx]*self.n_trans_matrix[idx] + m1_curr)/(self.n_trans_matrix[idx] + x.shape[0]*x.shape[1])
225
+ self.matrix_trans_3[idx] = (self.matrix_trans_3[idx]*self.n_trans_matrix[idx] + m3_curr)/(self.n_trans_matrix[idx] + x_final.shape[0]*x_final.shape[1])
226
+ else:
227
+ self.matrix_trans_1[idx] = m1_curr
228
+ self.matrix_trans_3[idx] = m3_curr
229
+ self.n_trans_matrix[idx] += x.shape[0]*x.shape[1]
230
+
231
+ if self.matrix_trans_2.sum() == 0:
232
+ self.matrix_trans_2 = torch.bmm(medium.detach().permute(0, 2, 1), medium.detach()).sum(dim=0).float()/(medium.shape[0]*medium.shape[1])
233
+ else:
234
+ self.matrix_trans_2 = (self.matrix_trans_2*self.n_trans_matrix[0] + torch.bmm(medium.detach().permute(0, 2, 1), medium.detach()).sum(dim=0).float())/(self.n_trans_matrix[0] + medium.shape[0]*medium.shape[1])
235
+
236
+ def cal_attention(self, prompt_key, x, return_logits=False):
237
+ # ROOT-style routing similarity
238
+ x = x/(x.norm(dim=-1,keepdim=True) + 1e-12)
239
+ prompt_key = prompt_key/(prompt_key.norm(dim=-1,keepdim=True) + 1e-12)
240
+ attn_scores = (x*prompt_key).sum(dim=-1, keepdim=True)
241
+ weights = torch.abs(torch.nn.functional.sigmoid(attn_scores*4)*2-1)
242
+ if not return_logits:
243
+ return weights
244
+ else:
245
+ return attn_scores
246
+
247
+ def compute_learned_routing(self, avg_inputs_embeds, batch_size):
248
+ """V10a: Learned MLP Routing copying GainLoRA exactly"""
249
+ prompt_key = self.prompt_key
250
+ if self.previous_prompts_keys is not None:
251
+ prompt_key = self.prompt_key.to(prompt_key.device)
252
+ past_prompt_key = torch.cat([prompt_key.repeat(batch_size, 1, 1), self.previous_prompts_keys.repeat(batch_size, 1, 1)], dim=1)
253
+
254
+ medium = self.trans_input[1](self.trans_input[0](avg_inputs_embeds))
255
+ x = self.trans_input[3](self.trans_input[2](medium))
256
+ if getattr(self, "get_trans_feature", False):
257
+ self.get_matrix3(avg_inputs_embeds, medium, x)
258
+
259
+ past_x = torch.cat([x, self.previous_trans_input(avg_inputs_embeds)], dim=1)
260
+ key_attention_weights = self.cal_attention(past_prompt_key, past_x)
261
+ else:
262
+ medium = self.trans_input[1](self.trans_input[0](avg_inputs_embeds))
263
+ x = self.trans_input[3](self.trans_input[2](medium))
264
+ if getattr(self, "get_trans_feature", False):
265
+ self.get_matrix3(avg_inputs_embeds, medium, x)
266
+
267
+ key_attention_weights = self.cal_attention(prompt_key.repeat(batch_size, 1, 1), x)
268
+ return key_attention_weights
269
+
270
+ def compute_grassmann_routing(self, h, h_norm_sq):
271
+ """V10b: Grassmann Distance Routing
272
+ Calculates principal angles between batch local subspace and candidate A_t subspaces.
273
+ """
274
+ B, _, d_model = h.shape
275
+ if self.training or B < 8:
276
+ # Fallback to A-row fit for very small batches or training (oracle handles training)
277
+ return self.compute_spectral_routing(h, h_norm_sq)
278
+
279
+ fits = []
280
+ r = self.block[0].layer[0].SelfAttention.lora_q.r
281
+
282
+ # Batch PCA to get local subspace U_batch (using SVD)
283
+ # h is (B, 1, d_model) -> reshape to (B, d_model)
284
+ h_flat = h.squeeze(1)
285
+ # torch.linalg.svd returns (U, S, Vh) where Vh = V^T
286
+ # We want right singular vectors V: h_flat = U @ diag(S) @ Vh, so V = Vh.T
287
+ _, _, Vh_batch = torch.linalg.svd(h_flat - h_flat.mean(dim=0, keepdim=True), full_matrices=False)
288
+ U_batch = Vh_batch[:r, :] # Vh is (min(B,d), d), so first r rows = top-r right sing. vectors, shape (r, d_model)
289
+
290
+ # Current task Grassmann dist
291
+ current_layer_dists = []
292
+ for block in self.block:
293
+ attn = block.layer[0].SelfAttention
294
+ for lora in [attn.lora_q, attn.lora_v]:
295
+ A = lora.lora_A.data.float().to(h.device) # (r, d_model)
296
+ # SVD of A^T: A^T = U_A @ diag(S_A) @ Vh_A => columns of U_A are right sing vecs of A
297
+ _, _, Vh_A = torch.linalg.svd(A, full_matrices=False) # A is (r, d_model), Vh_A is (r, d_model)
298
+ U_A = Vh_A[:r, :] # (r, d_model) — top-r right singular vectors of A, forming the subspace
299
+
300
+ # Grassmann distance via principal angles
301
+ # cos(theta_i) = singular values of U_batch @ U_A^T
302
+ M = torch.matmul(U_batch, U_A.T) # (r, r)
303
+ angles = torch.linalg.svdvals(M).clamp(-1.0, 1.0)
304
+ principal_angles = torch.acos(angles)
305
+ dist = torch.sqrt(torch.sum(principal_angles**2))
306
+ current_layer_dists.append(dist)
307
+
308
+ current_dist = torch.stack(current_layer_dists).mean(dim=0).item()
309
+ fits.append(1.0 / (current_dist + 1e-4)) # Inverse dist as affinity
310
+
311
+ # Old tasks
312
+ for sig_dict in self.spectral_signatures:
313
+ task_dists = []
314
+ for key, sig_data in sig_dict.items():
315
+ if not key.startswith('enc.'):
316
+ continue
317
+ A = sig_data['A'].to(h.device, dtype=torch.float32) # (r, d_model)
318
+ _, _, Vh_A = torch.linalg.svd(A, full_matrices=False)
319
+ U_A = Vh_A[:r, :]
320
+
321
+ M = torch.matmul(U_batch, U_A.T)
322
+ angles = torch.linalg.svdvals(M).clamp(-1.0, 1.0)
323
+ dist = torch.sqrt(torch.sum(torch.acos(angles)**2))
324
+ task_dists.append(dist)
325
+
326
+ if task_dists:
327
+ task_dist = torch.stack(task_dists).mean(dim=0).item()
328
+ fits.append(1.0 / (task_dist + 1e-4))
329
+ else:
330
+ fits.append(0.0)
331
+
332
+ fit_scores = torch.tensor(fits, device=h.device).unsqueeze(0).repeat(B, 1) # (B, n_tasks)
333
+ max_idx = fit_scores.argmax(dim=1, keepdim=True)
334
+ weights = torch.zeros_like(fit_scores).scatter_(1, max_idx, 1.0)
335
+ return weights.unsqueeze(2)
336
+
337
+ def compute_spectral_routing(self, h, h_norm_sq):
338
  """
339
  V9: Routing with oracle-training / spectral-inference split + calibration.
340
 
 
362
  Returns:
363
  (B, n_tasks, 1) routing weights: oracle one-hot (training) or top-1 (inference)
364
  """
 
 
 
365
  fits = []
366
 
367
  # === CURRENT TASK: A-row fit ===
 
537
  avg_inputs_embeds = (attention_mask.unsqueeze(-1) * inputs_embeds).sum(dim=1, keepdim=True) / _mask_count
538
 
539
  if not self.is_decoder and not self.prompt_config["run_single"]:
540
+ if self.routing_mode == "learned":
541
+ key_attention_weights = self.compute_learned_routing(avg_inputs_embeds, batch_size)
542
+
543
+ if self.is_inference and self.previous_prompts_keys is not None:
544
+ self.all_attn_weights.append(key_attention_weights.squeeze().mean(dim=0, keepdim=True).detach().to(torch.float).cpu().numpy())
545
+ elif self.is_inference:
546
+ self.all_attn_weights.append(key_attention_weights.squeeze(2).mean(dim=0, keepdim=True).detach().to(torch.float).cpu().numpy())
 
 
 
 
 
 
 
547
  else:
548
+ if len(self.spectral_signatures) > 0:
549
+ h_norm_sq = (avg_inputs_embeds ** 2).sum(dim=-1) + 1e-8 # (B, 1)
550
+ if self.routing_mode == "grassmann":
551
+ key_attention_weights = self.compute_grassmann_routing(avg_inputs_embeds, h_norm_sq)
552
+ else:
553
+ key_attention_weights = self.compute_spectral_routing(avg_inputs_embeds, h_norm_sq)
554
+
555
+ key_attention_weights = key_attention_weights.detach()
556
+
557
+ if self.is_inference:
558
+ self.all_attn_weights.append(
559
+ key_attention_weights.squeeze().mean(dim=0, keepdim=True).detach().to(torch.float).cpu().numpy()
560
+ )
561
+ else:
562
+ # First task or no previous info: single LoRA, weight = 1
563
+ key_attention_weights = torch.ones(
564
+ batch_size, 1, 1, device=inputs_embeds.device, dtype=inputs_embeds.dtype
565
  )
566
+ if self.is_inference:
567
+ self.all_attn_weights.append(
568
+ key_attention_weights.squeeze(2).mean(dim=0, keepdim=True).detach().to(torch.float).cpu().numpy()
569
+ )
570
  self.key_attention_weights = key_attention_weights
571
  else:
572
  # Decoder or run_single: use whatever was passed (from encoder)
results/experiment_versions.md CHANGED
@@ -360,4 +360,20 @@ V8 fail imdb/sst2/yahoo do B_t không học (gradient bị block). V9 oracle rou
360
  | - | V5 | **59.55** | **62.19** | Prototype routing + entropy + preconditioning |
361
  | - | V6 | ~27.4 | ~35.5 | SVD + C4 only (no prototypes) — **FAILED** |
362
  | - | V8 | 35.78 | 43.73 | C5 Data-Informed Init + C4 precond + A-row routing (no β) — PARTIAL |
363
- | - | V9 | (pending) | (pending) | Oracle routing (training) + calibrated Top-1 (inference) — bug fix |
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
360
  | - | V5 | **59.55** | **62.19** | Prototype routing + entropy + preconditioning |
361
  | - | V6 | ~27.4 | ~35.5 | SVD + C4 only (no prototypes) — **FAILED** |
362
  | - | V8 | 35.78 | 43.73 | C5 Data-Informed Init + C4 precond + A-row routing (no β) — PARTIAL |
363
+ | - | V9 | 43.14 | 51.55 | Oracle routing (training) + calibrated Top-1 (inference) — bug fix |
364
+ | - | V10a | (pending) | (pending) | Learned Routing + GPM + C5 + C4 |
365
+ | - | V10b | (pending) | (pending) | Grassmannian Distance Routing + C5 + C4 |
366
+
367
+ ---
368
+
369
+ ## V10 — Duality of Routing Mechanisms
370
+
371
+ **Motivation**: V9 showed that Top-1 A-row routing struggles to isolate orthogonal subspaces despite C4+C5. V10 explores two distinct modes to address routing precision while preserving C5's benefits.
372
+
373
+ ### V10a (Learned Routing - The Practical Baseline)
374
+ - **Method**: Reintroduces ROOT's `Trans_input` MLP and `prompt_key` gating, with exact GPM constraints applied to their weights post-optimizer step.
375
+ - **Why**: Proves that C5 initialization and C4 preconditioning can synergize with explicit function approximation for routing. Sacrifices the "parameter-free" claim but serves as a strong upper-bound baseline.
376
+
377
+ ### V10b (Grassmannian Distance Routing - The Zero-Replay Ideal)
378
+ - **Method**: Evaluates similarity by computing the Grassmannian distance (principal angles) between the batch's local principal subspace $U_{batch}$ and expert orthogonal projection $U_A$.
379
+ - **Why**: Directly measures subset geometric alignment, entirely bypassing scale-based similarity issues (GPM-Routing paradox). Batch-level SVD aggregates representations properly. Valid for batched inference ($B \ge 8$), falling back to A-row for small batches.