natmin322 commited on
Commit
9be56eb
·
1 Parent(s): 9de5c3c
improve_gainlora/src/t5_specroute.py CHANGED
@@ -724,15 +724,38 @@ class T5Stack(T5PreTrainedModel):
724
  if not self.is_decoder and not self.prompt_config["run_single"]:
725
  if self.routing_mode == "rls":
726
  # V11: RLS analytical routing
727
- if self.rls_router.num_tasks > 0:
728
- # Mean-pool for routing feature
 
 
 
 
 
 
729
  h_pool = avg_inputs_embeds.squeeze(1) # (B, d_model)
730
- key_attention_weights = self.rls_router.route(h_pool) # (B, n_tasks, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
731
  else:
732
- # First task: single expert
733
- key_attention_weights = torch.ones(
734
- batch_size, 1, 1, device=inputs_embeds.device, dtype=inputs_embeds.dtype
 
 
735
  )
 
736
  elif self.routing_mode == "learned":
737
  key_attention_weights = self.compute_learned_routing(avg_inputs_embeds, batch_size)
738
  else:
 
724
  if not self.is_decoder and not self.prompt_config["run_single"]:
725
  if self.routing_mode == "rls":
726
  # V11: RLS analytical routing
727
+ # Convention: index 0 = current task, index 1+ = previous tasks
728
+ # RLS router stores tasks 0..T-1 from previous training runs.
729
+ # During task T: model has T+1 LoRAs [current, prev_0, ..., prev_{T-1}]
730
+ # Total experts = 1 (current) + len(spectral_signatures) (previous)
731
+ n_lora_experts = 1 + len(self.spectral_signatures)
732
+
733
+ if self.rls_router.num_tasks > 0 and not self.training:
734
+ # Inference: RLS should have all tasks including current
735
  h_pool = avg_inputs_embeds.squeeze(1) # (B, d_model)
736
+ rls_weights = self.rls_router.route(h_pool) # (B, num_tasks, 1)
737
+ n_reg = rls_weights.shape[1]
738
+ if n_reg == n_lora_experts:
739
+ # After update_rls_router: all tasks registered
740
+ # RLS order: [task_0, task_1, ..., task_T] (chronological)
741
+ # Model order: [current(task_T), prev_0(task_{T-1}), prev_1(task_{T-2}), ..., prev_{T-1}(task_0)]
742
+ # Previous LoRAs are loaded in REVERSE chronological order (see run_t5.py .reverse())
743
+ # So: model[0]=RLS[-1], model[1]=RLS[-2], ..., model[T]=RLS[0]
744
+ key_attention_weights = rls_weights.flip(dims=[1])
745
+ else:
746
+ # Fallback: size mismatch, use uniform
747
+ key_attention_weights = torch.ones(
748
+ batch_size, n_lora_experts, 1,
749
+ device=inputs_embeds.device, dtype=inputs_embeds.dtype
750
+ ) / n_lora_experts
751
  else:
752
+ # Training or first task: oracle routing (always current = index 0)
753
+ # Just create correct-sized tensor; oracle override below will set it
754
+ key_attention_weights = torch.zeros(
755
+ batch_size, n_lora_experts, 1,
756
+ device=inputs_embeds.device, dtype=inputs_embeds.dtype
757
  )
758
+ key_attention_weights[:, 0, 0] = 1.0
759
  elif self.routing_mode == "learned":
760
  key_attention_weights = self.compute_learned_routing(avg_inputs_embeds, batch_size)
761
  else: