telen / src /telern /hypernetwork.py
haidang2405's picture
first commit
e7cfc32 verified
Raw
History Blame Contribute Delete
6.51 kB
"""
HyperNetwork for TELEN.
Core innovation: Instead of learning fixed projection weights, the HyperNetwork
GENERATES the projection function from the current legal corpus state.
When new laws arrive → state vector changes → HyperNetwork produces new weights
→ embedding space adapts WITHOUT retraining.
Additionally outputs variance for stochastic embeddings (uncertainty-aware retrieval).
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class HyperNetwork(nn.Module):
"""
Generates embedding projection parameters from a legal state vector.
Given state vector s ∈ R^d, produces:
- ΔW: low-rank projection shift (weighted sum of learned rank-1 bases)
- Δb: bias shift (weighted sum of learned bias bases)
- log_σ²: per-dimension log-variance for stochastic embedding
Architecture: Instead of generating giant parameter matrices directly,
we store a compact set of learned basis vectors and use the HyperNetwork
to generate ONLY the combination weights. This is parameter-efficient
and forces generalization.
"""
def __init__(self, config):
super().__init__()
self.config = config
hn = config.hypernetwork
d = config.hidden_dim
r = hn.adaptation_rank
hidden = hn.hn_hidden_dim
# Shared trunk: state → latent code
self.trunk = nn.Sequential(
nn.Linear(d, hidden),
nn.ReLU(),
nn.Dropout(hn.dropout),
nn.Linear(hidden, hidden),
nn.ReLU(),
nn.Dropout(hn.dropout),
nn.Linear(hidden, hidden),
nn.LayerNorm(hidden),
)
# Modulator: latent → combination weights for all outputs
self.modulator = nn.Linear(hidden, 2 * r + r + 1) # A_weights + B_weights + bias_weights + var_context
# === Learned basis vectors (stored, not generated) ===
# For ΔW = Σ_i w^A_i * (u_i ⊗ v_i^T) where u_i, v_i ∈ R^d
self.basis_u = nn.Parameter(torch.randn(r, d) * 0.01) # [r, d]
self.basis_v = nn.Parameter(torch.randn(r, d) * 0.01) # [r, d]
# For Δb = Σ_i w^b_i * b_i where b_i ∈ R^d
self.basis_b = nn.Parameter(torch.randn(r, d) * 0.01) # [r, d]
# Variance head
if hn.output_variance:
self.head_logvar = nn.Sequential(
nn.Linear(hidden, hidden),
nn.Tanh(),
nn.Linear(hidden, d),
)
else:
self.head_logvar = None
def forward(self, state_vector: torch.Tensor) -> dict:
"""
Args:
state_vector: [d] or [B, d] summarizing current legal landscape
Returns dict with keys:
"shift_matrix": [d, d] or [B, d, d] rank-r projection shift
"bias": [d] or [B, d] bias shift
"log_variance": [d] or [B, d] log variance for stochastic embedding
"""
squeeze = state_vector.dim() == 1
if squeeze:
state_vector = state_vector.unsqueeze(0) # [1, d]
B, d = state_vector.shape
r = self.config.hypernetwork.adaptation_rank
# Shared representation
h = self.trunk(state_vector) # [B, hidden]
modulated = self.modulator(h) # [B, 2r + r + 1]
# Split modulation weights
w_A = modulated[:, :r] # [B, r]
w_B = modulated[:, r:2*r] # [B, r]
w_bias = modulated[:, 2*r:3*r] # [B, r]
# Build shift matrix: ΔW = Σ_i w^A_i * (u_i ⊗ v_i^T)
# Weighted combination of basis vectors
u_combined = w_A @ self.basis_u # [B, d]
v_combined = w_B @ self.basis_v # [B, d]
shift = torch.bmm(
u_combined.unsqueeze(2), # [B, d, 1]
v_combined.unsqueeze(1), # [B, 1, d]
) # [B, d, d]
# Low-rank: this is rank-1. For rank r, generate r outer products and sum.
# Simple yet effective: use weighted sum of r rank-1 components
shift = shift.squeeze(0) if B == 1 else shift # [d, d] or [B, d, d]
if B == 1:
shift = shift.unsqueeze(0)
# Actually let's do proper rank-r: sum over rank dimension
# w_A: [B, r], basis_u: [r, d]
# For each rank i: w_A[:, i:i+1] * (basis_u[i:i+1]^T @ basis_v[i:i+1])
# = Σ_i (w_A[:, i] * basis_u[i]) ⊗ (w_B[:, i] * basis_v[i])
u_weighted = (w_A.unsqueeze(2) * self.basis_u.unsqueeze(0)) # [B, r, d]
v_weighted = (w_B.unsqueeze(2) * self.basis_v.unsqueeze(0)) # [B, r, d]
shift_ranked = torch.einsum("brd,bre->brde", u_weighted, v_weighted) # [B, r, d, d]
shift = shift_ranked.sum(dim=1) # [B, d, d]
# Bias
bias = (w_bias.unsqueeze(2) * self.basis_b.unsqueeze(0)).sum(dim=1) # [B, d]
result = {"shift_matrix": shift, "bias": bias}
# Log variance
if self.head_logvar is not None:
logvar = self.head_logvar(h)
logvar = torch.clamp(logvar, min=-5.0, max=2.0)
result["log_variance"] = logvar
else:
result["log_variance"] = torch.full((B, d), -3.0, device=h.device)
if squeeze:
result = {k: v.squeeze(0) for k, v in result.items()}
return result
class StateEncoder(nn.Module):
"""
Encodes the legal concept graph into a compact state vector.
This is separate from the HyperNetwork so the graph computation
can be cached and only updated when the graph changes.
"""
def __init__(self, dim: int):
super().__init__()
self.state_proj = nn.Sequential(
nn.Linear(dim, dim * 2),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(dim * 2, dim),
nn.LayerNorm(dim),
)
def forward(self, node_embeddings: torch.Tensor, node_weights: torch.Tensor = None) -> torch.Tensor:
"""
Args:
node_embeddings: [N, d] refined node embeddings from GNN
node_weights: [N] optional attention weights
Returns:
state_vector: [d] summarizing the legal landscape
"""
if node_weights is None:
# Equal weight if none provided
node_weights = torch.ones(
node_embeddings.shape[0], device=node_embeddings.device
)
node_weights = F.softmax(node_weights, dim=0)
pooled = (node_embeddings * node_weights.unsqueeze(1)).sum(dim=0)
return self.state_proj(pooled)