natmin322 commited on
Commit
2c4cffd
·
1 Parent(s): f666767

rls t5 large

Browse files
improve_gainlora/gen_script_long_order3_t5_large_rls.sh ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH -J cl-rls-large
3
+ #SBATCH -o cl-rls-large-%j.out
4
+ #SBATCH -p compute
5
+ #SBATCH -N 1
6
+ #SBATCH -t 30:00:00
7
+ #SBATCH --mem 128G
8
+ #SBATCH --gres=gpu:2
9
+
10
+ # ============================================================
11
+ # SpecRoute V11: RLS Analytical Router + InfLoRA/CPI/GPM
12
+ # Long Sequence Order 3 — T5-LARGE (optimized, no redundancy)
13
+ # ============================================================
14
+
15
+ export CUDA_DEVICE_ORDER="PCI_BUS_ID"
16
+
17
+ # Auto-detect GPU count and type for optimal parallelism
18
+ NUM_GPUS=$(nvidia-smi -L 2>/dev/null | wc -l)
19
+ GPU_MEM=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits 2>/dev/null | head -1)
20
+
21
+ if [ -z "$GPU_MEM" ]; then
22
+ echo "ERROR: No GPU detected!"
23
+ exit 1
24
+ fi
25
+
26
+ # Determine GPU type and parallelism strategy
27
+ if [ "$GPU_MEM" -lt 15500 ]; then
28
+ GPU_MODE="t4_2gpu"
29
+ GPU_IDS="0,1"
30
+ STRAT="2x T4 DataParallel"
31
+ # T5-large won't fit well on T4 16GB with training
32
+ BSZ=1; GA=16; EVAL_BSZ=8
33
+ elif [ "$GPU_MEM" -le 17000 ]; then
34
+ GPU_MODE="p100"
35
+ GPU_IDS="${1:-0}"
36
+ STRAT="P100 16GB"
37
+ BSZ=4; GA=8; EVAL_BSZ=16
38
+ else
39
+ GPU_MODE="a100"
40
+ if [ "$NUM_GPUS" -ge 2 ]; then
41
+ GPU_IDS="0,1"
42
+ STRAT="${NUM_GPUS}x ${GPU_MEM}MB DataParallel"
43
+ else
44
+ GPU_IDS="${1:-0}"
45
+ STRAT="1x ${GPU_MEM}MB GPU"
46
+ fi
47
+ # T5-large: higher batch sizes with extra VRAM
48
+ if [ "$GPU_MEM" -ge 40000 ]; then
49
+ BSZ=32; GA=1; EVAL_BSZ=64
50
+ else
51
+ BSZ=16; GA=2; EVAL_BSZ=32
52
+ fi
53
+ fi
54
+
55
+ echo "[GPU] Detected (~${GPU_MEM}MB per GPU): $STRAT"
56
+ echo "[HP] BSZ=$BSZ, GA=$GA, EVAL_BSZ=$EVAL_BSZ"
57
+ echo "[GPU] Using CUDA_VISIBLE_DEVICES=$GPU_IDS"
58
+ echo "============================================================"
59
+
60
+ # ============================================================
61
+ # Configuration
62
+ # ============================================================
63
+ RLS_EXPANSION_DIM=2048
64
+ RLS_LAMBDA=0.1
65
+ ROUTING_MODE=rls
66
+
67
+ TASK_ORDER=yelp,amazon,mnli,cb,copa,qqp,rte,imdb,sst2,dbpedia,agnews,yahoo,multirc,boolq,wic
68
+ RUN_NAME=gen_script_long_order3_t5_large_rls
69
+ CONFIG_BASE=configs/gen_script_long_order3_t5_configs
70
+ OUTPUT_BASE=logs_and_outputs/${RUN_NAME}/outputs
71
+
72
+ # Common hyperparameters (all tasks)
73
+ COMMON_ARGS="
74
+ --do_train
75
+ --predict_with_generate
76
+ --model_name_or_path $2
77
+ --data_dir CL_Benchmark
78
+ --task_order ${TASK_ORDER}
79
+ --per_device_train_batch_size $BSZ
80
+ --per_device_eval_batch_size $EVAL_BSZ
81
+ --gradient_accumulation_steps $GA
82
+ --learning_rate 0.0003
83
+ --num_train_epochs 10
84
+ --run_name ${RUN_NAME}
85
+ --max_source_length 512
86
+ --max_target_length 50
87
+ --generation_max_length 50
88
+ --add_task_name False
89
+ --add_dataset_name False
90
+ --overwrite_output_dir
91
+ --overwrite_cache
92
+ --lr_scheduler_type constant
93
+ --warmup_steps 0
94
+ --logging_strategy steps
95
+ --logging_steps 10
96
+ --evaluation_strategy steps
97
+ --save_strategy steps
98
+ --save_total_limit 1
99
+ --load_best_model_at_end
100
+ --lora_r 8
101
+ --lora_alpha 32
102
+ --lora_dropout 0.0
103
+ --run_single False
104
+ --n_batches_c5 100
105
+ --data_replay_freq -1
106
+ --mlp_hidden_dim 100
107
+ --model_name specroute
108
+ --routing_mode ${ROUTING_MODE}
109
+ --rls_expansion_dim ${RLS_EXPANSION_DIM}
110
+ --rls_lambda ${RLS_LAMBDA}
111
+ --cpi_gamma 0.5
112
+ --oap_eta 0.5
113
+ --oap_beta_min 0.3
114
+ --oap_warmup 3
115
+ --threshold 0.995
116
+ --transthreshold 0.995
117
+ --do_predict
118
+ "
119
+
120
+ # ============================================================
121
+ # Generate previous_lora_path string for each task
122
+ # ============================================================
123
+ build_prev_lora_list() {
124
+ local task_num=$1
125
+ local list=""
126
+ for i in $(seq 1 $((task_num - 1))); do
127
+ if [ $i -gt 1 ]; then
128
+ list="${list},"
129
+ fi
130
+ list="${list}${OUTPUT_BASE}/$i-${TASKS[$((i-1))]}/saved_weights"
131
+ done
132
+ echo "$list"
133
+ }
134
+
135
+ # ============================================================
136
+ # Task array: indexed from 0
137
+ # ============================================================
138
+ TASKS=(yelp amazon mnli cb copa qqp rte imdb sst2 dbpedia agnews yahoo multirc boolq wic)
139
+
140
+ # ============================================================
141
+ # Run all 15 tasks
142
+ # ============================================================
143
+ for task_idx in ${!TASKS[@]}; do
144
+ task_num=$((task_idx + 1))
145
+ task_name=${TASKS[$task_idx]}
146
+
147
+ echo ""
148
+ echo "============================================================"
149
+ echo "Task $task_num: $task_name"
150
+ echo "============================================================"
151
+
152
+ # Build metric key and previous_lora_path
153
+ metric_key="eval_exact_match"
154
+ if [ $task_num -gt 1 ]; then
155
+ metric_key="${metric_key}_for_${task_name}"
156
+ prev_lora=$(build_prev_lora_list $task_num)
157
+ prev_lora_arg="--previous_lora_path $prev_lora"
158
+ else
159
+ prev_lora_arg=""
160
+ fi
161
+
162
+ # Task 1 has different metric (no suffix)
163
+ if [ $task_num -eq 1 ]; then
164
+ metric_key="eval_exact_match"
165
+ fi
166
+
167
+ # Run training + prediction
168
+ CUDA_VISIBLE_DEVICES=$GPU_IDS python src/run_t5.py \
169
+ $COMMON_ARGS \
170
+ $prev_lora_arg \
171
+ --task_config_dir ${CONFIG_BASE}/${task_name} \
172
+ --output_dir ${OUTPUT_BASE}/${task_num}-${task_name} \
173
+ --metric_for_best_model $metric_key
174
+
175
+ # Cleanup checkpoints (save space)
176
+ rm -rf ${OUTPUT_BASE}/${task_num}-${task_name}/checkpoint*
177
+
178
+ # Brief pause before next task
179
+ sleep 2
180
+ done
181
+
182
+ echo ""
183
+ echo "============================================================"
184
+ echo "All 15 tasks completed!"
185
+ echo "============================================================"
improve_gainlora/src/run_t5.py CHANGED
@@ -1116,6 +1116,15 @@ def main():
1116
  )
