| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Tuple, Optional |
|
|
| class Sentinel(nn.Module): |
| def __init__(self, dim: int, n_experts: int = 12, n_activated: int = 2): |
| super().__init__() |
| self.n_experts = n_experts |
| self.n_activated = n_activated |
| self.gate = nn.Linear(dim, n_experts, bias=False) |
| self.expert_bias = nn.Parameter(torch.zeros(n_experts)) |
| self.register_buffer("usage_counts", torch.zeros(n_experts)) |
| self.register_buffer("total_tokens", torch.tensor(0.0)) |
| self.temperature = nn.Parameter(torch.ones(1) * 1.0) |
| |
| def forward(self, x: torch.Tensor, role_hints: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: |
| logits = self.gate(x) + self.expert_bias |
| if role_hints is not None: |
| logits = logits + role_hints |
| logits = logits / self.temperature.abs().clamp(min=0.1, max=2.0) |
| weights, indices = logits.topk(self.n_activated, dim=-1) |
| weights = F.softmax(weights, dim=-1) |
| if self.training: |
| self._update_usage(indices) |
| return weights, indices |
| |
| def _update_usage(self, indices): |
| for i in range(self.n_activated): |
| self.usage_counts.scatter_add_(0, indices[:, i], torch.ones_like(indices[:, i], dtype=torch.float)) |
| self.total_tokens += indices.shape[0] |
| |
| def get_load_balance_loss(self) -> torch.Tensor: |
| if self.total_tokens == 0: |
| return torch.tensor(0.0, device=self.expert_bias.device) |
| probs = self.usage_counts / self.total_tokens |
| ideal = 1.0 / self.n_experts |
| loss = ((probs - ideal) ** 2).mean() |
| self.usage_counts.zero_() |
| self.total_tokens.zero_() |
| return loss * 0.01 |
| |
| def get_role_entropy(self) -> torch.Tensor: |
| if self.total_tokens == 0: |
| return torch.tensor(0.0) |
| probs = self.usage_counts / self.total_tokens |
| entropy = -(probs * torch.log(probs + 1e-8)).sum() |
| return entropy |
|
|