Gemma3MoOLET / otitans_router.py
paperscarecrow's picture
Upload 47 files
9659b2b verified
import torch
import torch.nn as nn
class OTitansTriArchRouter(nn.Module):
"""
Phase 3: The Tri-Arch Router.
Dynamically routes forward passes between the frozen base model,
the Memory OTITANS gate, and potential Skill OTITANS gates.
"""
def __init__(self, base_model, memory_gate, skill_gate=None):
super().__init__()
self.base_model = base_model
# The Isolated Parallel Highways
self.memory_gate = memory_gate
self.skill_gate = skill_gate
# Manual Alpha Controls for the Router
# In a fully autonomous deployment, a lightweight classifier would set these
self.current_memory_alpha = 1.0
self.current_skill_alpha = 0.0
def set_routing_alphas(self, memory_alpha: float, skill_alpha: float):
"""Dynamically adjust the routing gates before a forward pass."""
self.current_memory_alpha = memory_alpha
self.current_skill_alpha = skill_alpha
# Suppress printing during rapid token generation, but useful for debugging
# print(f"[*] Tri-Arch Routing Updated -> Memory: {memory_alpha} | Skill: {skill_alpha}")
def forward(self, input_ids, **kwargs):
# 1. Base Model Feature Extraction
# Extract the hidden states directly from the frozen Gemma engine
base_outputs = self.base_model(
input_ids,
output_hidden_states=True,
return_dict=True,
**kwargs
)
# Grab the residual stream from the final layer
hidden_states = base_outputs.hidden_states[-1]
# 2. Dynamic Memory Routing
if self.current_memory_alpha > 0.0 and self.memory_gate is not None:
# Pass the stream through the isolated memory highway
memory_states = self.memory_gate(hidden_states)
# Blend it back into the stream based on the alpha weight
hidden_states = hidden_states + (memory_states * self.current_memory_alpha)
# 3. Dynamic Skill Routing (Your G-Code/Python logic lane)
if self.current_skill_alpha > 0.0 and self.skill_gate is not None:
skill_states = self.skill_gate(hidden_states)
hidden_states = hidden_states + (skill_states * self.current_skill_alpha)
# 4. Final Projection
# Project the successfully blended hidden states back into the vocabulary logits
logits = self.base_model.lm_head(hidden_states)
return logits