v11
Browse files
improve_gainlora/gen_script_long_order3_t5_rls.sh
CHANGED
|
@@ -93,6 +93,7 @@ fi
|
|
| 93 |
|
| 94 |
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 95 |
--do_train \
|
|
|
|
| 96 |
--predict_with_generate \
|
| 97 |
--model_name_or_path $2 \
|
| 98 |
--data_dir CL_Benchmark \
|
|
@@ -158,6 +159,7 @@ fi
|
|
| 158 |
|
| 159 |
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 160 |
--do_train \
|
|
|
|
| 161 |
--predict_with_generate \
|
| 162 |
--model_name_or_path $2 \
|
| 163 |
--previous_lora_path ${OUTPUT_BASE}/1-yelp/saved_weights \
|
|
@@ -215,6 +217,7 @@ sleep 5
|
|
| 215 |
# ============================================================
|
| 216 |
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 217 |
--do_train \
|
|
|
|
| 218 |
--predict_with_generate \
|
| 219 |
--model_name_or_path $2 \
|
| 220 |
--previous_lora_path ${OUTPUT_BASE}/1-yelp/saved_weights,${OUTPUT_BASE}/2-amazon/saved_weights \
|
|
@@ -272,6 +275,7 @@ sleep 5
|
|
| 272 |
# ============================================================
|
| 273 |
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 274 |
--do_train \
|
|
|
|
| 275 |
--predict_with_generate \
|
| 276 |
--model_name_or_path $2 \
|
| 277 |
--previous_lora_path ${OUTPUT_BASE}/1-yelp/saved_weights,${OUTPUT_BASE}/2-amazon/saved_weights,${OUTPUT_BASE}/3-mnli/saved_weights \
|
|
@@ -329,6 +333,7 @@ sleep 5
|
|
| 329 |
# ============================================================
|
| 330 |
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 331 |
--do_train \
|
|
|
|
| 332 |
--predict_with_generate \
|
| 333 |
--model_name_or_path $2 \
|
| 334 |
--previous_lora_path ${OUTPUT_BASE}/1-yelp/saved_weights,${OUTPUT_BASE}/2-amazon/saved_weights,${OUTPUT_BASE}/3-mnli/saved_weights,${OUTPUT_BASE}/4-cb/saved_weights \
|
|
@@ -386,6 +391,7 @@ sleep 5
|
|
| 386 |
# ============================================================
|
| 387 |
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 388 |
--do_train \
|
|
|
|
| 389 |
--predict_with_generate \
|
| 390 |
--model_name_or_path $2 \
|
| 391 |
--previous_lora_path ${OUTPUT_BASE}/1-yelp/saved_weights,${OUTPUT_BASE}/2-amazon/saved_weights,${OUTPUT_BASE}/3-mnli/saved_weights,${OUTPUT_BASE}/4-cb/saved_weights,${OUTPUT_BASE}/5-copa/saved_weights \
|
|
@@ -443,6 +449,7 @@ sleep 5
|
|
| 443 |
# ============================================================
|
| 444 |
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 445 |
--do_train \
|
|
|
|
| 446 |
--predict_with_generate \
|
| 447 |
--model_name_or_path $2 \
|
| 448 |
--previous_lora_path ${OUTPUT_BASE}/1-yelp/saved_weights,${OUTPUT_BASE}/2-amazon/saved_weights,${OUTPUT_BASE}/3-mnli/saved_weights,${OUTPUT_BASE}/4-cb/saved_weights,${OUTPUT_BASE}/5-copa/saved_weights,${OUTPUT_BASE}/6-qqp/saved_weights \
|
|
@@ -500,6 +507,7 @@ sleep 5
|
|
| 500 |
# ============================================================
|
| 501 |
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 502 |
--do_train \
|
|
|
|
| 503 |
--predict_with_generate \
|
| 504 |
--model_name_or_path $2 \
|
| 505 |
--previous_lora_path ${OUTPUT_BASE}/1-yelp/saved_weights,${OUTPUT_BASE}/2-amazon/saved_weights,${OUTPUT_BASE}/3-mnli/saved_weights,${OUTPUT_BASE}/4-cb/saved_weights,${OUTPUT_BASE}/5-copa/saved_weights,${OUTPUT_BASE}/6-qqp/saved_weights,${OUTPUT_BASE}/7-rte/saved_weights \
|
|
@@ -557,6 +565,7 @@ sleep 5
|
|
| 557 |
# ============================================================
|
| 558 |
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 559 |
--do_train \
|
|
|
|
| 560 |
--predict_with_generate \
|
| 561 |
--model_name_or_path $2 \
|
| 562 |
--previous_lora_path ${OUTPUT_BASE}/1-yelp/saved_weights,${OUTPUT_BASE}/2-amazon/saved_weights,${OUTPUT_BASE}/3-mnli/saved_weights,${OUTPUT_BASE}/4-cb/saved_weights,${OUTPUT_BASE}/5-copa/saved_weights,${OUTPUT_BASE}/6-qqp/saved_weights,${OUTPUT_BASE}/7-rte/saved_weights,${OUTPUT_BASE}/8-imdb/saved_weights \
|
|
@@ -614,6 +623,7 @@ sleep 5
|
|
| 614 |
# ============================================================
|
| 615 |
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 616 |
--do_train \
|
|
|
|
| 617 |
--predict_with_generate \
|
| 618 |
--model_name_or_path $2 \
|
| 619 |
--previous_lora_path ${OUTPUT_BASE}/1-yelp/saved_weights,${OUTPUT_BASE}/2-amazon/saved_weights,${OUTPUT_BASE}/3-mnli/saved_weights,${OUTPUT_BASE}/4-cb/saved_weights,${OUTPUT_BASE}/5-copa/saved_weights,${OUTPUT_BASE}/6-qqp/saved_weights,${OUTPUT_BASE}/7-rte/saved_weights,${OUTPUT_BASE}/8-imdb/saved_weights,${OUTPUT_BASE}/9-sst2/saved_weights \
|
|
@@ -671,6 +681,7 @@ sleep 5
|
|
| 671 |
# ============================================================
|
| 672 |
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 673 |
--do_train \
|
|
|
|
| 674 |
--predict_with_generate \
|
| 675 |
--model_name_or_path $2 \
|
| 676 |
--previous_lora_path ${OUTPUT_BASE}/1-yelp/saved_weights,${OUTPUT_BASE}/2-amazon/saved_weights,${OUTPUT_BASE}/3-mnli/saved_weights,${OUTPUT_BASE}/4-cb/saved_weights,${OUTPUT_BASE}/5-copa/saved_weights,${OUTPUT_BASE}/6-qqp/saved_weights,${OUTPUT_BASE}/7-rte/saved_weights,${OUTPUT_BASE}/8-imdb/saved_weights,${OUTPUT_BASE}/9-sst2/saved_weights,${OUTPUT_BASE}/10-dbpedia/saved_weights \
|
|
@@ -728,6 +739,7 @@ sleep 5
|
|
| 728 |
# ============================================================
|
| 729 |
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 730 |
--do_train \
|
|
|
|
| 731 |
--predict_with_generate \
|
| 732 |
--model_name_or_path $2 \
|
| 733 |
--previous_lora_path ${OUTPUT_BASE}/1-yelp/saved_weights,${OUTPUT_BASE}/2-amazon/saved_weights,${OUTPUT_BASE}/3-mnli/saved_weights,${OUTPUT_BASE}/4-cb/saved_weights,${OUTPUT_BASE}/5-copa/saved_weights,${OUTPUT_BASE}/6-qqp/saved_weights,${OUTPUT_BASE}/7-rte/saved_weights,${OUTPUT_BASE}/8-imdb/saved_weights,${OUTPUT_BASE}/9-sst2/saved_weights,${OUTPUT_BASE}/10-dbpedia/saved_weights,${OUTPUT_BASE}/11-agnews/saved_weights \
|
|
@@ -785,6 +797,7 @@ sleep 5
|
|
| 785 |
# ============================================================
|
| 786 |
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 787 |
--do_train \
|
|
|
|
| 788 |
--predict_with_generate \
|
| 789 |
--model_name_or_path $2 \
|
| 790 |
--previous_lora_path ${OUTPUT_BASE}/1-yelp/saved_weights,${OUTPUT_BASE}/2-amazon/saved_weights,${OUTPUT_BASE}/3-mnli/saved_weights,${OUTPUT_BASE}/4-cb/saved_weights,${OUTPUT_BASE}/5-copa/saved_weights,${OUTPUT_BASE}/6-qqp/saved_weights,${OUTPUT_BASE}/7-rte/saved_weights,${OUTPUT_BASE}/8-imdb/saved_weights,${OUTPUT_BASE}/9-sst2/saved_weights,${OUTPUT_BASE}/10-dbpedia/saved_weights,${OUTPUT_BASE}/11-agnews/saved_weights,${OUTPUT_BASE}/12-yahoo/saved_weights \
|
|
@@ -842,6 +855,7 @@ sleep 5
|
|
| 842 |
# ============================================================
|
| 843 |
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 844 |
--do_train \
|
|
|
|
| 845 |
--predict_with_generate \
|
| 846 |
--model_name_or_path $2 \
|
| 847 |
--previous_lora_path ${OUTPUT_BASE}/1-yelp/saved_weights,${OUTPUT_BASE}/2-amazon/saved_weights,${OUTPUT_BASE}/3-mnli/saved_weights,${OUTPUT_BASE}/4-cb/saved_weights,${OUTPUT_BASE}/5-copa/saved_weights,${OUTPUT_BASE}/6-qqp/saved_weights,${OUTPUT_BASE}/7-rte/saved_weights,${OUTPUT_BASE}/8-imdb/saved_weights,${OUTPUT_BASE}/9-sst2/saved_weights,${OUTPUT_BASE}/10-dbpedia/saved_weights,${OUTPUT_BASE}/11-agnews/saved_weights,${OUTPUT_BASE}/12-yahoo/saved_weights,${OUTPUT_BASE}/13-multirc/saved_weights \
|
|
@@ -899,6 +913,7 @@ sleep 5
|
|
| 899 |
# ============================================================
|
| 900 |
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 901 |
--do_train \
|
|
|
|
| 902 |
--predict_with_generate \
|
| 903 |
--model_name_or_path $2 \
|
| 904 |
--previous_lora_path ${OUTPUT_BASE}/1-yelp/saved_weights,${OUTPUT_BASE}/2-amazon/saved_weights,${OUTPUT_BASE}/3-mnli/saved_weights,${OUTPUT_BASE}/4-cb/saved_weights,${OUTPUT_BASE}/5-copa/saved_weights,${OUTPUT_BASE}/6-qqp/saved_weights,${OUTPUT_BASE}/7-rte/saved_weights,${OUTPUT_BASE}/8-imdb/saved_weights,${OUTPUT_BASE}/9-sst2/saved_weights,${OUTPUT_BASE}/10-dbpedia/saved_weights,${OUTPUT_BASE}/11-agnews/saved_weights,${OUTPUT_BASE}/12-yahoo/saved_weights,${OUTPUT_BASE}/13-multirc/saved_weights,${OUTPUT_BASE}/14-boolq/saved_weights \
|
|
|
|
| 93 |
|
| 94 |
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 95 |
--do_train \
|
| 96 |
+
--do_predict \
|
| 97 |
--predict_with_generate \
|
| 98 |
--model_name_or_path $2 \
|
| 99 |
--data_dir CL_Benchmark \
|
|
|
|
| 159 |
|
| 160 |
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 161 |
--do_train \
|
| 162 |
+
--do_predict \
|
| 163 |
--predict_with_generate \
|
| 164 |
--model_name_or_path $2 \
|
| 165 |
--previous_lora_path ${OUTPUT_BASE}/1-yelp/saved_weights \
|
|
|
|
| 217 |
# ============================================================
|
| 218 |
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 219 |
--do_train \
|
| 220 |
+
--do_predict \
|
| 221 |
--predict_with_generate \
|
| 222 |
--model_name_or_path $2 \
|
| 223 |
--previous_lora_path ${OUTPUT_BASE}/1-yelp/saved_weights,${OUTPUT_BASE}/2-amazon/saved_weights \
|
|
|
|
| 275 |
# ============================================================
|
| 276 |
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 277 |
--do_train \
|
| 278 |
+
--do_predict \
|
| 279 |
--predict_with_generate \
|
| 280 |
--model_name_or_path $2 \
|
| 281 |
--previous_lora_path ${OUTPUT_BASE}/1-yelp/saved_weights,${OUTPUT_BASE}/2-amazon/saved_weights,${OUTPUT_BASE}/3-mnli/saved_weights \
|
|
|
|
| 333 |
# ============================================================
|
| 334 |
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 335 |
--do_train \
|
| 336 |
+
--do_predict \
|
| 337 |
--predict_with_generate \
|
| 338 |
--model_name_or_path $2 \
|
| 339 |
--previous_lora_path ${OUTPUT_BASE}/1-yelp/saved_weights,${OUTPUT_BASE}/2-amazon/saved_weights,${OUTPUT_BASE}/3-mnli/saved_weights,${OUTPUT_BASE}/4-cb/saved_weights \
|
|
|
|
| 391 |
# ============================================================
|
| 392 |
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 393 |
--do_train \
|
| 394 |
+
--do_predict \
|
| 395 |
--predict_with_generate \
|
| 396 |
--model_name_or_path $2 \
|
| 397 |
--previous_lora_path ${OUTPUT_BASE}/1-yelp/saved_weights,${OUTPUT_BASE}/2-amazon/saved_weights,${OUTPUT_BASE}/3-mnli/saved_weights,${OUTPUT_BASE}/4-cb/saved_weights,${OUTPUT_BASE}/5-copa/saved_weights \
|
|
|
|
| 449 |
# ============================================================
|
| 450 |
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 451 |
--do_train \
|
| 452 |
+
--do_predict \
|
| 453 |
--predict_with_generate \
|
| 454 |
--model_name_or_path $2 \
|
| 455 |
--previous_lora_path ${OUTPUT_BASE}/1-yelp/saved_weights,${OUTPUT_BASE}/2-amazon/saved_weights,${OUTPUT_BASE}/3-mnli/saved_weights,${OUTPUT_BASE}/4-cb/saved_weights,${OUTPUT_BASE}/5-copa/saved_weights,${OUTPUT_BASE}/6-qqp/saved_weights \
|
|
|
|
| 507 |
# ============================================================
|
| 508 |
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 509 |
--do_train \
|
| 510 |
+
--do_predict \
|
| 511 |
--predict_with_generate \
|
| 512 |
--model_name_or_path $2 \
|
| 513 |
--previous_lora_path ${OUTPUT_BASE}/1-yelp/saved_weights,${OUTPUT_BASE}/2-amazon/saved_weights,${OUTPUT_BASE}/3-mnli/saved_weights,${OUTPUT_BASE}/4-cb/saved_weights,${OUTPUT_BASE}/5-copa/saved_weights,${OUTPUT_BASE}/6-qqp/saved_weights,${OUTPUT_BASE}/7-rte/saved_weights \
|
|
|
|
| 565 |
# ============================================================
|
| 566 |
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 567 |
--do_train \
|
| 568 |
+
--do_predict \
|
| 569 |
--predict_with_generate \
|
| 570 |
--model_name_or_path $2 \
|
| 571 |
--previous_lora_path ${OUTPUT_BASE}/1-yelp/saved_weights,${OUTPUT_BASE}/2-amazon/saved_weights,${OUTPUT_BASE}/3-mnli/saved_weights,${OUTPUT_BASE}/4-cb/saved_weights,${OUTPUT_BASE}/5-copa/saved_weights,${OUTPUT_BASE}/6-qqp/saved_weights,${OUTPUT_BASE}/7-rte/saved_weights,${OUTPUT_BASE}/8-imdb/saved_weights \
|
|
|
|
| 623 |
# ============================================================
|
| 624 |
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 625 |
--do_train \
|
| 626 |
+
--do_predict \
|
| 627 |
--predict_with_generate \
|
| 628 |
--model_name_or_path $2 \
|
| 629 |
--previous_lora_path ${OUTPUT_BASE}/1-yelp/saved_weights,${OUTPUT_BASE}/2-amazon/saved_weights,${OUTPUT_BASE}/3-mnli/saved_weights,${OUTPUT_BASE}/4-cb/saved_weights,${OUTPUT_BASE}/5-copa/saved_weights,${OUTPUT_BASE}/6-qqp/saved_weights,${OUTPUT_BASE}/7-rte/saved_weights,${OUTPUT_BASE}/8-imdb/saved_weights,${OUTPUT_BASE}/9-sst2/saved_weights \
|
|
|
|
| 681 |
# ============================================================
|
| 682 |
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 683 |
--do_train \
|
| 684 |
+
--do_predict \
|
| 685 |
--predict_with_generate \
|
| 686 |
--model_name_or_path $2 \
|
| 687 |
--previous_lora_path ${OUTPUT_BASE}/1-yelp/saved_weights,${OUTPUT_BASE}/2-amazon/saved_weights,${OUTPUT_BASE}/3-mnli/saved_weights,${OUTPUT_BASE}/4-cb/saved_weights,${OUTPUT_BASE}/5-copa/saved_weights,${OUTPUT_BASE}/6-qqp/saved_weights,${OUTPUT_BASE}/7-rte/saved_weights,${OUTPUT_BASE}/8-imdb/saved_weights,${OUTPUT_BASE}/9-sst2/saved_weights,${OUTPUT_BASE}/10-dbpedia/saved_weights \
|
|
|
|
| 739 |
# ============================================================
|
| 740 |
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 741 |
--do_train \
|
| 742 |
+
--do_predict \
|
| 743 |
--predict_with_generate \
|
| 744 |
--model_name_or_path $2 \
|
| 745 |
--previous_lora_path ${OUTPUT_BASE}/1-yelp/saved_weights,${OUTPUT_BASE}/2-amazon/saved_weights,${OUTPUT_BASE}/3-mnli/saved_weights,${OUTPUT_BASE}/4-cb/saved_weights,${OUTPUT_BASE}/5-copa/saved_weights,${OUTPUT_BASE}/6-qqp/saved_weights,${OUTPUT_BASE}/7-rte/saved_weights,${OUTPUT_BASE}/8-imdb/saved_weights,${OUTPUT_BASE}/9-sst2/saved_weights,${OUTPUT_BASE}/10-dbpedia/saved_weights,${OUTPUT_BASE}/11-agnews/saved_weights \
|
|
|
|
| 797 |
# ============================================================
|
| 798 |
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 799 |
--do_train \
|
| 800 |
+
--do_predict \
|
| 801 |
--predict_with_generate \
|
| 802 |
--model_name_or_path $2 \
|
| 803 |
--previous_lora_path ${OUTPUT_BASE}/1-yelp/saved_weights,${OUTPUT_BASE}/2-amazon/saved_weights,${OUTPUT_BASE}/3-mnli/saved_weights,${OUTPUT_BASE}/4-cb/saved_weights,${OUTPUT_BASE}/5-copa/saved_weights,${OUTPUT_BASE}/6-qqp/saved_weights,${OUTPUT_BASE}/7-rte/saved_weights,${OUTPUT_BASE}/8-imdb/saved_weights,${OUTPUT_BASE}/9-sst2/saved_weights,${OUTPUT_BASE}/10-dbpedia/saved_weights,${OUTPUT_BASE}/11-agnews/saved_weights,${OUTPUT_BASE}/12-yahoo/saved_weights \
|
|
|
|
| 855 |
# ============================================================
|
| 856 |
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 857 |
--do_train \
|
| 858 |
+
--do_predict \
|
| 859 |
--predict_with_generate \
|
| 860 |
--model_name_or_path $2 \
|
| 861 |
--previous_lora_path ${OUTPUT_BASE}/1-yelp/saved_weights,${OUTPUT_BASE}/2-amazon/saved_weights,${OUTPUT_BASE}/3-mnli/saved_weights,${OUTPUT_BASE}/4-cb/saved_weights,${OUTPUT_BASE}/5-copa/saved_weights,${OUTPUT_BASE}/6-qqp/saved_weights,${OUTPUT_BASE}/7-rte/saved_weights,${OUTPUT_BASE}/8-imdb/saved_weights,${OUTPUT_BASE}/9-sst2/saved_weights,${OUTPUT_BASE}/10-dbpedia/saved_weights,${OUTPUT_BASE}/11-agnews/saved_weights,${OUTPUT_BASE}/12-yahoo/saved_weights,${OUTPUT_BASE}/13-multirc/saved_weights \
|
|
|
|
| 913 |
# ============================================================
|
| 914 |
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 915 |
--do_train \
|
| 916 |
+
--do_predict \
|
| 917 |
--predict_with_generate \
|
| 918 |
--model_name_or_path $2 \
|
| 919 |
--previous_lora_path ${OUTPUT_BASE}/1-yelp/saved_weights,${OUTPUT_BASE}/2-amazon/saved_weights,${OUTPUT_BASE}/3-mnli/saved_weights,${OUTPUT_BASE}/4-cb/saved_weights,${OUTPUT_BASE}/5-copa/saved_weights,${OUTPUT_BASE}/6-qqp/saved_weights,${OUTPUT_BASE}/7-rte/saved_weights,${OUTPUT_BASE}/8-imdb/saved_weights,${OUTPUT_BASE}/9-sst2/saved_weights,${OUTPUT_BASE}/10-dbpedia/saved_weights,${OUTPUT_BASE}/11-agnews/saved_weights,${OUTPUT_BASE}/12-yahoo/saved_weights,${OUTPUT_BASE}/13-multirc/saved_weights,${OUTPUT_BASE}/14-boolq/saved_weights \
|
improve_gainlora/src/t5_specroute.py
CHANGED
|
@@ -748,11 +748,17 @@ class T5Stack(T5PreTrainedModel):
|
|
| 748 |
# So: model[0]=RLS[-1], model[1]=RLS[-2], ..., model[T]=RLS[0]
|
| 749 |
key_attention_weights = rls_weights.flip(dims=[1])
|
| 750 |
else:
|
| 751 |
-
#
|
| 752 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 753 |
batch_size, n_lora_experts, 1,
|
| 754 |
device=inputs_embeds.device, dtype=inputs_embeds.dtype
|
| 755 |
-
)
|
|
|
|
| 756 |
else:
|
| 757 |
# Training or first task: oracle routing (always current = index 0)
|
| 758 |
# Just create correct-sized tensor; oracle override below will set it
|
|
|
|
| 748 |
# So: model[0]=RLS[-1], model[1]=RLS[-2], ..., model[T]=RLS[0]
|
| 749 |
key_attention_weights = rls_weights.flip(dims=[1])
|
| 750 |
else:
|
| 751 |
+
# Mismatch: RLS fitted for N-1 tasks but model has N experts.
|
| 752 |
+
# This occurs during training-time eval (trainer.evaluate calls model.eval())
|
| 753 |
+
# before update_rls_router is called for the current task.
|
| 754 |
+
# Fall back to oracle current-task routing (index 0) to preserve
|
| 755 |
+
# valid eval metrics during training — otherwise uniform 1/N weight
|
| 756 |
+
# dilutes current-task signal to near zero as N grows.
|
| 757 |
+
key_attention_weights = torch.zeros(
|
| 758 |
batch_size, n_lora_experts, 1,
|
| 759 |
device=inputs_embeds.device, dtype=inputs_embeds.dtype
|
| 760 |
+
)
|
| 761 |
+
key_attention_weights[:, 0, 0] = 1.0
|
| 762 |
else:
|
| 763 |
# Training or first task: oracle routing (always current = index 0)
|
| 764 |
# Just create correct-sized tensor; oracle override below will set it
|