rls
Browse files
improve_gainlora/src/t5_specroute.py
CHANGED
|
@@ -724,15 +724,38 @@ class T5Stack(T5PreTrainedModel):
|
|
| 724 |
if not self.is_decoder and not self.prompt_config["run_single"]:
|
| 725 |
if self.routing_mode == "rls":
|
| 726 |
# V11: RLS analytical routing
|
| 727 |
-
|
| 728 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 729 |
h_pool = avg_inputs_embeds.squeeze(1) # (B, d_model)
|
| 730 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 731 |
else:
|
| 732 |
-
#
|
| 733 |
-
|
| 734 |
-
|
|
|
|
|
|
|
| 735 |
)
|
|
|
|
| 736 |
elif self.routing_mode == "learned":
|
| 737 |
key_attention_weights = self.compute_learned_routing(avg_inputs_embeds, batch_size)
|
| 738 |
else:
|
|
|
|
| 724 |
if not self.is_decoder and not self.prompt_config["run_single"]:
|
| 725 |
if self.routing_mode == "rls":
|
| 726 |
# V11: RLS analytical routing
|
| 727 |
+
# Convention: index 0 = current task, index 1+ = previous tasks
|
| 728 |
+
# RLS router stores tasks 0..T-1 from previous training runs.
|
| 729 |
+
# During task T: model has T+1 LoRAs [current, prev_0, ..., prev_{T-1}]
|
| 730 |
+
# Total experts = 1 (current) + len(spectral_signatures) (previous)
|
| 731 |
+
n_lora_experts = 1 + len(self.spectral_signatures)
|
| 732 |
+
|
| 733 |
+
if self.rls_router.num_tasks > 0 and not self.training:
|
| 734 |
+
# Inference: RLS should have all tasks including current
|
| 735 |
h_pool = avg_inputs_embeds.squeeze(1) # (B, d_model)
|
| 736 |
+
rls_weights = self.rls_router.route(h_pool) # (B, num_tasks, 1)
|
| 737 |
+
n_reg = rls_weights.shape[1]
|
| 738 |
+
if n_reg == n_lora_experts:
|
| 739 |
+
# After update_rls_router: all tasks registered
|
| 740 |
+
# RLS order: [task_0, task_1, ..., task_T] (chronological)
|
| 741 |
+
# Model order: [current(task_T), prev_0(task_{T-1}), prev_1(task_{T-2}), ..., prev_{T-1}(task_0)]
|
| 742 |
+
# Previous LoRAs are loaded in REVERSE chronological order (see run_t5.py .reverse())
|
| 743 |
+
# So: model[0]=RLS[-1], model[1]=RLS[-2], ..., model[T]=RLS[0]
|
| 744 |
+
key_attention_weights = rls_weights.flip(dims=[1])
|
| 745 |
+
else:
|
| 746 |
+
# Fallback: size mismatch, use uniform
|
| 747 |
+
key_attention_weights = torch.ones(
|
| 748 |
+
batch_size, n_lora_experts, 1,
|
| 749 |
+
device=inputs_embeds.device, dtype=inputs_embeds.dtype
|
| 750 |
+
) / n_lora_experts
|
| 751 |
else:
|
| 752 |
+
# Training or first task: oracle routing (always current = index 0)
|
| 753 |
+
# Just create correct-sized tensor; oracle override below will set it
|
| 754 |
+
key_attention_weights = torch.zeros(
|
| 755 |
+
batch_size, n_lora_experts, 1,
|
| 756 |
+
device=inputs_embeds.device, dtype=inputs_embeds.dtype
|
| 757 |
)
|
| 758 |
+
key_attention_weights[:, 0, 0] = 1.0
|
| 759 |
elif self.routing_mode == "learned":
|
| 760 |
key_attention_weights = self.compute_learned_routing(avg_inputs_embeds, batch_size)
|
| 761 |
else:
|