natmin322 commited on
Commit
9f858c3
·
1 Parent(s): 2c4cffd

rls t5 large

Browse files
improve_gainlora/src/run_t5.py CHANGED
@@ -1132,7 +1132,19 @@ def main():
1132
  all_metrics.update(metrics)
1133
 
1134
  if training_args.model_name in ['inflora', 'gainlora_inflora', 'gainlora_olora', 'specroute']:
1135
- trainer.get_repsentation()
 
 
 
 
 
 
 
 
 
 
 
 
1136
 
1137
  # Evaluation
1138
  results = {}
 
1132
  all_metrics.update(metrics)
1133
 
1134
  if training_args.model_name in ['inflora', 'gainlora_inflora', 'gainlora_olora', 'specroute']:
1135
+ try:
1136
+ print("[GPM] Starting get_repsentation()...")
1137
+ sys.stdout.flush()
1138
+ trainer.get_repsentation()
1139
+ print("[GPM] get_repsentation() completed successfully.")
1140
+ sys.stdout.flush()
1141
+ except Exception as _gpm_exc:
1142
+ import traceback
1143
+ print(f"\n[GPM ERROR] get_repsentation() FAILED (task {cur_task_id}, {cur_task}):")
1144
+ print(f" {type(_gpm_exc).__name__}: {_gpm_exc}")
1145
+ traceback.print_exc(file=sys.stdout)
1146
+ print("[GPM] Continuing to predict block despite GPM error...\n")
1147
+ sys.stdout.flush()
1148
 
1149
  # Evaluation
1150
  results = {}
improve_gainlora/src/t5_specroute.py CHANGED
@@ -308,8 +308,13 @@ class T5Stack(T5PreTrainedModel):
308
 
309
  if self.routing_mode == "rls":
310
  # V11: Analytical Ridge Regression Routing
311
- rls_expansion_dim = prompt_config.get("rls_expansion_dim", 2048)
 
 
 
312
  rls_lambda = prompt_config.get("rls_lambda", 0.1)
 
 
313
  self.rls_router = RLSRouter(
314
  d_model=config.d_model,
315
  expansion_dim=rls_expansion_dim,
 
308
 
309
  if self.routing_mode == "rls":
310
  # V11: Analytical Ridge Regression Routing
311
+ # expansion_dim scales with d_model to maintain 4x expansion ratio
312
+ # across all model sizes (T5-small: 512->2048, T5-large: 1024->4096, etc.)
313
+ _user_expansion_dim = prompt_config.get("rls_expansion_dim", 2048)
314
+ rls_expansion_dim = max(_user_expansion_dim, 4 * config.d_model)
315
  rls_lambda = prompt_config.get("rls_lambda", 0.1)
316
+ print(f"[RLS] d_model={config.d_model}, expansion_dim={rls_expansion_dim} "
317
+ f"(ratio={rls_expansion_dim/config.d_model:.1f}x, user_requested={_user_expansion_dim})")
318
  self.rls_router = RLSRouter(
319
  d_model=config.d_model,
320
  expansion_dim=rls_expansion_dim,