rls t5 large
Browse files
improve_gainlora/src/run_t5.py
CHANGED
|
@@ -1132,7 +1132,19 @@ def main():
|
|
| 1132 |
all_metrics.update(metrics)
|
| 1133 |
|
| 1134 |
if training_args.model_name in ['inflora', 'gainlora_inflora', 'gainlora_olora', 'specroute']:
|
| 1135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1136 |
|
| 1137 |
# Evaluation
|
| 1138 |
results = {}
|
|
|
|
| 1132 |
all_metrics.update(metrics)
|
| 1133 |
|
| 1134 |
if training_args.model_name in ['inflora', 'gainlora_inflora', 'gainlora_olora', 'specroute']:
|
| 1135 |
+
try:
|
| 1136 |
+
print("[GPM] Starting get_repsentation()...")
|
| 1137 |
+
sys.stdout.flush()
|
| 1138 |
+
trainer.get_repsentation()
|
| 1139 |
+
print("[GPM] get_repsentation() completed successfully.")
|
| 1140 |
+
sys.stdout.flush()
|
| 1141 |
+
except Exception as _gpm_exc:
|
| 1142 |
+
import traceback
|
| 1143 |
+
print(f"\n[GPM ERROR] get_repsentation() FAILED (task {cur_task_id}, {cur_task}):")
|
| 1144 |
+
print(f" {type(_gpm_exc).__name__}: {_gpm_exc}")
|
| 1145 |
+
traceback.print_exc(file=sys.stdout)
|
| 1146 |
+
print("[GPM] Continuing to predict block despite GPM error...\n")
|
| 1147 |
+
sys.stdout.flush()
|
| 1148 |
|
| 1149 |
# Evaluation
|
| 1150 |
results = {}
|
improve_gainlora/src/t5_specroute.py
CHANGED
|
@@ -308,8 +308,13 @@ class T5Stack(T5PreTrainedModel):
|
|
| 308 |
|
| 309 |
if self.routing_mode == "rls":
|
| 310 |
# V11: Analytical Ridge Regression Routing
|
| 311 |
-
|
|
|
|
|
|
|
|
|
|
| 312 |
rls_lambda = prompt_config.get("rls_lambda", 0.1)
|
|
|
|
|
|
|
| 313 |
self.rls_router = RLSRouter(
|
| 314 |
d_model=config.d_model,
|
| 315 |
expansion_dim=rls_expansion_dim,
|
|
|
|
| 308 |
|
| 309 |
if self.routing_mode == "rls":
|
| 310 |
# V11: Analytical Ridge Regression Routing
|
| 311 |
+
# expansion_dim scales with d_model to maintain 4x expansion ratio
|
| 312 |
+
# across all model sizes (T5-small: 512->2048, T5-large: 1024->4096, etc.)
|
| 313 |
+
_user_expansion_dim = prompt_config.get("rls_expansion_dim", 2048)
|
| 314 |
+
rls_expansion_dim = max(_user_expansion_dim, 4 * config.d_model)
|
| 315 |
rls_lambda = prompt_config.get("rls_lambda", 0.1)
|
| 316 |
+
print(f"[RLS] d_model={config.d_model}, expansion_dim={rls_expansion_dim} "
|
| 317 |
+
f"(ratio={rls_expansion_dim/config.d_model:.1f}x, user_requested={_user_expansion_dim})")
|
| 318 |
self.rls_router = RLSRouter(
|
| 319 |
d_model=config.d_model,
|
| 320 |
expansion_dim=rls_expansion_dim,
|