1117
  metrics["train_samples"] = min(max_train_samples, len(train_dataset))
1118
 
 
 
 
 
 
 
 
 
 
1119
  trainer.log_metrics("train", metrics)
1120
  trainer.save_metrics("train", metrics)
1121
  trainer.save_state()
@@ -1141,6 +1150,15 @@ def main():
1141
  print("*** Prediction ***")
1142
  logger.info("*** Prediction ***")
1143
  logger.info("*** Loading CheckPoint ***")
 
 
 
 
 
 
 
 
 
1144
 
1145
  if data_args.max_predict_samples is not None:
1146
  predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
@@ -1169,20 +1187,47 @@ def main():
1169
  trainer.model.encoder.is_inference = False
1170
 
1171
  if training_args.do_predict:
1172
- predict_results = trainer.predict(
1173
- predict_dataset,
1174
- metric_key_prefix="predict",
1175
- max_new_tokens=max_new_tokens,
1176
- num_beams=num_beams,
1177
- repetition_penalty=repetition_penalty,
1178
- pad_token_id=tokenizer.pad_token_id
1179
- )
1180
- metrics = predict_results.metrics
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1181
  max_predict_samples = (
1182
  data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
1183
  )
1184
  metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
1185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1186
  trainer.log(metrics)
1187
  trainer.log_metrics("predict", metrics)
