Spaces:
Sleeping
Sleeping
File size: 2,581 Bytes
8125804 f37be5a 8125804 f37be5a 8125804 f37be5a 8125804 f37be5a 8125804 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 | import torch
import torch.nn as nn
import torch.nn.functional as F
from src.model.config import ModelConfig
class BranchHead(nn.Module):
"""Single branch: projects hidden state to logits with its own temperature."""
def __init__(self, d_model: int, vocab_size: int, temperature: float = 1.0):
super().__init__()
self.proj = nn.Linear(d_model, vocab_size, bias=False)
self.temperature = temperature
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.proj(x) / self.temperature
class BranchRouter(nn.Module):
"""Generates multiple hypotheses via parallel branch heads.
Computes diversity loss to prevent branch collapse.
"""
def __init__(self, cfg: ModelConfig):
super().__init__()
self.cfg = cfg
self.diversity_target = cfg.branch_diversity_target
temps = [
0.8 + 0.4 * i / max(cfg.n_branches - 1, 1)
for i in range(cfg.n_branches)
]
self.branches = nn.ModuleList([
BranchHead(cfg.d_model, cfg.vocab_size, temperature=t) for t in temps
])
self.branch_offsets = nn.Parameter(torch.randn(cfg.n_branches, cfg.d_model) * 0.02)
@staticmethod
def _js_divergence(pi: torch.Tensor, pj: torch.Tensor) -> torch.Tensor:
m = 0.5 * (pi + pj)
log_m = torch.log(m.clamp_min(1e-8))
js = 0.5 * (
(pi * (torch.log(pi.clamp_min(1e-8)) - log_m)).sum(dim=-1)
+ (pj * (torch.log(pj.clamp_min(1e-8)) - log_m)).sum(dim=-1)
)
return js.mean()
def forward(self, x: torch.Tensor) -> dict:
branch_logits = []
for idx, branch in enumerate(self.branches):
branch_input = x + self.branch_offsets[idx].view(1, 1, -1)
branch_logits.append(branch(branch_input))
diversity_loss = torch.tensor(0.0, device=x.device)
n_pairs = 0
for i in range(len(branch_logits)):
for j in range(i + 1, len(branch_logits)):
pi = F.softmax(branch_logits[i], dim=-1)
pj = F.softmax(branch_logits[j], dim=-1)
js_div = self._js_divergence(pi, pj)
diversity_loss = diversity_loss + F.relu(self.diversity_target - js_div)
n_pairs += 1
if n_pairs > 0:
diversity_loss = diversity_loss / n_pairs
avg_logits = torch.stack(branch_logits).mean(dim=0)
return {
"logits": avg_logits,
"branch_logits": branch_logits,
"diversity_loss": diversity_loss,
}
|