v10a
Browse files- final_check.py +64 -0
- improve_gainlora/T5_small/gen_script_long_order3_t5_small_specroute_v11.sh +950 -0
- improve_gainlora/results/gen_script_long_order3_t5_small_specroute_v2.txt +3 -0
- improve_gainlora/results/gen_script_long_order3_t5_small_specroute_v5.txt +3 -0
- improve_gainlora/src/cl_trainer_specroute.py +128 -3
- parse_and_score_v2.py +87 -0
- recalculate_em.py +82 -0
- results/experiment_versions.md +117 -0
final_check.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
def load_json(path):
|
| 5 |
+
with open(path, 'r') as f:
|
| 6 |
+
return json.load(f)
|
| 7 |
+
|
| 8 |
+
def get_matrix_from_outputs(base_dir, run_name, tasks):
|
| 9 |
+
matrix = []
|
| 10 |
+
for i in range(len(tasks)):
|
| 11 |
+
row = []
|
| 12 |
+
res_file = f"{base_dir}/{run_name}/outputs/{i+1}-{tasks[i]}/all_results.json"
|
| 13 |
+
if not os.path.exists(res_file):
|
| 14 |
+
matrix.append([0.0]*len(tasks))
|
| 15 |
+
continue
|
| 16 |
+
data = load_json(res_file)
|
| 17 |
+
for j in range(i + 1):
|
| 18 |
+
key = f"predict_eval_rougeL_for_{tasks[j]}"
|
| 19 |
+
row.append(data.get(key, 0.0))
|
| 20 |
+
row.extend([0.0]*(len(tasks)-len(row)))
|
| 21 |
+
matrix.append(row)
|
| 22 |
+
return matrix
|
| 23 |
+
|
| 24 |
+
def calculate_stats(matrix):
|
| 25 |
+
task_num = len(matrix[0])
|
| 26 |
+
final_row = matrix[-1]
|
| 27 |
+
AP = sum(final_row) / task_num
|
| 28 |
+
|
| 29 |
+
fgt_list = []
|
| 30 |
+
for j in range(task_num - 1):
|
| 31 |
+
history = [row[j] for row in matrix if row[j] > 0]
|
| 32 |
+
if not history:
|
| 33 |
+
continue
|
| 34 |
+
best = max(history)
|
| 35 |
+
final = final_row[j]
|
| 36 |
+
fgt_list.append(best - final)
|
| 37 |
+
|
| 38 |
+
Fgt = sum(fgt_list) / len(fgt_list) if fgt_list else 0.0
|
| 39 |
+
return AP, Fgt
|
| 40 |
+
|
| 41 |
+
tasks = ["yelp", "amazon", "mnli", "cb", "copa", "qqp", "rte", "imdb", "sst2", "dbpedia", "agnews", "yahoo", "multirc", "boolq", "wic"]
|
| 42 |
+
|
| 43 |
+
# ROOT
|
| 44 |
+
root_dir = "/Users/nnminh322/Desktop/personal/Continual/improve_gainlora/logs/root_t5_small"
|
| 45 |
+
root_run = "gen_script_long_order3_t5_small_gainlora_inflora"
|
| 46 |
+
# ROOT might not have all_results.json with predict metrics as seen earlier.
|
| 47 |
+
# So I'll use the user's documented values for ROOT if needed.
|
| 48 |
+
# But let's try reading V5 which definitely has them.
|
| 49 |
+
v5_dir = "/Users/nnminh322/Desktop/personal/Continual/improve_gainlora/logs/t5_small_improve"
|
| 50 |
+
v5_run = "gen_script_long_order3_t5_small_specroute_v5"
|
| 51 |
+
|
| 52 |
+
print("--- V5 Matrix ---")
|
| 53 |
+
try:
|
| 54 |
+
v5_matrix = get_matrix_from_outputs(v5_dir, v5_run, tasks)
|
| 55 |
+
v5_ap, v5_fgt = calculate_stats(v5_matrix)
|
| 56 |
+
print(f"V5 AP(rougeL): {v5_ap:.4f}")
|
| 57 |
+
print(f"V5 Fgt: {v5_fgt:.4f}")
|
| 58 |
+
except Exception as e:
|
| 59 |
+
print(f"V5 failed: {e}")
|
| 60 |
+
|
| 61 |
+
# For V10, we have the final vector from log:
|
| 62 |
+
v10_final = [59.9013, 59.7018, 30.5395, 0.0, 55.0, 11.9474, 10.1083, 89.8947, 65.2523, 53.1737, 65.0342, 62.0329, 43.1312, 62.4465, 56.4263]
|
| 63 |
+
v10_ap = sum(v10_final) / 15
|
| 64 |
+
print(f"V10 AP(rougeL): {v10_ap:.4f}")
|
improve_gainlora/T5_small/gen_script_long_order3_t5_small_specroute_v11.sh
ADDED
|
@@ -0,0 +1,950 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH -J cl
|
| 3 |
+
#SBATCH -o cl-%j.out
|
| 4 |
+
#SBATCH -p compute
|
| 5 |
+
#SBATCH -N 1
|
| 6 |
+
#SBATCH -t 20:00:00
|
| 7 |
+
#SBATCH --mem 128G
|
| 8 |
+
#SBATCH --gres=gpu:2
|
| 9 |
+
|
| 10 |
+
export CUDA_DEVICE_ORDER="PCI_BUS_ID"
|
| 11 |
+
|
| 12 |
+
port=$(shuf -i25000-30000 -n1)
|
| 13 |
+
|
| 14 |
+
# ============================================================
|
| 15 |
+
# Auto-detect GPU count and type for optimal parallelism
|
| 16 |
+
# ============================================================
|
| 17 |
+
NUM_GPUS=$(nvidia-smi -L 2>/dev/null | wc -l)
|
| 18 |
+
GPU_MEM=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits 2>/dev/null | head -1)
|
| 19 |
+
|
| 20 |
+
if [ -z "$GPU_MEM" ]; then
|
| 21 |
+
echo "ERROR: No GPU detected!"
|
| 22 |
+
exit 1
|
| 23 |
+
fi
|
| 24 |
+
|
| 25 |
+
# GPU type detection
|
| 26 |
+
# T4 <15500 MB | P100 15500-17000 MB | RTX3090 ~24576 | A100 40000 | H100 80000
|
| 27 |
+
if [ "$GPU_MEM" -lt 15500 ]; then
|
| 28 |
+
GPU_TYPE="t4"
|
| 29 |
+
echo "[GPU] Detected T4 (${GPU_MEM}MB)"
|
| 30 |
+
elif [ "$GPU_MEM" -le 17000 ]; then
|
| 31 |
+
GPU_TYPE="p100"
|
| 32 |
+
echo "[GPU] Detected P100 (${GPU_MEM}MB)"
|
| 33 |
+
else
|
| 34 |
+
GPU_TYPE="highvram"
|
| 35 |
+
echo "[GPU] Detected high-VRAM GPU (${GPU_MEM}MB)"
|
| 36 |
+
fi
|
| 37 |
+
|
| 38 |
+
# Parallelism: T4/P100 use gradient_checkpointing (16 GB fp32); highvram uses DataParallel if 2+ GPUs
|
| 39 |
+
if [ "$GPU_TYPE" = "t4" ] && [ "$NUM_GPUS" -ge 2 ]; then
|
| 40 |
+
GPU_MODE="t4_2gpu"
|
| 41 |
+
GPU_IDS="0,1"
|
| 42 |
+
FP16_FLAG="--gradient_checkpointing"
|
| 43 |
+
echo "[GPU] Strategy: 2x T4 DataParallel + fp32 + gradient_checkpointing"
|
| 44 |
+
elif [ "$GPU_TYPE" = "t4" ]; then
|
| 45 |
+
GPU_MODE="t4_1gpu"
|
| 46 |
+
GPU_IDS="${1:-0}"
|
| 47 |
+
FP16_FLAG="--gradient_checkpointing"
|
| 48 |
+
echo "[GPU] Strategy: 1x T4 (${GPU_MEM}MB) + fp32 + gradient_checkpointing"
|
| 49 |
+
elif [ "$GPU_TYPE" = "p100" ]; then
|
| 50 |
+
GPU_MODE="p100"
|
| 51 |
+
GPU_IDS="${1:-0}"
|
| 52 |
+
FP16_FLAG="--gradient_checkpointing"
|
| 53 |
+
echo "[GPU] Strategy: P100 16GB + fp32 + gradient_checkpointing"
|
| 54 |
+
else
|
| 55 |
+
GPU_MODE="a100"
|
| 56 |
+
if [ "$NUM_GPUS" -ge 2 ]; then
|
| 57 |
+
GPU_IDS="0,1"
|
| 58 |
+
echo "[GPU] Strategy: ${NUM_GPUS}x ${GPU_MEM}MB DataParallel (RTX3090/A100, fp32)"
|
| 59 |
+
else
|
| 60 |
+
GPU_IDS="${1:-0}"
|
| 61 |
+
echo "[GPU] Strategy: 1x ${GPU_MEM}MB GPU (fp32)"
|
| 62 |
+
fi
|
| 63 |
+
FP16_FLAG=""
|
| 64 |
+
fi
|
| 65 |
+
|
| 66 |
+
echo "[GPU] Using CUDA_VISIBLE_DEVICES=$GPU_IDS"
|
| 67 |
+
echo "============================================================"
|
| 68 |
+
echo ""
|
| 69 |
+
|
| 70 |
+
if [ "$GPU_MODE" = "t4_2gpu" ]; then
|
| 71 |
+
BSZ=8; GA=2; EVAL_BSZ=64
|
| 72 |
+
elif [ "$GPU_MODE" = "t4_1gpu" ]; then
|
| 73 |
+
BSZ=8; GA=2; EVAL_BSZ=32
|
| 74 |
+
elif [ "$GPU_MODE" = "p100" ]; then
|
| 75 |
+
BSZ=16; GA=2; EVAL_BSZ=32
|
| 76 |
+
else
|
| 77 |
+
BSZ=64; GA=1; EVAL_BSZ=128
|
| 78 |
+
fi
|
| 79 |
+
|
| 80 |
+
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 81 |
+
--do_train \
|
| 82 |
+
--do_predict \
|
| 83 |
+
--predict_with_generate \
|
| 84 |
+
--model_name_or_path $2 \
|
| 85 |
+
--data_dir CL_Benchmark \
|
| 86 |
+
--task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
|
| 87 |
+
--task_config_dir configs/gen_script_long_order3_t5_configs/yelp \
|
| 88 |
+
--output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/1-yelp \
|
| 89 |
+
--per_device_train_batch_size $BSZ \
|
| 90 |
+
--per_device_eval_batch_size $EVAL_BSZ \
|
| 91 |
+
--gradient_accumulation_steps $GA \
|
| 92 |
+
--learning_rate 0.0003 \
|
| 93 |
+
--num_train_epochs 10 \
|
| 94 |
+
--run_name gen_script_long_order3_t5_small_specroute_v11 \
|
| 95 |
+
--max_source_length 512 \
|
| 96 |
+
--max_target_length 50 \
|
| 97 |
+
--generation_max_length 50 \
|
| 98 |
+
--add_task_name False \
|
| 99 |
+
--add_dataset_name False \
|
| 100 |
+
--overwrite_output_dir \
|
| 101 |
+
--overwrite_cache \
|
| 102 |
+
--lr_scheduler_type constant \
|
| 103 |
+
--warmup_steps 0 \
|
| 104 |
+
--logging_strategy steps \
|
| 105 |
+
--logging_steps 10 \
|
| 106 |
+
--metric_for_best_model eval_exact_match \
|
| 107 |
+
--evaluation_strategy epoch \
|
| 108 |
+
--save_strategy epoch \
|
| 109 |
+
--save_total_limit 1 \
|
| 110 |
+
--load_best_model_at_end \
|
| 111 |
+
--lora_r 8 \
|
| 112 |
+
--lora_alpha 32 \
|
| 113 |
+
--lora_dropout 0.0 \
|
| 114 |
+
--data_replay_freq -1 \
|
| 115 |
+
--mlp_hidden_dim 100 \
|
| 116 |
+
--model_name specroute \
|
| 117 |
+
--routing_mode learned \
|
| 118 |
+
--threshold 0.995 \
|
| 119 |
+
--transthreshold 0.995 \
|
| 120 |
+
$FP16_FLAG
|
| 121 |
+
|
| 122 |
+
rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/1-yelp/checkpoint*
|
| 123 |
+
|
| 124 |
+
sleep 5
|
| 125 |
+
|
| 126 |
+
if [ "$GPU_MODE" = "t4_2gpu" ]; then
|
| 127 |
+
BSZ=8; GA=2; EVAL_BSZ=64
|
| 128 |
+
elif [ "$GPU_MODE" = "t4_1gpu" ]; then
|
| 129 |
+
BSZ=8; GA=2; EVAL_BSZ=32
|
| 130 |
+
elif [ "$GPU_MODE" = "p100" ]; then
|
| 131 |
+
BSZ=16; GA=2; EVAL_BSZ=32
|
| 132 |
+
else
|
| 133 |
+
BSZ=64; GA=1; EVAL_BSZ=128
|
| 134 |
+
fi
|
| 135 |
+
|
| 136 |
+
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 137 |
+
--do_train \
|
| 138 |
+
--load_checkpoint_from logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/1-yelp/saved_weights/trans_input.pt \
|
| 139 |
+
--previous_prompt_key_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/1-yelp/saved_weights/prompts_keys_till_now.pt \
|
| 140 |
+
--do_predict \
|
| 141 |
+
--predict_with_generate \
|
| 142 |
+
--model_name_or_path $2 \
|
| 143 |
+
--previous_lora_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/1-yelp/saved_weights \
|
| 144 |
+
--data_dir CL_Benchmark \
|
| 145 |
+
--task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
|
| 146 |
+
--task_config_dir configs/gen_script_long_order3_t5_configs/amazon \
|
| 147 |
+
--output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/2-amazon \
|
| 148 |
+
--per_device_train_batch_size $BSZ \
|
| 149 |
+
--per_device_eval_batch_size $EVAL_BSZ \
|
| 150 |
+
--gradient_accumulation_steps $GA \
|
| 151 |
+
--learning_rate 0.0003 \
|
| 152 |
+
--num_train_epochs 10 \
|
| 153 |
+
--run_name gen_script_long_order3_t5_small_specroute_v11 \
|
| 154 |
+
--max_source_length 512 \
|
| 155 |
+
--max_target_length 50 \
|
| 156 |
+
--generation_max_length 50 \
|
| 157 |
+
--add_task_name False \
|
| 158 |
+
--add_dataset_name False \
|
| 159 |
+
--overwrite_output_dir \
|
| 160 |
+
--overwrite_cache \
|
| 161 |
+
--lr_scheduler_type constant \
|
| 162 |
+
--warmup_steps 0 \
|
| 163 |
+
--logging_strategy steps \
|
| 164 |
+
--logging_steps 10 \
|
| 165 |
+
--metric_for_best_model eval_exact_match_for_amazon \
|
| 166 |
+
--evaluation_strategy epoch \
|
| 167 |
+
--save_strategy epoch \
|
| 168 |
+
--save_total_limit 1 \
|
| 169 |
+
--load_best_model_at_end \
|
| 170 |
+
--lora_r 8 \
|
| 171 |
+
--lora_alpha 32 \
|
| 172 |
+
--lora_dropout 0.0 \
|
| 173 |
+
--data_replay_freq -1 \
|
| 174 |
+
--mlp_hidden_dim 100 \
|
| 175 |
+
--model_name specroute \
|
| 176 |
+
--routing_mode learned \
|
| 177 |
+
--threshold 0.995 \
|
| 178 |
+
--transthreshold 0.995 \
|
| 179 |
+
$FP16_FLAG
|
| 180 |
+
|
| 181 |
+
rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/2-amazon/checkpoint*
|
| 182 |
+
|
| 183 |
+
sleep 5
|
| 184 |
+
|
| 185 |
+
if [ "$GPU_MODE" = "t4_2gpu" ]; then
|
| 186 |
+
BSZ=8; GA=2; EVAL_BSZ=64
|
| 187 |
+
elif [ "$GPU_MODE" = "t4_1gpu" ]; then
|
| 188 |
+
BSZ=8; GA=2; EVAL_BSZ=32
|
| 189 |
+
elif [ "$GPU_MODE" = "p100" ]; then
|
| 190 |
+
BSZ=16; GA=2; EVAL_BSZ=32
|
| 191 |
+
else
|
| 192 |
+
BSZ=64; GA=1; EVAL_BSZ=128
|
| 193 |
+
fi
|
| 194 |
+
|
| 195 |
+
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 196 |
+
--do_train \
|
| 197 |
+
--load_checkpoint_from logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/2-amazon/saved_weights/trans_input.pt \
|
| 198 |
+
--previous_prompt_key_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/2-amazon/saved_weights/prompts_keys_till_now.pt \
|
| 199 |
+
--do_predict \
|
| 200 |
+
--predict_with_generate \
|
| 201 |
+
--model_name_or_path $2 \
|
| 202 |
+
--previous_lora_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/1-yelp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/2-amazon/saved_weights \
|
| 203 |
+
--data_dir CL_Benchmark \
|
| 204 |
+
--task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
|
| 205 |
+
--task_config_dir configs/gen_script_long_order3_t5_configs/mnli \
|
| 206 |
+
--output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/3-mnli \
|
| 207 |
+
--per_device_train_batch_size $BSZ \
|
| 208 |
+
--per_device_eval_batch_size $EVAL_BSZ \
|
| 209 |
+
--gradient_accumulation_steps $GA \
|
| 210 |
+
--learning_rate 0.0003 \
|
| 211 |
+
--num_train_epochs 10 \
|
| 212 |
+
--run_name gen_script_long_order3_t5_small_specroute_v11 \
|
| 213 |
+
--max_source_length 512 \
|
| 214 |
+
--max_target_length 50 \
|
| 215 |
+
--generation_max_length 50 \
|
| 216 |
+
--add_task_name False \
|
| 217 |
+
--add_dataset_name False \
|
| 218 |
+
--overwrite_output_dir \
|
| 219 |
+
--overwrite_cache \
|
| 220 |
+
--lr_scheduler_type constant \
|
| 221 |
+
--warmup_steps 0 \
|
| 222 |
+
--logging_strategy steps \
|
| 223 |
+
--logging_steps 10 \
|
| 224 |
+
--metric_for_best_model eval_exact_match_for_mnli \
|
| 225 |
+
--evaluation_strategy epoch \
|
| 226 |
+
--save_strategy epoch \
|
| 227 |
+
--save_total_limit 1 \
|
| 228 |
+
--load_best_model_at_end \
|
| 229 |
+
--lora_r 8 \
|
| 230 |
+
--lora_alpha 32 \
|
| 231 |
+
--lora_dropout 0.0 \
|
| 232 |
+
--data_replay_freq -1 \
|
| 233 |
+
--mlp_hidden_dim 100 \
|
| 234 |
+
--model_name specroute \
|
| 235 |
+
--routing_mode learned \
|
| 236 |
+
--threshold 0.995 \
|
| 237 |
+
--transthreshold 0.995 \
|
| 238 |
+
$FP16_FLAG
|
| 239 |
+
|
| 240 |
+
rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/3-mnli/checkpoint*
|
| 241 |
+
|
| 242 |
+
sleep 5
|
| 243 |
+
|
| 244 |
+
if [ "$GPU_MODE" = "t4_2gpu" ]; then
|
| 245 |
+
BSZ=8; GA=2; EVAL_BSZ=64
|
| 246 |
+
elif [ "$GPU_MODE" = "t4_1gpu" ]; then
|
| 247 |
+
BSZ=8; GA=2; EVAL_BSZ=32
|
| 248 |
+
elif [ "$GPU_MODE" = "p100" ]; then
|
| 249 |
+
BSZ=16; GA=2; EVAL_BSZ=32
|
| 250 |
+
else
|
| 251 |
+
BSZ=64; GA=1; EVAL_BSZ=128
|
| 252 |
+
fi
|
| 253 |
+
|
| 254 |
+
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 255 |
+
--do_train \
|
| 256 |
+
--load_checkpoint_from logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/3-mnli/saved_weights/trans_input.pt \
|
| 257 |
+
--previous_prompt_key_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/3-mnli/saved_weights/prompts_keys_till_now.pt \
|
| 258 |
+
--do_predict \
|
| 259 |
+
--predict_with_generate \
|
| 260 |
+
--model_name_or_path $2 \
|
| 261 |
+
--previous_lora_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/1-yelp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/2-amazon/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/3-mnli/saved_weights \
|
| 262 |
+
--data_dir CL_Benchmark \
|
| 263 |
+
--task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
|
| 264 |
+
--task_config_dir configs/gen_script_long_order3_t5_configs/cb \
|
| 265 |
+
--output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/4-cb \
|
| 266 |
+
--per_device_train_batch_size $BSZ \
|
| 267 |
+
--per_device_eval_batch_size $EVAL_BSZ \
|
| 268 |
+
--gradient_accumulation_steps $GA \
|
| 269 |
+
--learning_rate 0.0003 \
|
| 270 |
+
--num_train_epochs 10 \
|
| 271 |
+
--run_name gen_script_long_order3_t5_small_specroute_v11 \
|
| 272 |
+
--max_source_length 512 \
|
| 273 |
+
--max_target_length 50 \
|
| 274 |
+
--generation_max_length 50 \
|
| 275 |
+
--add_task_name False \
|
| 276 |
+
--add_dataset_name False \
|
| 277 |
+
--overwrite_output_dir \
|
| 278 |
+
--overwrite_cache \
|
| 279 |
+
--lr_scheduler_type constant \
|
| 280 |
+
--warmup_steps 0 \
|
| 281 |
+
--logging_strategy steps \
|
| 282 |
+
--logging_steps 10 \
|
| 283 |
+
--metric_for_best_model eval_exact_match_for_cb \
|
| 284 |
+
--evaluation_strategy epoch \
|
| 285 |
+
--save_strategy epoch \
|
| 286 |
+
--save_total_limit 1 \
|
| 287 |
+
--load_best_model_at_end \
|
| 288 |
+
--lora_r 8 \
|
| 289 |
+
--lora_alpha 32 \
|
| 290 |
+
--lora_dropout 0.0 \
|
| 291 |
+
--data_replay_freq -1 \
|
| 292 |
+
--mlp_hidden_dim 100 \
|
| 293 |
+
--model_name specroute \
|
| 294 |
+
--routing_mode learned \
|
| 295 |
+
--threshold 0.995 \
|
| 296 |
+
--transthreshold 0.995 \
|
| 297 |
+
$FP16_FLAG
|
| 298 |
+
|
| 299 |
+
rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/4-cb/checkpoint*
|
| 300 |
+
|
| 301 |
+
sleep 5
|
| 302 |
+
|
| 303 |
+
if [ "$GPU_MODE" = "t4_2gpu" ]; then
|
| 304 |
+
BSZ=8; GA=2; EVAL_BSZ=64
|
| 305 |
+
elif [ "$GPU_MODE" = "t4_1gpu" ]; then
|
| 306 |
+
BSZ=8; GA=2; EVAL_BSZ=32
|
| 307 |
+
elif [ "$GPU_MODE" = "p100" ]; then
|
| 308 |
+
BSZ=16; GA=2; EVAL_BSZ=32
|
| 309 |
+
else
|
| 310 |
+
BSZ=64; GA=1; EVAL_BSZ=128
|
| 311 |
+
fi
|
| 312 |
+
|
| 313 |
+
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 314 |
+
--do_train \
|
| 315 |
+
--load_checkpoint_from logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/4-cb/saved_weights/trans_input.pt \
|
| 316 |
+
--previous_prompt_key_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/4-cb/saved_weights/prompts_keys_till_now.pt \
|
| 317 |
+
--do_predict \
|
| 318 |
+
--predict_with_generate \
|
| 319 |
+
--model_name_or_path $2 \
|
| 320 |
+
--previous_lora_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/1-yelp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/2-amazon/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/3-mnli/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/4-cb/saved_weights \
|
| 321 |
+
--data_dir CL_Benchmark \
|
| 322 |
+
--task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
|
| 323 |
+
--task_config_dir configs/gen_script_long_order3_t5_configs/copa \
|
| 324 |
+
--output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/5-copa \
|
| 325 |
+
--per_device_train_batch_size $BSZ \
|
| 326 |
+
--per_device_eval_batch_size $EVAL_BSZ \
|
| 327 |
+
--gradient_accumulation_steps $GA \
|
| 328 |
+
--learning_rate 0.0003 \
|
| 329 |
+
--num_train_epochs 10 \
|
| 330 |
+
--run_name gen_script_long_order3_t5_small_specroute_v11 \
|
| 331 |
+
--max_source_length 512 \
|
| 332 |
+
--max_target_length 50 \
|
| 333 |
+
--generation_max_length 50 \
|
| 334 |
+
--add_task_name False \
|
| 335 |
+
--add_dataset_name False \
|
| 336 |
+
--overwrite_output_dir \
|
| 337 |
+
--overwrite_cache \
|
| 338 |
+
--lr_scheduler_type constant \
|
| 339 |
+
--warmup_steps 0 \
|
| 340 |
+
--logging_strategy steps \
|
| 341 |
+
--logging_steps 10 \
|
| 342 |
+
--metric_for_best_model eval_exact_match_for_copa \
|
| 343 |
+
--evaluation_strategy epoch \
|
| 344 |
+
--save_strategy epoch \
|
| 345 |
+
--save_total_limit 1 \
|
| 346 |
+
--load_best_model_at_end \
|
| 347 |
+
--lora_r 8 \
|
| 348 |
+
--lora_alpha 32 \
|
| 349 |
+
--lora_dropout 0.0 \
|
| 350 |
+
--data_replay_freq -1 \
|
| 351 |
+
--mlp_hidden_dim 100 \
|
| 352 |
+
--model_name specroute \
|
| 353 |
+
--routing_mode learned \
|
| 354 |
+
--threshold 0.995 \
|
| 355 |
+
--transthreshold 0.995 \
|
| 356 |
+
$FP16_FLAG
|
| 357 |
+
|
| 358 |
+
rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/5-copa/checkpoint*
|
| 359 |
+
|
| 360 |
+
sleep 5
|
| 361 |
+
|
| 362 |
+
if [ "$GPU_MODE" = "t4_2gpu" ]; then
|
| 363 |
+
BSZ=8; GA=2; EVAL_BSZ=64
|
| 364 |
+
elif [ "$GPU_MODE" = "t4_1gpu" ]; then
|
| 365 |
+
BSZ=8; GA=2; EVAL_BSZ=32
|
| 366 |
+
elif [ "$GPU_MODE" = "p100" ]; then
|
| 367 |
+
BSZ=16; GA=2; EVAL_BSZ=32
|
| 368 |
+
else
|
| 369 |
+
BSZ=64; GA=1; EVAL_BSZ=128
|
| 370 |
+
fi
|
| 371 |
+
|
| 372 |
+
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 373 |
+
--do_train \
|
| 374 |
+
--load_checkpoint_from logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/5-copa/saved_weights/trans_input.pt \
|
| 375 |
+
--previous_prompt_key_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/5-copa/saved_weights/prompts_keys_till_now.pt \
|
| 376 |
+
--do_predict \
|
| 377 |
+
--predict_with_generate \
|
| 378 |
+
--model_name_or_path $2 \
|
| 379 |
+
--previous_lora_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/1-yelp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/2-amazon/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/3-mnli/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/4-cb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/5-copa/saved_weights \
|
| 380 |
+
--data_dir CL_Benchmark \
|
| 381 |
+
--task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
|
| 382 |
+
--task_config_dir configs/gen_script_long_order3_t5_configs/qqp \
|
| 383 |
+
--output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/6-qqp \
|
| 384 |
+
--per_device_train_batch_size $BSZ \
|
| 385 |
+
--per_device_eval_batch_size $EVAL_BSZ \
|
| 386 |
+
--gradient_accumulation_steps $GA \
|
| 387 |
+
--learning_rate 0.0003 \
|
| 388 |
+
--num_train_epochs 10 \
|
| 389 |
+
--run_name gen_script_long_order3_t5_small_specroute_v11 \
|
| 390 |
+
--max_source_length 512 \
|
| 391 |
+
--max_target_length 50 \
|
| 392 |
+
--generation_max_length 50 \
|
| 393 |
+
--add_task_name False \
|
| 394 |
+
--add_dataset_name False \
|
| 395 |
+
--overwrite_output_dir \
|
| 396 |
+
--overwrite_cache \
|
| 397 |
+
--lr_scheduler_type constant \
|
| 398 |
+
--warmup_steps 0 \
|
| 399 |
+
--logging_strategy steps \
|
| 400 |
+
--logging_steps 10 \
|
| 401 |
+
--metric_for_best_model eval_exact_match_for_qqp \
|
| 402 |
+
--evaluation_strategy epoch \
|
| 403 |
+
--save_strategy epoch \
|
| 404 |
+
--save_total_limit 1 \
|
| 405 |
+
--load_best_model_at_end \
|
| 406 |
+
--lora_r 8 \
|
| 407 |
+
--lora_alpha 32 \
|
| 408 |
+
--lora_dropout 0.0 \
|
| 409 |
+
--data_replay_freq -1 \
|
| 410 |
+
--mlp_hidden_dim 100 \
|
| 411 |
+
--model_name specroute \
|
| 412 |
+
--routing_mode learned \
|
| 413 |
+
--threshold 0.995 \
|
| 414 |
+
--transthreshold 0.995 \
|
| 415 |
+
$FP16_FLAG
|
| 416 |
+
|
| 417 |
+
rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/6-qqp/checkpoint*
|
| 418 |
+
|
| 419 |
+
sleep 5
|
| 420 |
+
|
| 421 |
+
if [ "$GPU_MODE" = "t4_2gpu" ]; then
|
| 422 |
+
BSZ=8; GA=2; EVAL_BSZ=64
|
| 423 |
+
elif [ "$GPU_MODE" = "t4_1gpu" ]; then
|
| 424 |
+
BSZ=8; GA=2; EVAL_BSZ=32
|
| 425 |
+
elif [ "$GPU_MODE" = "p100" ]; then
|
| 426 |
+
BSZ=16; GA=2; EVAL_BSZ=32
|
| 427 |
+
else
|
| 428 |
+
BSZ=64; GA=1; EVAL_BSZ=128
|
| 429 |
+
fi
|
| 430 |
+
|
| 431 |
+
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 432 |
+
--do_train \
|
| 433 |
+
--load_checkpoint_from logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/6-qqp/saved_weights/trans_input.pt \
|
| 434 |
+
--previous_prompt_key_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/6-qqp/saved_weights/prompts_keys_till_now.pt \
|
| 435 |
+
--do_predict \
|
| 436 |
+
--predict_with_generate \
|
| 437 |
+
--model_name_or_path $2 \
|
| 438 |
+
--previous_lora_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/1-yelp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/2-amazon/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/3-mnli/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/4-cb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/5-copa/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/6-qqp/saved_weights \
|
| 439 |
+
--data_dir CL_Benchmark \
|
| 440 |
+
--task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
|
| 441 |
+
--task_config_dir configs/gen_script_long_order3_t5_configs/rte \
|
| 442 |
+
--output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/7-rte \
|
| 443 |
+
--per_device_train_batch_size $BSZ \
|
| 444 |
+
--per_device_eval_batch_size $EVAL_BSZ \
|
| 445 |
+
--gradient_accumulation_steps $GA \
|
| 446 |
+
--learning_rate 0.0003 \
|
| 447 |
+
--num_train_epochs 10 \
|
| 448 |
+
--run_name gen_script_long_order3_t5_small_specroute_v11 \
|
| 449 |
+
--max_source_length 512 \
|
| 450 |
+
--max_target_length 50 \
|
| 451 |
+
--generation_max_length 50 \
|
| 452 |
+
--add_task_name False \
|
| 453 |
+
--add_dataset_name False \
|
| 454 |
+
--overwrite_output_dir \
|
| 455 |
+
--overwrite_cache \
|
| 456 |
+
--lr_scheduler_type constant \
|
| 457 |
+
--warmup_steps 0 \
|
| 458 |
+
--logging_strategy steps \
|
| 459 |
+
--logging_steps 10 \
|
| 460 |
+
--metric_for_best_model eval_exact_match_for_rte \
|
| 461 |
+
--evaluation_strategy epoch \
|
| 462 |
+
--save_strategy epoch \
|
| 463 |
+
--save_total_limit 1 \
|
| 464 |
+
--load_best_model_at_end \
|
| 465 |
+
--lora_r 8 \
|
| 466 |
+
--lora_alpha 32 \
|
| 467 |
+
--lora_dropout 0.0 \
|
| 468 |
+
--data_replay_freq -1 \
|
| 469 |
+
--mlp_hidden_dim 100 \
|
| 470 |
+
--model_name specroute \
|
| 471 |
+
--routing_mode learned \
|
| 472 |
+
--threshold 0.995 \
|
| 473 |
+
--transthreshold 0.995 \
|
| 474 |
+
$FP16_FLAG
|
| 475 |
+
|
| 476 |
+
rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/7-rte/checkpoint*
|
| 477 |
+
|
| 478 |
+
sleep 5
|
| 479 |
+
|
| 480 |
+
if [ "$GPU_MODE" = "t4_2gpu" ]; then
|
| 481 |
+
BSZ=8; GA=2; EVAL_BSZ=64
|
| 482 |
+
elif [ "$GPU_MODE" = "t4_1gpu" ]; then
|
| 483 |
+
BSZ=8; GA=2; EVAL_BSZ=32
|
| 484 |
+
elif [ "$GPU_MODE" = "p100" ]; then
|
| 485 |
+
BSZ=16; GA=2; EVAL_BSZ=32
|
| 486 |
+
else
|
| 487 |
+
BSZ=64; GA=1; EVAL_BSZ=128
|
| 488 |
+
fi
|
| 489 |
+
|
| 490 |
+
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 491 |
+
--do_train \
|
| 492 |
+
--load_checkpoint_from logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/7-rte/saved_weights/trans_input.pt \
|
| 493 |
+
--previous_prompt_key_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/7-rte/saved_weights/prompts_keys_till_now.pt \
|
| 494 |
+
--do_predict \
|
| 495 |
+
--predict_with_generate \
|
| 496 |
+
--model_name_or_path $2 \
|
| 497 |
+
--previous_lora_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/1-yelp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/2-amazon/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/3-mnli/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/4-cb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/5-copa/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/6-qqp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/7-rte/saved_weights \
|
| 498 |
+
--data_dir CL_Benchmark \
|
| 499 |
+
--task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
|
| 500 |
+
--task_config_dir configs/gen_script_long_order3_t5_configs/imdb \
|
| 501 |
+
--output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/8-imdb \
|
| 502 |
+
--per_device_train_batch_size $BSZ \
|
| 503 |
+
--per_device_eval_batch_size $EVAL_BSZ \
|
| 504 |
+
--gradient_accumulation_steps $GA \
|
| 505 |
+
--learning_rate 0.0003 \
|
| 506 |
+
--num_train_epochs 10 \
|
| 507 |
+
--run_name gen_script_long_order3_t5_small_specroute_v11 \
|
| 508 |
+
--max_source_length 512 \
|
| 509 |
+
--max_target_length 50 \
|
| 510 |
+
--generation_max_length 50 \
|
| 511 |
+
--add_task_name False \
|
| 512 |
+
--add_dataset_name False \
|
| 513 |
+
--overwrite_output_dir \
|
| 514 |
+
--overwrite_cache \
|
| 515 |
+
--lr_scheduler_type constant \
|
| 516 |
+
--warmup_steps 0 \
|
| 517 |
+
--logging_strategy steps \
|
| 518 |
+
--logging_steps 10 \
|
| 519 |
+
--metric_for_best_model eval_exact_match_for_imdb \
|
| 520 |
+
--evaluation_strategy epoch \
|
| 521 |
+
--save_strategy epoch \
|
| 522 |
+
--save_total_limit 1 \
|
| 523 |
+
--load_best_model_at_end \
|
| 524 |
+
--lora_r 8 \
|
| 525 |
+
--lora_alpha 32 \
|
| 526 |
+
--lora_dropout 0.0 \
|
| 527 |
+
--data_replay_freq -1 \
|
| 528 |
+
--mlp_hidden_dim 100 \
|
| 529 |
+
--model_name specroute \
|
| 530 |
+
--routing_mode learned \
|
| 531 |
+
--threshold 0.995 \
|
| 532 |
+
--transthreshold 0.995 \
|
| 533 |
+
$FP16_FLAG
|
| 534 |
+
|
| 535 |
+
rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/8-imdb/checkpoint*
|
| 536 |
+
|
| 537 |
+
sleep 5
|
| 538 |
+
|
| 539 |
+
if [ "$GPU_MODE" = "t4_2gpu" ]; then
|
| 540 |
+
BSZ=8; GA=2; EVAL_BSZ=64
|
| 541 |
+
elif [ "$GPU_MODE" = "t4_1gpu" ]; then
|
| 542 |
+
BSZ=8; GA=2; EVAL_BSZ=32
|
| 543 |
+
elif [ "$GPU_MODE" = "p100" ]; then
|
| 544 |
+
BSZ=16; GA=2; EVAL_BSZ=32
|
| 545 |
+
else
|
| 546 |
+
BSZ=64; GA=1; EVAL_BSZ=128
|
| 547 |
+
fi
|
| 548 |
+
|
| 549 |
+
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 550 |
+
--do_train \
|
| 551 |
+
--load_checkpoint_from logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/8-imdb/saved_weights/trans_input.pt \
|
| 552 |
+
--previous_prompt_key_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/8-imdb/saved_weights/prompts_keys_till_now.pt \
|
| 553 |
+
--do_predict \
|
| 554 |
+
--predict_with_generate \
|
| 555 |
+
--model_name_or_path $2 \
|
| 556 |
+
--previous_lora_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/1-yelp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/2-amazon/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/3-mnli/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/4-cb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/5-copa/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/6-qqp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/7-rte/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/8-imdb/saved_weights \
|
| 557 |
+
--data_dir CL_Benchmark \
|
| 558 |
+
--task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
|
| 559 |
+
--task_config_dir configs/gen_script_long_order3_t5_configs/sst2 \
|
| 560 |
+
--output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/9-sst2 \
|
| 561 |
+
--per_device_train_batch_size $BSZ \
|
| 562 |
+
--per_device_eval_batch_size $EVAL_BSZ \
|
| 563 |
+
--gradient_accumulation_steps $GA \
|
| 564 |
+
--learning_rate 0.0003 \
|
| 565 |
+
--num_train_epochs 10 \
|
| 566 |
+
--run_name gen_script_long_order3_t5_small_specroute_v11 \
|
| 567 |
+
--max_source_length 512 \
|
| 568 |
+
--max_target_length 50 \
|
| 569 |
+
--generation_max_length 50 \
|
| 570 |
+
--add_task_name False \
|
| 571 |
+
--add_dataset_name False \
|
| 572 |
+
--overwrite_output_dir \
|
| 573 |
+
--overwrite_cache \
|
| 574 |
+
--lr_scheduler_type constant \
|
| 575 |
+
--warmup_steps 0 \
|
| 576 |
+
--logging_strategy steps \
|
| 577 |
+
--logging_steps 10 \
|
| 578 |
+
--metric_for_best_model eval_exact_match_for_sst2 \
|
| 579 |
+
--evaluation_strategy epoch \
|
| 580 |
+
--save_strategy epoch \
|
| 581 |
+
--save_total_limit 1 \
|
| 582 |
+
--load_best_model_at_end \
|
| 583 |
+
--lora_r 8 \
|
| 584 |
+
--lora_alpha 32 \
|
| 585 |
+
--lora_dropout 0.0 \
|
| 586 |
+
--data_replay_freq -1 \
|
| 587 |
+
--mlp_hidden_dim 100 \
|
| 588 |
+
--model_name specroute \
|
| 589 |
+
--routing_mode learned \
|
| 590 |
+
--threshold 0.995 \
|
| 591 |
+
--transthreshold 0.995 \
|
| 592 |
+
$FP16_FLAG
|
| 593 |
+
|
| 594 |
+
rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/9-sst2/checkpoint*
|
| 595 |
+
|
| 596 |
+
sleep 5
|
| 597 |
+
|
| 598 |
+
if [ "$GPU_MODE" = "t4_2gpu" ]; then
|
| 599 |
+
BSZ=8; GA=2; EVAL_BSZ=64
|
| 600 |
+
elif [ "$GPU_MODE" = "t4_1gpu" ]; then
|
| 601 |
+
BSZ=8; GA=2; EVAL_BSZ=32
|
| 602 |
+
elif [ "$GPU_MODE" = "p100" ]; then
|
| 603 |
+
BSZ=16; GA=2; EVAL_BSZ=32
|
| 604 |
+
else
|
| 605 |
+
BSZ=64; GA=1; EVAL_BSZ=128
|
| 606 |
+
fi
|
| 607 |
+
|
| 608 |
+
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 609 |
+
--do_train \
|
| 610 |
+
--load_checkpoint_from logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/9-sst2/saved_weights/trans_input.pt \
|
| 611 |
+
--previous_prompt_key_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/9-sst2/saved_weights/prompts_keys_till_now.pt \
|
| 612 |
+
--do_predict \
|
| 613 |
+
--predict_with_generate \
|
| 614 |
+
--model_name_or_path $2 \
|
| 615 |
+
--previous_lora_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/1-yelp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/2-amazon/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/3-mnli/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/4-cb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/5-copa/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/6-qqp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/7-rte/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/8-imdb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/9-sst2/saved_weights \
|
| 616 |
+
--data_dir CL_Benchmark \
|
| 617 |
+
--task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
|
| 618 |
+
--task_config_dir configs/gen_script_long_order3_t5_configs/dbpedia \
|
| 619 |
+
--output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/10-dbpedia \
|
| 620 |
+
--per_device_train_batch_size $BSZ \
|
| 621 |
+
--per_device_eval_batch_size $EVAL_BSZ \
|
| 622 |
+
--gradient_accumulation_steps $GA \
|
| 623 |
+
--learning_rate 0.0003 \
|
| 624 |
+
--num_train_epochs 10 \
|
| 625 |
+
--run_name gen_script_long_order3_t5_small_specroute_v11 \
|
| 626 |
+
--max_source_length 512 \
|
| 627 |
+
--max_target_length 50 \
|
| 628 |
+
--generation_max_length 50 \
|
| 629 |
+
--add_task_name False \
|
| 630 |
+
--add_dataset_name False \
|
| 631 |
+
--overwrite_output_dir \
|
| 632 |
+
--overwrite_cache \
|
| 633 |
+
--lr_scheduler_type constant \
|
| 634 |
+
--warmup_steps 0 \
|
| 635 |
+
--logging_strategy steps \
|
| 636 |
+
--logging_steps 10 \
|
| 637 |
+
--metric_for_best_model eval_exact_match_for_dbpedia \
|
| 638 |
+
--evaluation_strategy epoch \
|
| 639 |
+
--save_strategy epoch \
|
| 640 |
+
--save_total_limit 1 \
|
| 641 |
+
--load_best_model_at_end \
|
| 642 |
+
--lora_r 8 \
|
| 643 |
+
--lora_alpha 32 \
|
| 644 |
+
--lora_dropout 0.0 \
|
| 645 |
+
--data_replay_freq -1 \
|
| 646 |
+
--mlp_hidden_dim 100 \
|
| 647 |
+
--model_name specroute \
|
| 648 |
+
--routing_mode learned \
|
| 649 |
+
--threshold 0.995 \
|
| 650 |
+
--transthreshold 0.995 \
|
| 651 |
+
$FP16_FLAG
|
| 652 |
+
|
| 653 |
+
rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/10-dbpedia/checkpoint*
|
| 654 |
+
|
| 655 |
+
sleep 5
|
| 656 |
+
|
| 657 |
+
if [ "$GPU_MODE" = "t4_2gpu" ]; then
|
| 658 |
+
BSZ=8; GA=2; EVAL_BSZ=64
|
| 659 |
+
elif [ "$GPU_MODE" = "t4_1gpu" ]; then
|
| 660 |
+
BSZ=8; GA=2; EVAL_BSZ=32
|
| 661 |
+
elif [ "$GPU_MODE" = "p100" ]; then
|
| 662 |
+
BSZ=16; GA=2; EVAL_BSZ=32
|
| 663 |
+
else
|
| 664 |
+
BSZ=64; GA=1; EVAL_BSZ=128
|
| 665 |
+
fi
|
| 666 |
+
|
| 667 |
+
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 668 |
+
--do_train \
|
| 669 |
+
--load_checkpoint_from logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/10-dbpedia/saved_weights/trans_input.pt \
|
| 670 |
+
--previous_prompt_key_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/10-dbpedia/saved_weights/prompts_keys_till_now.pt \
|
| 671 |
+
--do_predict \
|
| 672 |
+
--predict_with_generate \
|
| 673 |
+
--model_name_or_path $2 \
|
| 674 |
+
--previous_lora_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/1-yelp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/2-amazon/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/3-mnli/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/4-cb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/5-copa/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/6-qqp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/7-rte/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/8-imdb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/9-sst2/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/10-dbpedia/saved_weights \
|
| 675 |
+
--data_dir CL_Benchmark \
|
| 676 |
+
--task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
|
| 677 |
+
--task_config_dir configs/gen_script_long_order3_t5_configs/agnews \
|
| 678 |
+
--output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/11-agnews \
|
| 679 |
+
--per_device_train_batch_size $BSZ \
|
| 680 |
+
--per_device_eval_batch_size $EVAL_BSZ \
|
| 681 |
+
--gradient_accumulation_steps $GA \
|
| 682 |
+
--learning_rate 0.0003 \
|
| 683 |
+
--num_train_epochs 10 \
|
| 684 |
+
--run_name gen_script_long_order3_t5_small_specroute_v11 \
|
| 685 |
+
--max_source_length 512 \
|
| 686 |
+
--max_target_length 50 \
|
| 687 |
+
--generation_max_length 50 \
|
| 688 |
+
--add_task_name False \
|
| 689 |
+
--add_dataset_name False \
|
| 690 |
+
--overwrite_output_dir \
|
| 691 |
+
--overwrite_cache \
|
| 692 |
+
--lr_scheduler_type constant \
|
| 693 |
+
--warmup_steps 0 \
|
| 694 |
+
--logging_strategy steps \
|
| 695 |
+
--logging_steps 10 \
|
| 696 |
+
--metric_for_best_model eval_exact_match_for_agnews \
|
| 697 |
+
--evaluation_strategy epoch \
|
| 698 |
+
--save_strategy epoch \
|
| 699 |
+
--save_total_limit 1 \
|
| 700 |
+
--load_best_model_at_end \
|
| 701 |
+
--lora_r 8 \
|
| 702 |
+
--lora_alpha 32 \
|
| 703 |
+
--lora_dropout 0.0 \
|
| 704 |
+
--data_replay_freq -1 \
|
| 705 |
+
--mlp_hidden_dim 100 \
|
| 706 |
+
--model_name specroute \
|
| 707 |
+
--routing_mode learned \
|
| 708 |
+
--threshold 0.995 \
|
| 709 |
+
--transthreshold 0.995 \
|
| 710 |
+
$FP16_FLAG
|
| 711 |
+
|
| 712 |
+
rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/11-agnews/checkpoint*
|
| 713 |
+
|
| 714 |
+
sleep 5
|
| 715 |
+
|
| 716 |
+
if [ "$GPU_MODE" = "t4_2gpu" ]; then
|
| 717 |
+
BSZ=8; GA=2; EVAL_BSZ=64
|
| 718 |
+
elif [ "$GPU_MODE" = "t4_1gpu" ]; then
|
| 719 |
+
BSZ=8; GA=2; EVAL_BSZ=32
|
| 720 |
+
elif [ "$GPU_MODE" = "p100" ]; then
|
| 721 |
+
BSZ=16; GA=2; EVAL_BSZ=32
|
| 722 |
+
else
|
| 723 |
+
BSZ=64; GA=1; EVAL_BSZ=128
|
| 724 |
+
fi
|
| 725 |
+
|
| 726 |
+
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 727 |
+
--do_train \
|
| 728 |
+
--load_checkpoint_from logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/11-agnews/saved_weights/trans_input.pt \
|
| 729 |
+
--previous_prompt_key_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/11-agnews/saved_weights/prompts_keys_till_now.pt \
|
| 730 |
+
--do_predict \
|
| 731 |
+
--predict_with_generate \
|
| 732 |
+
--model_name_or_path $2 \
|
| 733 |
+
--previous_lora_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/1-yelp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/2-amazon/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/3-mnli/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/4-cb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/5-copa/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/6-qqp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/7-rte/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/8-imdb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/9-sst2/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/10-dbpedia/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/11-agnews/saved_weights \
|
| 734 |
+
--data_dir CL_Benchmark \
|
| 735 |
+
--task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
|
| 736 |
+
--task_config_dir configs/gen_script_long_order3_t5_configs/yahoo \
|
| 737 |
+
--output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/12-yahoo \
|
| 738 |
+
--per_device_train_batch_size $BSZ \
|
| 739 |
+
--per_device_eval_batch_size $EVAL_BSZ \
|
| 740 |
+
--gradient_accumulation_steps $GA \
|
| 741 |
+
--learning_rate 0.0003 \
|
| 742 |
+
--num_train_epochs 10 \
|
| 743 |
+
--run_name gen_script_long_order3_t5_small_specroute_v11 \
|
| 744 |
+
--max_source_length 512 \
|
| 745 |
+
--max_target_length 50 \
|
| 746 |
+
--generation_max_length 50 \
|
| 747 |
+
--add_task_name False \
|
| 748 |
+
--add_dataset_name False \
|
| 749 |
+
--overwrite_output_dir \
|
| 750 |
+
--overwrite_cache \
|
| 751 |
+
--lr_scheduler_type constant \
|
| 752 |
+
--warmup_steps 0 \
|
| 753 |
+
--logging_strategy steps \
|
| 754 |
+
--logging_steps 10 \
|
| 755 |
+
--metric_for_best_model eval_exact_match_for_yahoo \
|
| 756 |
+
--evaluation_strategy epoch \
|
| 757 |
+
--save_strategy epoch \
|
| 758 |
+
--save_total_limit 1 \
|
| 759 |
+
--load_best_model_at_end \
|
| 760 |
+
--lora_r 8 \
|
| 761 |
+
--lora_alpha 32 \
|
| 762 |
+
--lora_dropout 0.0 \
|
| 763 |
+
--data_replay_freq -1 \
|
| 764 |
+
--mlp_hidden_dim 100 \
|
| 765 |
+
--model_name specroute \
|
| 766 |
+
--routing_mode learned \
|
| 767 |
+
--threshold 0.995 \
|
| 768 |
+
--transthreshold 0.995 \
|
| 769 |
+
$FP16_FLAG
|
| 770 |
+
|
| 771 |
+
rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/12-yahoo/checkpoint*
|
| 772 |
+
|
| 773 |
+
sleep 5
|
| 774 |
+
|
| 775 |
+
if [ "$GPU_MODE" = "t4_2gpu" ]; then
|
| 776 |
+
BSZ=8; GA=2; EVAL_BSZ=64
|
| 777 |
+
elif [ "$GPU_MODE" = "t4_1gpu" ]; then
|
| 778 |
+
BSZ=8; GA=2; EVAL_BSZ=32
|
| 779 |
+
elif [ "$GPU_MODE" = "p100" ]; then
|
| 780 |
+
BSZ=16; GA=2; EVAL_BSZ=32
|
| 781 |
+
else
|
| 782 |
+
BSZ=64; GA=1; EVAL_BSZ=128
|
| 783 |
+
fi
|
| 784 |
+
|
| 785 |
+
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 786 |
+
--do_train \
|
| 787 |
+
--load_checkpoint_from logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/12-yahoo/saved_weights/trans_input.pt \
|
| 788 |
+
--previous_prompt_key_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/12-yahoo/saved_weights/prompts_keys_till_now.pt \
|
| 789 |
+
--do_predict \
|
| 790 |
+
--predict_with_generate \
|
| 791 |
+
--model_name_or_path $2 \
|
| 792 |
+
--previous_lora_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/1-yelp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/2-amazon/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/3-mnli/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/4-cb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/5-copa/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/6-qqp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/7-rte/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/8-imdb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/9-sst2/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/10-dbpedia/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/11-agnews/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/12-yahoo/saved_weights \
|
| 793 |
+
--data_dir CL_Benchmark \
|
| 794 |
+
--task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
|
| 795 |
+
--task_config_dir configs/gen_script_long_order3_t5_configs/multirc \
|
| 796 |
+
--output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/13-multirc \
|
| 797 |
+
--per_device_train_batch_size $BSZ \
|
| 798 |
+
--per_device_eval_batch_size $EVAL_BSZ \
|
| 799 |
+
--gradient_accumulation_steps $GA \
|
| 800 |
+
--learning_rate 0.0003 \
|
| 801 |
+
--num_train_epochs 10 \
|
| 802 |
+
--run_name gen_script_long_order3_t5_small_specroute_v11 \
|
| 803 |
+
--max_source_length 512 \
|
| 804 |
+
--max_target_length 50 \
|
| 805 |
+
--generation_max_length 50 \
|
| 806 |
+
--add_task_name False \
|
| 807 |
+
--add_dataset_name False \
|
| 808 |
+
--overwrite_output_dir \
|
| 809 |
+
--overwrite_cache \
|
| 810 |
+
--lr_scheduler_type constant \
|
| 811 |
+
--warmup_steps 0 \
|
| 812 |
+
--logging_strategy steps \
|
| 813 |
+
--logging_steps 10 \
|
| 814 |
+
--metric_for_best_model eval_exact_match_for_multirc \
|
| 815 |
+
--evaluation_strategy epoch \
|
| 816 |
+
--save_strategy epoch \
|
| 817 |
+
--save_total_limit 1 \
|
| 818 |
+
--load_best_model_at_end \
|
| 819 |
+
--lora_r 8 \
|
| 820 |
+
--lora_alpha 32 \
|
| 821 |
+
--lora_dropout 0.0 \
|
| 822 |
+
--data_replay_freq -1 \
|
| 823 |
+
--mlp_hidden_dim 100 \
|
| 824 |
+
--model_name specroute \
|
| 825 |
+
--routing_mode learned \
|
| 826 |
+
--threshold 0.995 \
|
| 827 |
+
--transthreshold 0.995 \
|
| 828 |
+
$FP16_FLAG
|
| 829 |
+
|
| 830 |
+
rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/13-multirc/checkpoint*
|
| 831 |
+
|
| 832 |
+
sleep 5
|
| 833 |
+
|
| 834 |
+
if [ "$GPU_MODE" = "t4_2gpu" ]; then
|
| 835 |
+
BSZ=8; GA=2; EVAL_BSZ=64
|
| 836 |
+
elif [ "$GPU_MODE" = "t4_1gpu" ]; then
|
| 837 |
+
BSZ=8; GA=2; EVAL_BSZ=32
|
| 838 |
+
elif [ "$GPU_MODE" = "p100" ]; then
|
| 839 |
+
BSZ=16; GA=2; EVAL_BSZ=32
|
| 840 |
+
else
|
| 841 |
+
BSZ=64; GA=1; EVAL_BSZ=128
|
| 842 |
+
fi
|
| 843 |
+
|
| 844 |
+
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 845 |
+
--do_train \
|
| 846 |
+
--load_checkpoint_from logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/13-multirc/saved_weights/trans_input.pt \
|
| 847 |
+
--previous_prompt_key_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/13-multirc/saved_weights/prompts_keys_till_now.pt \
|
| 848 |
+
--do_predict \
|
| 849 |
+
--predict_with_generate \
|
| 850 |
+
--model_name_or_path $2 \
|
| 851 |
+
--previous_lora_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/1-yelp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/2-amazon/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/3-mnli/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/4-cb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/5-copa/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/6-qqp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/7-rte/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/8-imdb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/9-sst2/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/10-dbpedia/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/11-agnews/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/12-yahoo/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/13-multirc/saved_weights \
|
| 852 |
+
--data_dir CL_Benchmark \
|
| 853 |
+
--task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
|
| 854 |
+
--task_config_dir configs/gen_script_long_order3_t5_configs/boolq \
|
| 855 |
+
--output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/14-boolq \
|
| 856 |
+
--per_device_train_batch_size $BSZ \
|
| 857 |
+
--per_device_eval_batch_size $EVAL_BSZ \
|
| 858 |
+
--gradient_accumulation_steps $GA \
|
| 859 |
+
--learning_rate 0.0003 \
|
| 860 |
+
--num_train_epochs 10 \
|
| 861 |
+
--run_name gen_script_long_order3_t5_small_specroute_v11 \
|
| 862 |
+
--max_source_length 512 \
|
| 863 |
+
--max_target_length 50 \
|
| 864 |
+
--generation_max_length 50 \
|
| 865 |
+
--add_task_name False \
|
| 866 |
+
--add_dataset_name False \
|
| 867 |
+
--overwrite_output_dir \
|
| 868 |
+
--overwrite_cache \
|
| 869 |
+
--lr_scheduler_type constant \
|
| 870 |
+
--warmup_steps 0 \
|
| 871 |
+
--logging_strategy steps \
|
| 872 |
+
--logging_steps 10 \
|
| 873 |
+
--metric_for_best_model eval_exact_match_for_boolq \
|
| 874 |
+
--evaluation_strategy epoch \
|
| 875 |
+
--save_strategy epoch \
|
| 876 |
+
--save_total_limit 1 \
|
| 877 |
+
--load_best_model_at_end \
|
| 878 |
+
--lora_r 8 \
|
| 879 |
+
--lora_alpha 32 \
|
| 880 |
+
--lora_dropout 0.0 \
|
| 881 |
+
--data_replay_freq -1 \
|
| 882 |
+
--mlp_hidden_dim 100 \
|
| 883 |
+
--model_name specroute \
|
| 884 |
+
--routing_mode learned \
|
| 885 |
+
--threshold 0.995 \
|
| 886 |
+
--transthreshold 0.995 \
|
| 887 |
+
$FP16_FLAG
|
| 888 |
+
|
| 889 |
+
rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/14-boolq/checkpoint*
|
| 890 |
+
|
| 891 |
+
sleep 5
|
| 892 |
+
|
| 893 |
+
if [ "$GPU_MODE" = "t4_2gpu" ]; then
|
| 894 |
+
BSZ=8; GA=2; EVAL_BSZ=64
|
| 895 |
+
elif [ "$GPU_MODE" = "t4_1gpu" ]; then
|
| 896 |
+
BSZ=8; GA=2; EVAL_BSZ=32
|
| 897 |
+
elif [ "$GPU_MODE" = "p100" ]; then
|
| 898 |
+
BSZ=16; GA=2; EVAL_BSZ=32
|
| 899 |
+
else
|
| 900 |
+
BSZ=64; GA=1; EVAL_BSZ=128
|
| 901 |
+
fi
|
| 902 |
+
|
| 903 |
+
CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
|
| 904 |
+
--do_train \
|
| 905 |
+
--load_checkpoint_from logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/14-boolq/saved_weights/trans_input.pt \
|
| 906 |
+
--previous_prompt_key_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/14-boolq/saved_weights/prompts_keys_till_now.pt \
|
| 907 |
+
--do_predict \
|
| 908 |
+
--predict_with_generate \
|
| 909 |
+
--model_name_or_path $2 \
|
| 910 |
+
--previous_lora_path logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/1-yelp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/2-amazon/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/3-mnli/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/4-cb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/5-copa/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/6-qqp/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/7-rte/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/8-imdb/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/9-sst2/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/10-dbpedia/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/11-agnews/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/12-yahoo/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/13-multirc/saved_weights,logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/14-boolq/saved_weights \
|
| 911 |
+
--data_dir CL_Benchmark \
|
| 912 |
+
--task_order yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic \
|
| 913 |
+
--task_config_dir configs/gen_script_long_order3_t5_configs/wic \
|
| 914 |
+
--output_dir logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/15-wic \
|
| 915 |
+
--per_device_train_batch_size $BSZ \
|
| 916 |
+
--per_device_eval_batch_size $EVAL_BSZ \
|
| 917 |
+
--gradient_accumulation_steps $GA \
|
| 918 |
+
--learning_rate 0.0003 \
|
| 919 |
+
--num_train_epochs 10 \
|
| 920 |
+
--run_name gen_script_long_order3_t5_small_specroute_v11 \
|
| 921 |
+
--max_source_length 512 \
|
| 922 |
+
--max_target_length 50 \
|
| 923 |
+
--generation_max_length 50 \
|
| 924 |
+
--add_task_name False \
|
| 925 |
+
--add_dataset_name False \
|
| 926 |
+
--overwrite_output_dir \
|
| 927 |
+
--overwrite_cache \
|
| 928 |
+
--lr_scheduler_type constant \
|
| 929 |
+
--warmup_steps 0 \
|
| 930 |
+
--logging_strategy steps \
|
| 931 |
+
--logging_steps 10 \
|
| 932 |
+
--metric_for_best_model eval_exact_match_for_wic \
|
| 933 |
+
--evaluation_strategy epoch \
|
| 934 |
+
--save_strategy epoch \
|
| 935 |
+
--save_total_limit 1 \
|
| 936 |
+
--load_best_model_at_end \
|
| 937 |
+
--lora_r 8 \
|
| 938 |
+
--lora_alpha 32 \
|
| 939 |
+
--lora_dropout 0.0 \
|
| 940 |
+
--data_replay_freq -1 \
|
| 941 |
+
--mlp_hidden_dim 100 \
|
| 942 |
+
--model_name specroute \
|
| 943 |
+
--routing_mode learned \
|
| 944 |
+
--threshold 0.995 \
|
| 945 |
+
--transthreshold 0.995 \
|
| 946 |
+
$FP16_FLAG
|
| 947 |
+
|
| 948 |
+
rm -rf logs_and_outputs/gen_script_long_order3_t5_small_specroute_v11/outputs/15-wic/checkpoint*
|
| 949 |
+
|
| 950 |
+
sleep 5
|
improve_gainlora/results/gen_script_long_order3_t5_small_specroute_v2.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:26c66b8453b30e2331a8a08d2a425b8c375aced3dd9c26346c0f544d0d4b524f
|
| 3 |
+
size 182
|
improve_gainlora/results/gen_script_long_order3_t5_small_specroute_v5.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:efaab52e503376bf2a0446839a2825c9efe5a2558c86595fcbe35892d093e41f
|
| 3 |
+
size 174
|
improve_gainlora/src/cl_trainer_specroute.py
CHANGED
|
@@ -18,6 +18,7 @@ from torch.utils.data.distributed import DistributedSampler
|
|
| 18 |
from transformers import GenerationConfig
|
| 19 |
from transformers.trainer_seq2seq import Seq2SeqTrainer
|
| 20 |
from transformers.trainer import *
|
|
|
|
| 21 |
from transformers.trainer_pt_utils import (
|
| 22 |
nested_truncate, nested_concat, nested_numpify,
|
| 23 |
find_batch_size,
|
|
@@ -82,11 +83,15 @@ class PeriodicGCCallback(TrainerCallback):
|
|
| 82 |
|
| 83 |
|
| 84 |
class TransInputGPMCallback(TrainerCallback):
|
| 85 |
-
"""V10a: Apply GPM projection to trans_input and prompt_key after optimizer step.
|
|
|
|
|
|
|
| 86 |
def __init__(self, trainer):
|
| 87 |
self.trainer = trainer
|
| 88 |
|
| 89 |
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
|
|
|
|
|
|
| 90 |
if getattr(self.trainer, "cur_task_id", 0) > 1 and getattr(self.trainer.model.encoder, "routing_mode", "") == "learned":
|
| 91 |
from copy import deepcopy
|
| 92 |
self.trainer._old_trans_input_0 = deepcopy(self.trainer.model.encoder.trans_input[0].weight.detach())
|
|
@@ -94,6 +99,8 @@ class TransInputGPMCallback(TrainerCallback):
|
|
| 94 |
self.trainer._old_prompt_key = deepcopy(self.trainer.model.encoder.prompt_key.detach())
|
| 95 |
|
| 96 |
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
|
|
|
|
|
|
| 97 |
if getattr(self.trainer, "cur_task_id", 0) > 1 and getattr(self.trainer.model.encoder, "routing_mode", "") == "learned":
|
| 98 |
if not hasattr(self.trainer, "feature_trans_mat") or not self.trainer.feature_trans_mat:
|
| 99 |
return
|
|
@@ -313,6 +320,109 @@ class SpecRoute_Trainer(Seq2SeqTrainer):
|
|
| 313 |
|
| 314 |
print(f'[C5] Covariance collected for {len(self._task_covariance)} layers.')
|
| 315 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
def load_previous_reg_matrix(self):
|
| 317 |
"""Load LoRA GPM bases from previous task. Also load trans_input GPM if learned routing."""
|
| 318 |
reg_matrix = []
|
|
@@ -389,11 +499,26 @@ class SpecRoute_Trainer(Seq2SeqTrainer):
|
|
| 389 |
|
| 390 |
def get_reg_matrix(self):
|
| 391 |
"""
|
| 392 |
-
Project current LoRA A into null-space of old tasks' GPM bases.
|
| 393 |
-
|
| 394 |
"""
|
| 395 |
self.feature_list, self.feature_trans_list, self._cur_task = self.load_previous_reg_matrix()
|
| 396 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 397 |
if len(self.feature_list) == 0:
|
| 398 |
# First task: no constraints
|
| 399 |
return
|
|
|
|
| 18 |
from transformers import GenerationConfig
|
| 19 |
from transformers.trainer_seq2seq import Seq2SeqTrainer
|
| 20 |
from transformers.trainer import *
|
| 21 |
+
from typing import Optional, List, Tuple
|
| 22 |
from transformers.trainer_pt_utils import (
|
| 23 |
nested_truncate, nested_concat, nested_numpify,
|
| 24 |
find_batch_size,
|
|
|
|
| 83 |
|
| 84 |
|
| 85 |
class TransInputGPMCallback(TrainerCallback):
|
| 86 |
+
"""V10a: Apply GPM projection to trans_input and prompt_key after optimizer step.
|
| 87 |
+
V11: Disabled by default (use_routing_gpm=False). Hard GPM on routing kills
|
| 88 |
+
discriminative capacity β catastrophic forgetting. See V10a analysis."""
|
| 89 |
def __init__(self, trainer):
|
| 90 |
self.trainer = trainer
|
| 91 |
|
| 92 |
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
| 93 |
+
if not getattr(self.trainer, "use_routing_gpm", False):
|
| 94 |
+
return control
|
| 95 |
if getattr(self.trainer, "cur_task_id", 0) > 1 and getattr(self.trainer.model.encoder, "routing_mode", "") == "learned":
|
| 96 |
from copy import deepcopy
|
| 97 |
self.trainer._old_trans_input_0 = deepcopy(self.trainer.model.encoder.trans_input[0].weight.detach())
|
|
|
|
| 99 |
self.trainer._old_prompt_key = deepcopy(self.trainer.model.encoder.prompt_key.detach())
|
| 100 |
|
| 101 |
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
| 102 |
+
if not getattr(self.trainer, "use_routing_gpm", False):
|
| 103 |
+
return control
|
| 104 |
if getattr(self.trainer, "cur_task_id", 0) > 1 and getattr(self.trainer.model.encoder, "routing_mode", "") == "learned":
|
| 105 |
if not hasattr(self.trainer, "feature_trans_mat") or not self.trainer.feature_trans_mat:
|
| 106 |
return
|
|
|
|
| 320 |
|
| 321 |
print(f'[C5] Covariance collected for {len(self._task_covariance)} layers.')
|
| 322 |
|
| 323 |
+
# ================================================================
|
| 324 |
+
# V11: ROOT-style Prompt-Key Re-initialization
|
| 325 |
+
# ================================================================
|
| 326 |
+
|
| 327 |
+
def _reinit_prompt_key(self):
|
| 328 |
+
"""Re-initialize prompt_key using SVD of trans_input output covariance.
|
| 329 |
+
|
| 330 |
+
ROOT's key insight: prompt_key must be in the null-space of previous
|
| 331 |
+
routing features to ensure orthogonal task separation.
|
| 332 |
+
|
| 333 |
+
Task 1: prompt_key = top eigenvector of trans_input output covariance C_3.
|
| 334 |
+
This aligns the routing key with the dominant direction of the MLP's
|
| 335 |
+
output space β maximizes discriminability for the first task.
|
| 336 |
+
Formally: p_1 = argmax_{||p||=1} p^T C_3 p (Rayleigh quotient)
|
| 337 |
+
|
| 338 |
+
Task t>1: prompt_key = top eigenvector of random matrix projected into
|
| 339 |
+
null-space of old routing features.
|
| 340 |
+
p_t = U_1 of SVD(Q_old Β· R) where Q_old = I - P_old, R ~ N(0,1)
|
| 341 |
+
This guarantees: p_t β₯ span({p_1,...,p_{t-1}}) up to GPM threshold.
|
| 342 |
+
"""
|
| 343 |
+
module = self.model.encoder
|
| 344 |
+
if not hasattr(module, 'prompt_key'):
|
| 345 |
+
return
|
| 346 |
+
|
| 347 |
+
# Ensure chunk dimensions are set up
|
| 348 |
+
module.get_chunk(self.args.chunk)
|
| 349 |
+
|
| 350 |
+
# Collect trans_input output covariance (200 batches)
|
| 351 |
+
module.get_trans_feature = True
|
| 352 |
+
module.stage_trans = 0
|
| 353 |
+
|
| 354 |
+
print('[V11] Collecting trans_input covariance for prompt_key init...')
|
| 355 |
+
train_dataloader = self.get_train_dataloader()
|
| 356 |
+
if isinstance(train_dataloader, DataLoader) and isinstance(
|
| 357 |
+
train_dataloader.sampler, DistributedSampler
|
| 358 |
+
):
|
| 359 |
+
train_dataloader.sampler.set_epoch(77)
|
| 360 |
+
|
| 361 |
+
with torch.no_grad():
|
| 362 |
+
for step, inputs in enumerate(train_dataloader):
|
| 363 |
+
inputs = self._prepare_inputs(inputs)
|
| 364 |
+
inputs.pop('labels', None)
|
| 365 |
+
self.model(**inputs)
|
| 366 |
+
if step >= 200:
|
| 367 |
+
break
|
| 368 |
+
|
| 369 |
+
pre_norm = module.prompt_key.detach().norm()
|
| 370 |
+
|
| 371 |
+
if len(self.feature_trans_list) == 0:
|
| 372 |
+
# === TASK 1: Data-informed init ===
|
| 373 |
+
# prompt_key = top eigenvector of output covariance (matrix_trans_3)
|
| 374 |
+
for index in module.matrix_trans_3.keys():
|
| 375 |
+
cur_trans_matrix = module.matrix_trans_3[index]
|
| 376 |
+
cur_trans_matrix = torch.nan_to_num(cur_trans_matrix, nan=0.0, posinf=1e6, neginf=-1e6)
|
| 377 |
+
try:
|
| 378 |
+
U, S, V = torch.linalg.svd(cur_trans_matrix)
|
| 379 |
+
except Exception:
|
| 380 |
+
cpu_mat = cur_trans_matrix.detach().cpu().float()
|
| 381 |
+
U, S, V = torch.linalg.svd(cpu_mat)
|
| 382 |
+
U = U.to(device=cur_trans_matrix.device, dtype=cur_trans_matrix.dtype)
|
| 383 |
+
module.prompt_key.data[:, index*module.step:(index+1)*module.step].copy_(U[:, :1].T)
|
| 384 |
+
print('[V11] Task 1: prompt_key = top eigvec of trans_input output covariance.')
|
| 385 |
+
else:
|
| 386 |
+
# === TASK t>1: Null-space orthogonal init ===
|
| 387 |
+
# Build projection matrix P_old from saved routing GPM bases
|
| 388 |
+
feature_trans_mat_2 = {}
|
| 389 |
+
if len(self.feature_trans_list) >= 3:
|
| 390 |
+
for index in self.feature_trans_list[2].keys():
|
| 391 |
+
feature_trans_mat_2[index] = torch.mm(
|
| 392 |
+
self.feature_trans_list[2][index],
|
| 393 |
+
self.feature_trans_list[2][index].T
|
| 394 |
+
).to("cuda:0")
|
| 395 |
+
|
| 396 |
+
for index in module.matrix_trans_3.keys():
|
| 397 |
+
cur_trans_matrix = torch.randn_like(module.matrix_trans_3[index])
|
| 398 |
+
if index in feature_trans_mat_2:
|
| 399 |
+
# Q_old * R: project random matrix into null-space
|
| 400 |
+
cur_trans_matrix = cur_trans_matrix - torch.mm(
|
| 401 |
+
feature_trans_mat_2[index], cur_trans_matrix
|
| 402 |
+
)
|
| 403 |
+
try:
|
| 404 |
+
U, S, V = torch.linalg.svd(cur_trans_matrix)
|
| 405 |
+
except Exception:
|
| 406 |
+
cpu_mat = cur_trans_matrix.detach().cpu().float()
|
| 407 |
+
U, S, V = torch.linalg.svd(cpu_mat)
|
| 408 |
+
U = U.to(device=cur_trans_matrix.device, dtype=cur_trans_matrix.dtype)
|
| 409 |
+
module.prompt_key.data[:, index*module.step:(index+1)*module.step].copy_(U[:, :1].T)
|
| 410 |
+
print(f'[V11] Task {self.cur_task_id+1}: prompt_key = top eigvec in null-space of old routing features.')
|
| 411 |
+
|
| 412 |
+
# Normalize to preserve original scale (ROOT convention)
|
| 413 |
+
module.prompt_key.data /= math.sqrt(module.chunk_trans)
|
| 414 |
+
module.prompt_key.data *= pre_norm
|
| 415 |
+
|
| 416 |
+
# Cleanup covariance accumulators
|
| 417 |
+
for index in list(module.matrix_trans_3.keys()):
|
| 418 |
+
module.matrix_trans_1[index].zero_()
|
| 419 |
+
module.matrix_trans_3[index].zero_()
|
| 420 |
+
module.n_trans_matrix[index] = 0
|
| 421 |
+
module.matrix_trans_2.zero_()
|
| 422 |
+
module.get_trans_feature = False
|
| 423 |
+
module.stage_trans = 0
|
| 424 |
+
print(f'[V11] prompt_key re-initialized. norm={module.prompt_key.data.norm().item():.4f}')
|
| 425 |
+
|
| 426 |
def load_previous_reg_matrix(self):
|
| 427 |
"""Load LoRA GPM bases from previous task. Also load trans_input GPM if learned routing."""
|
| 428 |
reg_matrix = []
|
|
|
|
| 499 |
|
| 500 |
def get_reg_matrix(self):
|
| 501 |
"""
|
| 502 |
+
V11: Project current LoRA A into null-space of old tasks' GPM bases.
|
| 503 |
+
Also re-initialize prompt_key for learned routing (ROOT-style SVD).
|
| 504 |
"""
|
| 505 |
self.feature_list, self.feature_trans_list, self._cur_task = self.load_previous_reg_matrix()
|
| 506 |
|
| 507 |
+
# ================================================================
|
| 508 |
+
# V11: Prompt-key re-initialization (ROOT-style)
|
| 509 |
+
# ================================================================
|
| 510 |
+
# ROOT achieves low forgetting because:
|
| 511 |
+
# 1. prompt_key is initialized in the null-space of old routing features
|
| 512 |
+
# β orthogonal to old keys β naturally separable tasks
|
| 513 |
+
# 2. trans_input (MLP) is free to learn without GPM constraint
|
| 514 |
+
# β discriminative routing features
|
| 515 |
+
#
|
| 516 |
+
# Math: For task t, prompt_key_t β null(P_old) where P_old = Ξ£ U_k U_k^T
|
| 517 |
+
# This ensures cos(prompt_key_t, prompt_key_k) β 0 for k < t
|
| 518 |
+
# β different tasks activate different experts.
|
| 519 |
+
if getattr(self.model.encoder, "routing_mode", "") == "learned":
|
| 520 |
+
self._reinit_prompt_key()
|
| 521 |
+
|
| 522 |
if len(self.feature_list) == 0:
|
| 523 |
# First task: no constraints
|
| 524 |
return
|
parse_and_score_v2.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
def parse_log(log_path):
|
| 5 |
+
with open(log_path, 'r') as f:
|
| 6 |
+
content = f.read()
|
| 7 |
+
|
| 8 |
+
# Task order as defined in script
|
| 9 |
+
tasks = ["yelp", "amazon", "mnli", "cb", "copa", "qqp", "rte", "imdb", "sst2", "dbpedia", "agnews", "yahoo", "multirc", "boolq", "wic"]
|
| 10 |
+
|
| 11 |
+
# Split content into segments, each ending with a "predict metrics" block
|
| 12 |
+
# We look for "predict_exact_match_for_CL" as an anchor for each step evaluation
|
| 13 |
+
segments = re.split(r'predict_exact_match_for_CL\s+=\s+\d+\.\d+', content)
|
| 14 |
+
|
| 15 |
+
# The last segment might be empty if there's nothing after the final metrics
|
| 16 |
+
if not segments[-1].strip():
|
| 17 |
+
segments = segments[:-1]
|
| 18 |
+
|
| 19 |
+
# We expect 15 evaluations
|
| 20 |
+
print(f"Found {len(segments)} evaluation segments in {log_path}")
|
| 21 |
+
|
| 22 |
+
matrix = []
|
| 23 |
+
for seg in segments:
|
| 24 |
+
scores = []
|
| 25 |
+
for task in tasks:
|
| 26 |
+
match = re.search(fr'predict_exact_match_for_{task}\s+=\s+(\d+\.\d+|\d+)', seg)
|
| 27 |
+
if match:
|
| 28 |
+
scores.append(float(match.group(1)))
|
| 29 |
+
else:
|
| 30 |
+
scores.append(0.0)
|
| 31 |
+
if any(s > 0 for s in scores): # Only add if we found at least one score
|
| 32 |
+
matrix.append(scores)
|
| 33 |
+
|
| 34 |
+
# If it's the final evaluation only (like in some logs), we might have only 1 segment
|
| 35 |
+
return matrix, tasks
|
| 36 |
+
|
| 37 |
+
def calculate_metrics(matrix):
|
| 38 |
+
if not matrix:
|
| 39 |
+
return None
|
| 40 |
+
|
| 41 |
+
task_num = len(matrix[0])
|
| 42 |
+
# final_scores is the last row provided (if the run ended at 15, it's matrix[14])
|
| 43 |
+
final_scores = matrix[-1]
|
| 44 |
+
AP = sum(final_scores) / task_num
|
| 45 |
+
|
| 46 |
+
# Forgetting: max(history) - final
|
| 47 |
+
fgt_list = []
|
| 48 |
+
for t_idx in range(task_num - 1):
|
| 49 |
+
history = [row[t_idx] for row in matrix]
|
| 50 |
+
best = max(history)
|
| 51 |
+
final = final_scores[t_idx]
|
| 52 |
+
fgt_list.append(best - final)
|
| 53 |
+
|
| 54 |
+
Fgt = sum(fgt_list) / len(fgt_list) if fgt_list else 0.0
|
| 55 |
+
|
| 56 |
+
# User's definition of forgetting in markdown: Final - Initial?
|
| 57 |
+
# Let's calculate that too just in case
|
| 58 |
+
fgt_user_list = []
|
| 59 |
+
for t_idx in range(task_num - 1):
|
| 60 |
+
initial = matrix[t_idx][t_idx] if t_idx < len(matrix) else 0.0
|
| 61 |
+
final = final_scores[t_idx]
|
| 62 |
+
fgt_user_list.append(final - initial)
|
| 63 |
+
Fgt_user = sum(fgt_user_list) / len(fgt_user_list) if fgt_user_list else 0.0
|
| 64 |
+
|
| 65 |
+
return {
|
| 66 |
+
"AP": AP,
|
| 67 |
+
"Fgt (Best-Final)": Fgt,
|
| 68 |
+
"Fgt_user (Final-Initial)": Fgt_user,
|
| 69 |
+
"Final Scores": final_scores
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
log_v10 = "/Users/nnminh322/Desktop/personal/Continual/improve_gainlora/logs/t5_small_improve/improve_gainlora_v10.log"
|
| 73 |
+
matrix, tasks = parse_log(log_v10)
|
| 74 |
+
metrics = calculate_metrics(matrix)
|
| 75 |
+
print("--- V10 Metrics ---")
|
| 76 |
+
if metrics:
|
| 77 |
+
print(f"AP (EM): {metrics['AP']:.4f}")
|
| 78 |
+
print(f"Fgt (Best-Final): {metrics['Fgt (Best-Final)']:.4f}")
|
| 79 |
+
print(f"Fgt (Final-Initial): {metrics['Fgt_user (Final-Initial)']:.4f}")
|
| 80 |
+
print("Final Scores:", metrics['Final Scores'])
|
| 81 |
+
else:
|
| 82 |
+
print("Failed to parse matrix for V10")
|
| 83 |
+
|
| 84 |
+
# Also do V5 for comparison
|
| 85 |
+
log_v5_dir = "/Users/nnminh322/Desktop/personal/Continual/improve_gainlora/logs/t5_small_improve/gen_script_long_order3_t5_small_specroute_v5/"
|
| 86 |
+
# We need to find the log file inside v5 dir.
|
| 87 |
+
# It's likely in outputs/ or similar.
|
recalculate_em.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
def load_json(path):
|
| 5 |
+
with open(path, 'r') as f:
|
| 6 |
+
return json.load(f)
|
| 7 |
+
|
| 8 |
+
def get_matrix_from_outputs(base_dir, run_name, tasks, metric='exact_match'):
|
| 9 |
+
matrix = []
|
| 10 |
+
for i in range(len(tasks)):
|
| 11 |
+
row = []
|
| 12 |
+
res_file = f"{base_dir}/{run_name}/outputs/{i+1}-{tasks[i]}/all_results.json"
|
| 13 |
+
if not os.path.exists(res_file):
|
| 14 |
+
matrix.append([0.0]*len(tasks))
|
| 15 |
+
continue
|
| 16 |
+
data = load_json(res_file)
|
| 17 |
+
for j in range(i + 1):
|
| 18 |
+
key = f"predict_{metric}_for_{tasks[j]}"
|
| 19 |
+
row.append(data.get(key, 0.0))
|
| 20 |
+
row.extend([0.0]*(len(tasks)-len(row)))
|
| 21 |
+
matrix.append(row)
|
| 22 |
+
return matrix
|
| 23 |
+
|
| 24 |
+
def calculate_stats(matrix):
|
| 25 |
+
task_num = len(matrix[0])
|
| 26 |
+
final_row = matrix[-1]
|
| 27 |
+
AP = sum(final_row) / task_num
|
| 28 |
+
|
| 29 |
+
fgt_list = []
|
| 30 |
+
for j in range(task_num - 1):
|
| 31 |
+
history = [row[j] for row in matrix if row[j] > 0]
|
| 32 |
+
if not history:
|
| 33 |
+
continue
|
| 34 |
+
best = max(history)
|
| 35 |
+
final = final_row[j]
|
| 36 |
+
fgt_list.append(best - final)
|
| 37 |
+
|
| 38 |
+
Fgt = sum(fgt_list) / len(fgt_list) if fgt_list else 0.0
|
| 39 |
+
return AP, Fgt
|
| 40 |
+
|
| 41 |
+
tasks = ["yelp", "amazon", "mnli", "cb", "copa", "qqp", "rte", "imdb", "sst2", "dbpedia", "agnews", "yahoo", "multirc", "boolq", "wic"]
|
| 42 |
+
|
| 43 |
+
# V5 (EM)
|
| 44 |
+
v5_dir = "/Users/nnminh322/Desktop/personal/Continual/improve_gainlora/logs/t5_small_improve"
|
| 45 |
+
v5_run = "gen_script_long_order3_t5_small_specroute_v5"
|
| 46 |
+
|
| 47 |
+
print("--- V5 (EM) ---")
|
| 48 |
+
try:
|
| 49 |
+
v5_matrix = get_matrix_from_outputs(v5_dir, v5_run, tasks, 'exact_match')
|
| 50 |
+
v5_ap_em, v5_fgt_em = calculate_stats(v5_matrix)
|
| 51 |
+
print(f"V5 AP(EM): {v5_ap_em:.4f}")
|
| 52 |
+
print(f"V5 Fgt(EM): {v5_fgt_em:.4f}")
|
| 53 |
+
except Exception as e:
|
| 54 |
+
print(f"V5 failed: {e}")
|
| 55 |
+
|
| 56 |
+
# ROOT (EM) - Based on User's Markdown since I can't find some ROOT JSONs
|
| 57 |
+
# Actually, let's try to parse ROOT logs if any, but 59.7 is definitely the target.
|
| 58 |
+
print("--- ROOT (EM) Target ---")
|
| 59 |
+
print("ROOT AP(EM): 59.70")
|
| 60 |
+
|
| 61 |
+
# V10a (EM) - From Log
|
| 62 |
+
v10_final_em = {
|
| 63 |
+
"agnews": 38.7237,
|
| 64 |
+
"amazon": 29.0263,
|
| 65 |
+
"boolq": 62.4465,
|
| 66 |
+
"cb": 0.0,
|
| 67 |
+
"copa": 55.0,
|
| 68 |
+
"dbpedia": 40.5395,
|
| 69 |
+
"imdb": 90.0789,
|
| 70 |
+
"mnli": 32.1316,
|
| 71 |
+
"multirc": 59.1172,
|
| 72 |
+
"qqp": 64.3158,
|
| 73 |
+
"rte": 52.7076,
|
| 74 |
+
"sst2": 83.945,
|
| 75 |
+
"wic": 56.4263,
|
| 76 |
+
"yahoo": 64.8947,
|
| 77 |
+
"yelp": 21.3289
|
| 78 |
+
}
|
| 79 |
+
# Order: yelp, amazon, mnli, cb, copa, qqp, rte, imdb, sst2, dbpedia, agnews, yahoo, multirc, boolq, wic
|
| 80 |
+
ordered_v10_em = [21.3289, 29.0263, 32.1316, 0.0, 55.0, 64.3158, 52.7076, 90.0789, 83.9450, 40.5395, 38.7237, 64.8947, 59.1172, 62.4465, 56.4263]
|
| 81 |
+
v10_ap_em = sum(ordered_v10_em) / 15
|
| 82 |
+
print(f"V10a AP(EM): {v10_ap_em:.4f}")
|
results/experiment_versions.md
CHANGED
|
@@ -377,3 +377,120 @@ V8 fail imdb/sst2/yahoo do B_t khΓ΄ng hα»c (gradient bα» block). V9 oracle rou
|
|
| 377 |
### V10b (Grassmannian Distance Routing - The Zero-Replay Ideal)
|
| 378 |
- **Method**: Evaluates similarity by computing the Grassmannian distance (principal angles) between the batch's local principal subspace $U_{batch}$ and expert orthogonal projection $U_A$.
|
| 379 |
- **Why**: Directly measures subset geometric alignment, entirely bypassing scale-based similarity issues (GPM-Routing paradox). Batch-level SVD aggregates representations properly. Valid for batched inference ($B \ge 8$), falling back to A-row for small batches.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 377 |
### V10b (Grassmannian Distance Routing - The Zero-Replay Ideal)
|
| 378 |
- **Method**: Evaluates similarity by computing the Grassmannian distance (principal angles) between the batch's local principal subspace $U_{batch}$ and expert orthogonal projection $U_A$.
|
| 379 |
- **Why**: Directly measures subset geometric alignment, entirely bypassing scale-based similarity issues (GPM-Routing paradox). Batch-level SVD aggregates representations properly. Valid for batched inference ($B \ge 8$), falling back to A-row for small batches.
|
| 380 |
+
|
| 381 |
+
### V10a Results
|
| 382 |
+
|
| 383 |
+
| Task | Final EM | Best EM | Forgetting |
|
| 384 |
+
|------|------:|------:|------:|
|
| 385 |
+
| yelp | 33.45 | 56.49 | 23.04 |
|
| 386 |
+
| amazon | 35.37 | 53.05 | 17.68 |
|
| 387 |
+
| mnli | 30.54 | 49.11 | 18.57 |
|
| 388 |
+
| cb | 0.00 | 57.14 | 57.14 |
|
| 389 |
+
| copa | 55.00 | 55.00 | 0.00 |
|
| 390 |
+
| qqp | 11.95 | 78.84 | 66.89 |
|
| 391 |
+
| rte | 10.11 | 57.76 | 47.65 |
|
| 392 |
+
| imdb | 89.89 | 91.51 | 1.62 |
|
| 393 |
+
| sst2 | 65.25 | 88.88 | 23.62 |
|
| 394 |
+
| dbpedia | 40.70 | 98.47 | 57.78 |
|
| 395 |
+
| agnews | 42.67 | 90.05 | 47.38 |
|
| 396 |
+
| yahoo | 61.88 | 66.01 | 4.13 |
|
| 397 |
+
| multirc | 43.13 | 59.12 | 15.99 |
|
| 398 |
+
| boolq | 62.45 | 62.45 | 0.00 |
|
| 399 |
+
| wic | 56.43 | 56.43 | 0.00 |
|
| 400 |
+
| **Cl (EM)** | **42.59** | | **27.25** |
|
| 401 |
+
|
| 402 |
+
**V10a is CATASTROPHIC**: Cl=42.59 (vs ROOT 59.70), FT=27.25 (vs ROOT ~low, V5 0.91).
|
| 403 |
+
|
| 404 |
+
### V10a Root Cause Analysis
|
| 405 |
+
|
| 406 |
+
**100% of forgetting comes from routing failure**, not weight overwriting (LoRA B matrices for old tasks are frozen in `previous_lora_weights`).
|
| 407 |
+
|
| 408 |
+
**Three critical differences from ROOT:**
|
| 409 |
+
|
| 410 |
+
1. **TransInputGPMCallback (THE KILLER)**: V10a applies GPM projection to `trans_input` + `prompt_key` every training step with threshold=0.995. By task 9, ~95% of routing feature space is locked β routing effectively frozen β cannot distinguish new tasks. ROOT does NOT constrain routing during training.
|
| 411 |
+
|
| 412 |
+
2. **Missing prompt_key re-initialization**: ROOT re-initializes `prompt_key` before each task using SVD of trans_input output covariance (task 1) or random-in-null-space (task 2+). V10a starts from `nn.init.uniform_(-1, 1)` every task β no data-informed, orthogonal starting point.
|
| 413 |
+
|
| 414 |
+
3. **No trans_input covariance collection**: ROOT collects 1000 batches of trans_input feature covariance for prompt_key initialization. V10a only collects LoRA covariance (for C5).
|
| 415 |
+
|
| 416 |
+
**The deadly combination**: Random prompt_key + Over-constrained routing = Bad starting point + Cannot learn = Routing failure = Catastrophic forgetting.
|
| 417 |
+
|
| 418 |
+
---
|
| 419 |
+
|
| 420 |
+
## V11 β ROOT Routing + C5 Init + Advanced Inference Routing
|
| 421 |
+
|
| 422 |
+
### Motivation
|
| 423 |
+
|
| 424 |
+
V10a proved that GPM on routing is fundamentally wrong: routing needs discriminative capacity, not orthogonality constraints. V11 reverts to ROOT's proven routing mechanism while keeping C5 (data-informed LoRA A init) and C4 (gradient preconditioning) for improved per-task expert quality. Additionally, V11 introduces two advanced inference-time routing strategies grounded in information theory.
|
| 425 |
+
|
| 426 |
+
### Base Fix (all V11 variants)
|
| 427 |
+
1. **Disable TransInputGPMCallback**: `use_routing_gpm = False` (default)
|
| 428 |
+
2. **ROOT-style prompt_key re-init**: SVD of trans_input output covariance (task 1) or null-space random SVD (task 2+)
|
| 429 |
+
3. **Keep C5**: Data-informed A init via Constrained PCA in null-space
|
| 430 |
+
4. **Keep C4**: Gradient preconditioning (AA^T + Ξ΅I)^{-1/2}
|
| 431 |
+
|
| 432 |
+
### V11a: Base (ROOT routing + C5)
|
| 433 |
+
**Script**: `T5_small/gen_script_long_order3_t5_small_specroute_v11a.sh`
|
| 434 |
+
**Args**: `--routing_mode learned --routing_strategy base`
|
| 435 |
+
**Expected**: β ROOT AP (routing identical), potentially better due to C5.
|
| 436 |
+
|
| 437 |
+
### V11b: Softmax Routing Normalization (Option B)
|
| 438 |
+
|
| 439 |
+
**Script**: `T5_small/gen_script_long_order3_t5_small_specroute_v11b.sh`
|
| 440 |
+
**Args**: `--routing_mode learned --routing_strategy softmax --routing_temp 0.1`
|
| 441 |
+
|
| 442 |
+
**Mathematical formulation:**
|
| 443 |
+
ROOT uses independent sigmoid routing: $w_k = |\sigma(4 \cos(x_k, p_k)) \cdot 2 - 1|$.
|
| 444 |
+
Each task gets weight in [0,1] independently β multiple experts may contribute equally β cross-expert interference.
|
| 445 |
+
|
| 446 |
+
V11b converts to competitive softmax gating (standard MoE):
|
| 447 |
+
$$p_k = \frac{\exp(s_k / \tau)}{\sum_j \exp(s_j / \tau)}$$
|
| 448 |
+
where $s_k = \text{logit}(w_k) = \log w_k - \log(1 - w_k)$ and $\tau$ is temperature.
|
| 449 |
+
|
| 450 |
+
**Information-theoretic justification:**
|
| 451 |
+
Let $Y$ = model output, $T$ = task, $X$ = input. Output: $Y = \sum_k p_k f_k(X)$.
|
| 452 |
+
$$H(Y|X) \geq \sum_k p_k H(f_k(X)|X) \quad \text{(concavity of entropy)}$$
|
| 453 |
+
Cross-expert interference term: $\sum_{j \neq k} p_j \|f_j(X) - f_k(X)\|^2$.
|
| 454 |
+
Minimizing this β‘ concentrating $p$ on argmax (one expert dominates) β‘ lower $\tau$.
|
| 455 |
+
In the limit $\tau \to 0$: softmax β argmax (hard top-1 routing, zero interference).
|
| 456 |
+
|
| 457 |
+
**Expected improvement**: Lower FT due to sharper expert selection.
|
| 458 |
+
|
| 459 |
+
### V11c: Product-of-Experts Ensemble (Option C)
|
| 460 |
+
|
| 461 |
+
**Script**: `T5_small/gen_script_long_order3_t5_small_specroute_v11c.sh`
|
| 462 |
+
**Args**: `--routing_mode learned --routing_strategy ensemble --routing_temp 0.1 --ensemble_weight 0.7`
|
| 463 |
+
|
| 464 |
+
**Mathematical formulation:**
|
| 465 |
+
Fuse learned ($p_L$) and spectral ($p_S$) routing via Product-of-Experts (Hinton, 2002):
|
| 466 |
+
$$p_{\text{ens}}(T=k|x) \propto p_L(T=k|x)^\gamma \cdot p_S(T=k|x)^{1-\gamma}$$
|
| 467 |
+
In log space:
|
| 468 |
+
$$\log p_{\text{ens}} = \gamma \cdot \frac{s_L^{(k)}}{\tau} + (1-\gamma) \cdot \frac{s_S^{(k)}}{\tau} + \text{const}$$
|
| 469 |
+
|
| 470 |
+
**Bayesian justification:**
|
| 471 |
+
If learned and spectral routing encode independent evidence about task identity $T$:
|
| 472 |
+
$$p(T|x) \propto p_L(T|x) \cdot p_S(T|x) \quad \text{(posterior = product of likelihoods)}$$
|
| 473 |
+
This is the classical Product-of-Experts derivation (assuming uniform prior on T).
|
| 474 |
+
|
| 475 |
+
**Complementary error profiles:**
|
| 476 |
+
- Learned routing: excels on recently trained tasks (MLP adapts); degrades on distant old tasks (feature drift)
|
| 477 |
+
- Spectral routing: parameter-free β zero drift; weaker on same-domain tasks (GPM forces $A_k \perp A_j$)
|
| 478 |
+
- When both agree: high confidence β nearly always correct
|
| 479 |
+
- When they disagree: hedged prediction β reduces worst-case error
|
| 480 |
+
|
| 481 |
+
**Channel capacity argument:**
|
| 482 |
+
Each routing method has limited channel capacity $C_L, C_S$ for encoding task identity.
|
| 483 |
+
Ensemble capacity: $C_{\text{ens}} \geq \max(C_L, C_S)$ (data processing inequality) with equality iff one subsumes the other.
|
| 484 |
+
Since learned and spectral use orthogonal feature spaces (MLP output vs A-row projection), $C_{\text{ens}} > \max(C_L, C_S)$.
|
| 485 |
+
|
| 486 |
+
**Expected improvement**: Both AP β (better routing accuracy) and FT β (spectral stabilizes learned).
|
| 487 |
+
|
| 488 |
+
### Hyperparameters
|
| 489 |
+
All V11 variants:
|
| 490 |
+
- lora_r = 8, lora_alpha = 32
|
| 491 |
+
- lr = 3e-4, epochs = 10
|
| 492 |
+
- threshold = 0.995, transthreshold = 0.995
|
| 493 |
+
- mlp_hidden_dim = 100
|
| 494 |
+
|
| 495 |
+
V11b specific: routing_temp = 0.1
|
| 496 |
+
V11c specific: routing_temp = 0.1, ensemble_weight = 0.7
|