1188
  trainer.save_metrics("predict", metrics)
@@ -1217,6 +1262,21 @@ def main():
1217
  # Reset for next eval round
1218
  trainer.model.encoder._routing_decisions = []
1219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1220
  return results
1221
 
1222
 
 
1116
  )
1117
  metrics["train_samples"] = min(max_train_samples, len(train_dataset))
1118
 
1119
+ # Print training metrics to stdout
1120
+ print("\n[TRAIN METRICS]")
1121
+ for key, value in sorted(metrics.items()):
1122
+ if isinstance(value, float):
1123
+ print(f" {key}: {value:.6f}")
1124
+ else:
1125
+ print(f" {key}: {value}")
1126
+ sys.stdout.flush()
1127
+
1128
  trainer.log_metrics("train", metrics)
1129
  trainer.save_metrics("train", metrics)
1130
  trainer.save_state()
 
1150
  print("*** Prediction ***")
1151
  logger.info("*** Prediction ***")
1152
  logger.info("*** Loading CheckPoint ***")
1153
+
1154
+ # [DIAG] Check model device state before prediction
1155
+ try:
1156
+ model_device = next(trainer.model.parameters()).device
1157
+ print(f"[DIAG-DEVICE] Model is on device: {model_device}")
1158
+ logger.info(f"[DIAG-DEVICE] Model is on device: {model_device}")
1159
+ sys.stdout.flush()
1160
+ except Exception as e:
1161
+ print(f"[DIAG-DEVICE] Could not get model device: {e}")
1162
 
1163
  if data_args.max_predict_samples is not None:
1164
  predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
 
1187
  trainer.model.encoder.is_inference = False
1188
 
1189
  if training_args.do_predict:
1190
+ try:
1191
+ logger.info("Starting prediction on %d samples", len(predict_dataset))
1192
+ print(f"[PREDICT] Starting prediction on {len(predict_dataset)} samples")
1193
+ sys.stdout.flush()
1194
+
1195
+ predict_results = trainer.predict(
1196
+ predict_dataset,
1197
+ metric_key_prefix="predict",
1198
+ max_new_tokens=max_new_tokens,
1199
+ num_beams=num_beams,
1200
+ repetition_penalty=repetition_penalty,
1201
+ pad_token_id=tokenizer.pad_token_id
1202
+ )
1203
+ logger.info("Prediction completed successfully")
1204
+ print("[PREDICT] Prediction completed successfully")
1205
+ sys.stdout.flush()
1206
+ metrics = predict_results.metrics
1207
+ except Exception as e:
1208
+ logger.error(f"Error during prediction: {e}", exc_info=True)
1209
+ print(f"[ERROR] Prediction failed: {e}")
1210
+ import traceback
1211
+ traceback.print_exc()
1212
+ raise
1213
  max_predict_samples = (
1214
  data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
1215
  )
1216
  metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
1217
 
1218
+ # Print prediction metrics to stdout
1219
+ print("\n" + "="*80)
1220
+ print(f"TASK {cur_task_id}: {cur_task}")
1221
+ print("="*80)
1222
+ print("[PREDICT METRICS]")
1223
+ for key, value in sorted(metrics.items()):
1224
+ if isinstance(value, float):
1225
+ print(f" {key}: {value:.6f}")
1226
+ else:
1227
+ print(f" {key}: {value}")
1228
+ print("="*80 + "\n")
1229
+ sys.stdout.flush()
1230
+
1231
  trainer.log(metrics)
1232
  trainer.log_metrics("predict", metrics)
1233
  trainer.save_metrics("predict", metrics)
 
1262
  # Reset for next eval round
1263
  trainer.model.encoder._routing_decisions = []
1264
 
1265
+ # ===== FINAL METRICS SUMMARY =====
1266
+ print("\n" + "="*80)
1267
+ print("FINAL METRICS SUMMARY")
1268
+ print("="*80)
1269
+ if all_metrics:
1270
+ for key, value in sorted(all_metrics.items()):
1271
+ if isinstance(value, float):
1272
+ print(f" {key}: {value:.6f}")
1273
+ else:
1274
+ print(f" {key}: {value}")
1275
+ else:
1276
+ print(" (No metrics available)")
1277
+ print("="*80 + "\n")
1278
+ sys.stdout.flush()
1279
+
1280
  return results
1281
 
1282