SHOREKEEPER / src /council /sentinel.py
geoore's picture
Restructure to src/ layout with attention, per-layer MoE, and working chat
73400c8
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional
class Sentinel(nn.Module):
def __init__(self, dim: int, n_experts: int = 12, n_activated: int = 2):
super().__init__()
self.n_experts = n_experts
self.n_activated = n_activated
self.gate = nn.Linear(dim, n_experts, bias=False)
self.expert_bias = nn.Parameter(torch.zeros(n_experts))
self.register_buffer("usage_counts", torch.zeros(n_experts))
self.register_buffer("total_tokens", torch.tensor(0.0))
self.temperature = nn.Parameter(torch.ones(1) * 1.0)
def forward(self, x: torch.Tensor, role_hints: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
logits = self.gate(x) + self.expert_bias
if role_hints is not None:
logits = logits + role_hints
logits = logits / self.temperature.abs().clamp(min=0.1, max=2.0)
weights, indices = logits.topk(self.n_activated, dim=-1)
weights = F.softmax(weights, dim=-1)
if self.training:
self._update_usage(indices)
return weights, indices
def _update_usage(self, indices):
for i in range(self.n_activated):
self.usage_counts.scatter_add_(0, indices[:, i], torch.ones_like(indices[:, i], dtype=torch.float))
self.total_tokens += indices.shape[0]
def get_load_balance_loss(self) -> torch.Tensor:
if self.total_tokens == 0:
return torch.tensor(0.0, device=self.expert_bias.device)
probs = self.usage_counts / self.total_tokens
ideal = 1.0 / self.n_experts
loss = ((probs - ideal) ** 2).mean()
self.usage_counts.zero_()
self.total_tokens.zero_()
return loss * 0.01
def get_role_entropy(self) -> torch.Tensor:
if self.total_tokens == 0:
return torch.tensor(0.0)
probs = self.usage_counts / self.total_tokens
entropy = -(probs * torch.log(probs + 1e-8)).sum()
return entropy