| | import torch |
| | import torch.nn as nn |
| |
|
| | class OTitansTriArchRouter(nn.Module): |
| | """ |
| | Phase 3: The Tri-Arch Router. |
| | Dynamically routes forward passes between the frozen base model, |
| | the Memory OTITANS gate, and potential Skill OTITANS gates. |
| | """ |
| | def __init__(self, base_model, memory_gate, skill_gate=None): |
| | super().__init__() |
| | self.base_model = base_model |
| | |
| | |
| | self.memory_gate = memory_gate |
| | self.skill_gate = skill_gate |
| | |
| | |
| | |
| | self.current_memory_alpha = 1.0 |
| | self.current_skill_alpha = 0.0 |
| |
|
| | def set_routing_alphas(self, memory_alpha: float, skill_alpha: float): |
| | """Dynamically adjust the routing gates before a forward pass.""" |
| | self.current_memory_alpha = memory_alpha |
| | self.current_skill_alpha = skill_alpha |
| | |
| | |
| |
|
| | def forward(self, input_ids, **kwargs): |
| | |
| | |
| | base_outputs = self.base_model( |
| | input_ids, |
| | output_hidden_states=True, |
| | return_dict=True, |
| | **kwargs |
| | ) |
| | |
| | |
| | hidden_states = base_outputs.hidden_states[-1] |
| | |
| | |
| | if self.current_memory_alpha > 0.0 and self.memory_gate is not None: |
| | |
| | memory_states = self.memory_gate(hidden_states) |
| | |
| | hidden_states = hidden_states + (memory_states * self.current_memory_alpha) |
| | |
| | |
| | if self.current_skill_alpha > 0.0 and self.skill_gate is not None: |
| | skill_states = self.skill_gate(hidden_states) |
| | hidden_states = hidden_states + (skill_states * self.current_skill_alpha) |
| | |
| | |
| | |
| | logits = self.base_model.lm_head(hidden_states) |
| | |
| | return logits |
| |
|