vigneshwar234 commited on
Commit
d266df1
·
verified ·
1 Parent(s): 8f6eed4

Add source: tmt/model/exit_gate.py

Browse files
Files changed (1) hide show
  1. tmt/model/exit_gate.py +61 -0
tmt/model/exit_gate.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ exit_gate.py — ExitGate: per-token adaptive depth routing.
3
+
4
+ Novel vs standard: every token in a standard transformer passes through all N
5
+ layers unconditionally. ExitGate computes a confidence scalar after each layer
6
+ norm. If confidence > exit_threshold the token's representation is frozen and
7
+ it skips all remaining layers, halving average compute on easy tokens.
8
+
9
+ The auxiliary training loss encourages the gate to be decisive (push toward 0
10
+ or 1) without forcing early exits — the model learns when it is confident.
11
+ """
12
+ from __future__ import annotations
13
+
14
+ from typing import Tuple
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ from torch import Tensor
19
+
20
+ from .config import TMTConfig
21
+
22
+
23
+ class ExitGate(nn.Module):
24
+ """Single linear → sigmoid confidence gate per token."""
25
+
26
+ def __init__(self, cfg: TMTConfig) -> None:
27
+ super().__init__()
28
+ self.threshold = cfg.exit_threshold
29
+ # Single scalar projection: d_model → 1
30
+ self.gate_proj = nn.Linear(cfg.d_model, 1)
31
+ nn.init.zeros_(self.gate_proj.weight)
32
+ nn.init.constant_(self.gate_proj.bias, -2.0) # start pessimistic
33
+
34
+ def forward(self, x: Tensor, exit_mask: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
35
+ """
36
+ Args:
37
+ x: (B, S, D) current token representations
38
+ exit_mask: (B, S) bool — True where token has already exited
39
+
40
+ Returns:
41
+ x: (B, S, D) unchanged (gating is applied in TMTLayer)
42
+ new_mask: (B, S) bool — updated exit mask
43
+ confidence: (B, S) float confidence scores for auxiliary loss
44
+ """
45
+ confidence = torch.sigmoid(self.gate_proj(x)).squeeze(-1) # (B, S)
46
+
47
+ # New exits: not already exited AND confidence above threshold
48
+ newly_exited = (~exit_mask) & (confidence > self.threshold)
49
+ new_mask = exit_mask | newly_exited
50
+ return x, new_mask, confidence
51
+
52
+ def auxiliary_loss(self, confidence: Tensor) -> Tensor:
53
+ """
54
+ Encourage decisive gates — push confidence toward 0 or 1.
55
+ Loss = -E[|conf - 0.5|] so the model is penalised for being uncertain.
56
+ """
57
+ return -(confidence - 0.5).abs().mean()
58
+
59
+ def __repr__(self) -> str:
60
+ p = sum(p.numel() for p in self.parameters())
61
+ return f"ExitGate(threshold={self.threshold}, params={p:,})"