lexiform-13m / model /model.py
Raj0pro's picture
Upload model/model.py with huggingface_hub
e657057 verified
Raw
History Blame Contribute Delete
11.5 kB
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.config import ModelConfig
from model.blocks import EncoderBlock, DecoderBlock, RMSNorm
class CopyGate(nn.Module):
"""
Pointer-generator gate.
Returns a raw logit (pre-sigmoid). Caller uses F.logsigmoid(+logit) /
F.logsigmoid(-logit) for an exact, allocation-free log-space mixture.
Input features: [context ; decoder_state ; decoder_input_emb]. The
decoder-input embedding (post-lookup, pre-stack) is concatenated alongside
the decoder hidden state — required for compatibility with the trained
checkpoint, which was produced with this 3-way input.
"""
def __init__(self, d_model: int):
super().__init__()
self.linear = nn.Linear(d_model * 3, 1, bias=True)
def forward(
self,
context: torch.Tensor, # (B, T, d) weighted sum of encoder states
decoder_state: torch.Tensor, # (B, T, d) decoder hidden state
decoder_input: torch.Tensor, # (B, T, d) raw decoder-input embedding
) -> torch.Tensor:
combined = torch.cat([context, decoder_state, decoder_input], dim=-1)
return self.linear(combined) # (B, T, 1) — raw logit
def _build_src_pad_mask(
src_ids: torch.Tensor, pad_id: int, dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""Returns additive float mask of shape (B, 1, 1, T_src) with -inf at padding."""
pad = (src_ids == pad_id) # (B, T_src)
mask = torch.zeros(pad.shape, dtype=dtype, device=pad.device)
mask = mask.masked_fill(pad, float("-inf"))
return mask.unsqueeze(1).unsqueeze(1)
class ParaphraseModel(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
# Embedding and output projection are sized to include the
# `num_sentinels` ids appended after the BPE vocab (T5 span corruption).
# At num_sentinels=0 this is identical to the original architecture.
V = config.effective_vocab_size
self.embedding = nn.Embedding(V, config.d_model, padding_idx=config.pad_id)
self.enc_norm = RMSNorm(config.d_model)
self.dec_norm = RMSNorm(config.d_model)
self.encoder = nn.ModuleList([EncoderBlock(config) for _ in range(config.num_encoder_layers)])
self.decoder = nn.ModuleList([DecoderBlock(config) for _ in range(config.num_decoder_layers)])
self.output_proj = nn.Linear(config.d_model, V, bias=False)
self.output_proj.weight = self.embedding.weight # tie weights
self.drop = nn.Dropout(config.dropout)
if config.use_copy:
self.copy_gate = CopyGate(config.d_model)
self._init_weights()
def _init_weights(self):
d = self.config.d_model
for name, p in self.named_parameters():
if "copy_gate" in name:
continue # handled explicitly below
if name == "embedding.weight":
# Tied with the output projection — Xavier on (V, d) gives a
# per-element scale ~1/sqrt(V+d), starving the logit gradient.
# Normal(0, d^-0.5) matches T5/LLaMA-style weight-tied init.
nn.init.normal_(p, mean=0.0, std=d ** -0.5)
continue
if p.dim() > 1:
nn.init.xavier_uniform_(p)
# Zero-init the copy gate so logit==0 -> p_gen==0.5 at step 0,
# giving an exact 50/50 mix between vocab and copy distributions.
if self.config.use_copy:
nn.init.zeros_(self.copy_gate.linear.weight)
nn.init.zeros_(self.copy_gate.linear.bias)
# ---------- encoding ----------
def encode(self, src_ids: torch.Tensor) -> torch.Tensor:
src_attn_mask = _build_src_pad_mask(
src_ids, self.config.pad_id, dtype=self.embedding.weight.dtype,
)
# sqrt(d_model) scaling: tied embeddings double as the output
# projection, so the unscaled lookup leaves both the residual stream
# and the output logits under-magnitude.
scale = math.sqrt(self.config.d_model)
x = self.drop(self.embedding(src_ids) * scale)
for layer in self.encoder:
x = layer(x, attn_mask=src_attn_mask)
return self.enc_norm(x)
# ---------- internals: a single decoder pass, parallel or incremental ----------
def _decoder_pass(
self,
tgt_ids: torch.Tensor,
encoder_out: torch.Tensor,
src_attn_mask: torch.Tensor,
layer_caches: list[dict] | None,
start_pos: int,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]:
scale = math.sqrt(self.config.d_model)
x_emb = self.drop(self.embedding(tgt_ids) * scale)
x = x_emb
last_idx = len(self.decoder) - 1
attn_weights = None
for i, layer in enumerate(self.decoder):
layer_cache = layer_caches[i] if layer_caches is not None else None
# Only the last layer's cross-attn weights feed the copy mechanism,
# so the earlier layers can use SDPA (Flash kernels) and skip the
# manual softmax + dropout-on-probs path.
need_w = self.config.use_copy and (i == last_idx)
x, attn_weights = layer(
x, encoder_out,
src_attn_mask=src_attn_mask,
self_attn_mask=None, # rely on SDPA causal flag / cache semantics
layer_cache=layer_cache,
start_pos=start_pos,
need_weights=need_w,
)
x = self.dec_norm(x)
return x, attn_weights, x_emb # attn_weights: (B, H, T_tgt, T_src) from last layer
def _project_and_copy(
self,
dec_state: torch.Tensor, # (B, T_tgt, d)
encoder_out: torch.Tensor, # (B, T_src, d)
attn_weights: torch.Tensor, # (B, H, T_tgt, T_src)
src_ids: torch.Tensor | None,
dec_input_emb: torch.Tensor, # (B, T_tgt, d) raw decoder-input embedding
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""Returns (scores, attn_avg).
scores — log-probs when use_copy=True, raw logits otherwise.
attn_avg — (B, T_tgt, T_src) mean-over-heads cross-attention, or None
when copy is disabled / src_ids is None. Exposed so the
inference loop can read the attended source position for the
phrase-table bias without a second forward pass.
"""
vocab_logits = self.output_proj(dec_state) # (B, T, V)
if not self.config.use_copy or src_ids is None:
return vocab_logits, None # caller treats as raw logits
B, T_tgt, V = vocab_logits.shape
# average per-head attention to a single distribution for copy
attn_avg = attn_weights.mean(dim=1) # (B, T_tgt, T_src)
# log-domain mixture: avoids the log(p + eps) numerical issue
log_vocab_probs = F.log_softmax(vocab_logits, dim=-1) # (B, T, V)
src_expanded = src_ids.unsqueeze(1).expand(B, T_tgt, -1) # (B, T_tgt, T_src) — view
# accumulate copy mass in prob space then log; safe because per-token copy mass is >= 0
copy_probs = torch.zeros_like(log_vocab_probs)
copy_probs.scatter_add_(2, src_expanded, attn_avg)
log_copy_probs = torch.log(copy_probs.clamp_min(1e-20))
context = torch.bmm(attn_avg, encoder_out) # (B, T_tgt, d)
p_gen_logit = self.copy_gate(context, dec_state, dec_input_emb) # (B, T_tgt, 1) — raw
# F.logsigmoid is numerically exact for both branches — no clamp needed.
log_p_gen = F.logsigmoid( p_gen_logit)
log_p_copy = F.logsigmoid(-p_gen_logit)
final_log_probs = torch.logaddexp(
log_p_gen + log_vocab_probs,
log_p_copy + log_copy_probs,
)
return final_log_probs, attn_avg # caller treats scores as log-probs
# ---------- training-time parallel decode ----------
def decode(
self,
tgt_ids: torch.Tensor,
encoder_out: torch.Tensor,
src_ids: torch.Tensor | None = None,
) -> torch.Tensor:
if src_ids is not None:
src_attn_mask = _build_src_pad_mask(
src_ids, self.config.pad_id, dtype=encoder_out.dtype,
)
else:
src_attn_mask = None
dec_state, attn_weights, dec_input_emb = self._decoder_pass(
tgt_ids, encoder_out, src_attn_mask, layer_caches=None, start_pos=0,
)
scores, _ = self._project_and_copy(
dec_state, encoder_out, attn_weights, src_ids, dec_input_emb,
)
return scores
# ---------- inference-time incremental decode ----------
def init_caches(self, num_layers: int | None = None) -> list[dict]:
n = num_layers if num_layers is not None else self.config.num_decoder_layers
return [{"self": {}, "cross": {}} for _ in range(n)]
@torch.no_grad()
def decode_step(
self,
token: torch.Tensor, # (B, 1)
encoder_out: torch.Tensor, # (B, T_src, d)
src_attn_mask: torch.Tensor, # (B, 1, 1, T_src), additive float
src_ids: torch.Tensor | None,
layer_caches: list[dict],
step: int,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""Returns (scores, attn_avg).
attn_avg is (B, 1, T_src) when use_copy=True, else None.
The inference loop reads attn_avg to identify the attended source
position for the phrase-table bias without a second forward pass.
"""
dec_state, attn_weights, dec_input_emb = self._decoder_pass(
token, encoder_out, src_attn_mask, layer_caches=layer_caches, start_pos=step,
)
return self._project_and_copy(
dec_state, encoder_out, attn_weights, src_ids, dec_input_emb,
)
# ---------- top-level forward (training) ----------
def forward(
self,
src_ids: torch.Tensor,
dec_input: torch.Tensor,
) -> torch.Tensor:
encoder_out = self.encode(src_ids)
return self.decode(dec_input, encoder_out, src_ids=src_ids)
# ---------- simple greedy generate (KV-cached) — kept for quick smoke tests ----------
@torch.no_grad()
def generate(
self,
src_ids: torch.Tensor,
max_new_tokens: int = 128,
) -> torch.Tensor:
self.eval()
device = src_ids.device
encoder_out = self.encode(src_ids)
src_attn_mask = _build_src_pad_mask(
src_ids, self.config.pad_id, dtype=encoder_out.dtype,
)
B = src_ids.size(0)
caches = self.init_caches()
cur = torch.full((B, 1), self.config.bos_id, dtype=torch.long, device=device)
out_tokens = []
finished = torch.zeros(B, dtype=torch.bool, device=device)
for step in range(max_new_tokens):
logits_or_lp, _ = self.decode_step(cur, encoder_out, src_attn_mask, src_ids, caches, step)
next_token = logits_or_lp[:, -1, :].argmax(dim=-1, keepdim=True)
out_tokens.append(next_token)
finished = finished | (next_token.squeeze(-1) == self.config.eos_id)
if finished.all():
break
cur = next_token
return torch.cat(out_tokens, dim=1)
def param_count(self) -> int:
return sum(p.numel() for p in self.parameters())