controlmt-v2.3 / model.py
anandkaman's picture
ControlMT v2.2 initial release
4b22190 verified
Raw
History Blame Contribute Delete
15.8 kB
"""
ControlMT Model — Modular Encoder-Decoder Transformer with Explicit Routing
Trained by: Anand Kaman
Architecture:
- Shared Core (6 layers, ~40M params) — language-agnostic "brain"
- Per-language Encoder (2 layers, ~10M each) — KN, EN
- Per-language Decoder (2 layers, ~10M each) — KN, EN
- Control Embeddings — style/format vectors injected into core
- Explicit routing — code selects encoder/decoder, not learned
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
# Model hyperparameters
D_MODEL = 512
N_HEADS = 8
D_FF = 2048
DROPOUT = 0.1
ENCODER_LAYERS = 2
DECODER_LAYERS = 2
CORE_LAYERS = 6
MAX_SEQ_LEN = 320 # 256 max content tokens + room for [BOS, direction, style, EOS] prefix + buffer
# Token IDs
PAD_ID = 0
BOS_ID = 1
EOS_ID = 2
# Control token IDs (style/register, set per training example from style_labels.jsonl)
CONTROL_TOKENS = {
"strict": 6, "natural": 7, "formal": 8,
"casual": 9, "json": 10, "text": 11,
}
# Default for pairs without a style label (translit, synth, etc.)
DEFAULT_CONTROL_ID = CONTROL_TOKENS["natural"]
# Direction token IDs.
# v2: kn2en (4), en2kn (5)
# v2.1: + rkn2kn (12) for Aksharantar word-level transliteration data
# v3 reservations: rkn2en (13), hi2en (14), en2hi (15)
DIRECTION_TOKENS = {
"kn2en": 4, "en2kn": 5,
"rkn2kn": 12, "rkn2en": 13,
"hi2en": 14, "en2hi": 15,
}
class PositionalEncoding(nn.Module):
"""Sinusoidal positional encoding."""
def __init__(self, d_model, max_len=MAX_SEQ_LEN, dropout=DROPOUT):
super().__init__()
self.dropout = nn.Dropout(dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # (1, max_len, d_model)
self.register_buffer("pe", pe)
def forward(self, x):
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
class TransformerEncoderBlock(nn.Module):
"""Single transformer encoder layer: self-attention + FFN."""
def __init__(self, d_model=D_MODEL, n_heads=N_HEADS, d_ff=D_FF, dropout=DROPOUT):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
self.norm1 = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout),
)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, src_mask=None, src_key_padding_mask=None):
# Self-attention with residual
attn_out, _ = self.self_attn(x, x, x, key_padding_mask=src_key_padding_mask)
x = self.norm1(x + self.dropout(attn_out))
# FFN with residual
x = self.norm2(x + self.ffn(x))
return x
class TransformerDecoderBlock(nn.Module):
"""Single transformer decoder layer: self-attention + cross-attention + FFN."""
def __init__(self, d_model=D_MODEL, n_heads=N_HEADS, d_ff=D_FF, dropout=DROPOUT):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
self.norm1 = nn.LayerNorm(d_model)
self.cross_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
self.norm2 = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout),
)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, memory, tgt_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
# Masked self-attention
attn_out, _ = self.self_attn(x, x, x, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)
x = self.norm1(x + self.dropout(attn_out))
# Cross-attention over encoder output
cross_out, _ = self.cross_attn(x, memory, memory, key_padding_mask=memory_key_padding_mask)
x = self.norm2(x + self.dropout(cross_out))
# FFN
x = self.norm3(x + self.ffn(x))
return x
class LanguageEncoder(nn.Module):
"""Per-language encoder module (2 layers)."""
def __init__(self, d_model=D_MODEL, n_layers=ENCODER_LAYERS, n_heads=N_HEADS, d_ff=D_FF, dropout=DROPOUT):
super().__init__()
self.layers = nn.ModuleList([
TransformerEncoderBlock(d_model, n_heads, d_ff, dropout)
for _ in range(n_layers)
])
def forward(self, x, src_key_padding_mask=None):
for layer in self.layers:
x = layer(x, src_key_padding_mask=src_key_padding_mask)
return x
class LanguageDecoder(nn.Module):
"""Per-language decoder module (2 layers)."""
def __init__(self, d_model=D_MODEL, n_layers=DECODER_LAYERS, n_heads=N_HEADS, d_ff=D_FF, dropout=DROPOUT):
super().__init__()
self.layers = nn.ModuleList([
TransformerDecoderBlock(d_model, n_heads, d_ff, dropout)
for _ in range(n_layers)
])
def forward(self, x, memory, tgt_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
for layer in self.layers:
x = layer(x, memory, tgt_mask=tgt_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask)
return x
class SharedCore(nn.Module):
"""Shared core — the brain. 6 encoder layers + 6 decoder layers.
The core processes encoder output through its encoder layers,
then the decoder side uses cross-attention to attend to core encoder output.
Control embeddings are prepended to the encoder sequence.
"""
def __init__(self, d_model=D_MODEL, n_layers=CORE_LAYERS, n_heads=N_HEADS, d_ff=D_FF, dropout=DROPOUT):
super().__init__()
self.encoder_layers = nn.ModuleList([
TransformerEncoderBlock(d_model, n_heads, d_ff, dropout)
for _ in range(n_layers)
])
self.decoder_layers = nn.ModuleList([
TransformerDecoderBlock(d_model, n_heads, d_ff, dropout)
for _ in range(n_layers)
])
def encode(self, x, src_key_padding_mask=None):
"""Process encoder output through core encoder layers."""
for layer in self.encoder_layers:
x = layer(x, src_key_padding_mask=src_key_padding_mask)
return x
def decode(self, x, memory, tgt_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
"""Process decoder through core decoder layers with cross-attention to core encoder output."""
for layer in self.decoder_layers:
x = layer(x, memory, tgt_mask=tgt_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask)
return x
class ControlMT(nn.Module):
"""
ControlMT — Modular Encoder-Decoder Transformer
Flow: Input -> Lang Encoder -> Shared Core Encoder -> Shared Core Decoder <- Lang Decoder -> Output
Control embeddings prepended to encoder sequence for style/format control.
"""
def __init__(self, vocab_size, d_model=D_MODEL, n_heads=N_HEADS, d_ff=D_FF,
dropout=DROPOUT, max_seq_len=MAX_SEQ_LEN, n_control_tokens=6):
super().__init__()
self.d_model = d_model
self.vocab_size = vocab_size
# Shared token embedding (all languages share vocabulary)
self.token_embedding = nn.Embedding(vocab_size, d_model, padding_idx=PAD_ID)
self.pos_encoding = PositionalEncoding(d_model, max_seq_len, dropout)
# Control embeddings (style/format — injected into encoder sequence)
self.control_embedding = nn.Embedding(n_control_tokens, d_model)
# Per-language encoders
self.encoders = nn.ModuleDict({
"kn": LanguageEncoder(d_model),
"en": LanguageEncoder(d_model),
})
# Shared core (the brain)
self.core = SharedCore(d_model)
# Per-language decoders
self.decoders = nn.ModuleDict({
"kn": LanguageDecoder(d_model),
"en": LanguageDecoder(d_model),
})
# Output projection (shared across languages)
self.output_proj = nn.Linear(d_model, vocab_size)
# Tie embedding weights with output projection
self.output_proj.weight = self.token_embedding.weight
# Init weights
self._init_weights()
def _init_weights(self):
"""Xavier uniform initialization."""
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def _get_lang(self, direction_id):
"""Get src/tgt language from direction ID.
v2.1 supports:
kn2en (4) → kn enc, en dec
en2kn (5) → en enc, kn dec
rkn2kn (12) → en enc, kn dec (romanized Kannada uses EN encoder — it's Latin script)
v3 reservations (rkn2en, hi2en, en2hi) not yet wired.
"""
if direction_id == DIRECTION_TOKENS["kn2en"]:
return "kn", "en"
elif direction_id == DIRECTION_TOKENS["en2kn"]:
return "en", "kn"
elif direction_id == DIRECTION_TOKENS["rkn2kn"]:
# Romanized Kannada is Latin-script — route through the EN encoder.
# Target is Kannada script → KN decoder.
return "en", "kn"
else:
raise ValueError(f"Unknown direction ID: {direction_id}")
@staticmethod
def generate_square_subsequent_mask(sz, device):
"""Causal mask for decoder self-attention."""
mask = torch.triu(torch.ones(sz, sz, device=device), diagonal=1).bool()
return mask
def encode(self, src_ids, src_mask, direction_id, control_id=CONTROL_TOKENS["strict"]):
"""
Encode source sequence.
Args:
src_ids: (batch, src_len) — source token IDs
src_mask: (batch, src_len) — 1 for real tokens, 0 for padding
direction_id: int — direction token ID (4=KN2EN, 5=EN2KN)
control_id: int — control token ID (6=strict, etc.)
Returns:
memory: (batch, src_len+1, d_model) — encoded representation
memory_key_padding_mask: (batch, src_len+1) — padding mask
"""
src_lang, _ = self._get_lang(direction_id)
# Embed tokens
x = self.token_embedding(src_ids) * math.sqrt(self.d_model)
x = self.pos_encoding(x)
# Create padding mask (True = ignore)
src_key_padding_mask = (src_mask == 0)
# Pass through language-specific encoder
x = self.encoders[src_lang](x, src_key_padding_mask=src_key_padding_mask)
# Prepend control embedding
batch_size = x.size(0)
ctrl = self.control_embedding(torch.tensor([control_id - 6], device=x.device)) # offset by first control ID
ctrl = ctrl.unsqueeze(0).expand(batch_size, -1, -1) # (batch, 1, d_model)
x = torch.cat([ctrl, x], dim=1) # (batch, src_len+1, d_model)
# Extend padding mask for control token (always attend to it)
ctrl_mask = torch.zeros(batch_size, 1, dtype=torch.bool, device=x.device)
memory_key_padding_mask = torch.cat([ctrl_mask, src_key_padding_mask], dim=1)
# Pass through shared core encoder
memory = self.core.encode(x, src_key_padding_mask=memory_key_padding_mask)
return memory, memory_key_padding_mask
def decode(self, tgt_ids, tgt_mask, memory, memory_key_padding_mask, direction_id):
"""
Decode target sequence.
Args:
tgt_ids: (batch, tgt_len)
tgt_mask: (batch, tgt_len) — 1 for real tokens, 0 for padding
memory: (batch, src_len+1, d_model)
memory_key_padding_mask: (batch, src_len+1)
direction_id: int
Returns:
logits: (batch, tgt_len, vocab_size)
"""
_, tgt_lang = self._get_lang(direction_id)
# Embed target tokens
x = self.token_embedding(tgt_ids) * math.sqrt(self.d_model)
x = self.pos_encoding(x)
# Causal mask for decoder
tgt_len = tgt_ids.size(1)
causal_mask = self.generate_square_subsequent_mask(tgt_len, tgt_ids.device)
tgt_key_padding_mask = (tgt_mask == 0)
# Pass through shared core decoder
x = self.core.decode(x, memory, tgt_mask=causal_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask)
# Pass through language-specific decoder
x = self.decoders[tgt_lang](x, memory, tgt_mask=causal_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask)
# Project to vocabulary
logits = self.output_proj(x)
return logits
def forward(self, src_ids, src_mask, tgt_ids, tgt_mask, direction_id, control_id=CONTROL_TOKENS["strict"]):
"""
Full forward pass for training.
Args:
src_ids: (batch, src_len)
src_mask: (batch, src_len)
tgt_ids: (batch, tgt_len)
tgt_mask: (batch, tgt_len)
direction_id: int — single direction for the batch
control_id: int
Returns:
logits: (batch, tgt_len, vocab_size)
"""
memory, memory_key_padding_mask = self.encode(src_ids, src_mask, direction_id, control_id)
logits = self.decode(tgt_ids, tgt_mask, memory, memory_key_padding_mask, direction_id)
return logits
def count_parameters(model):
"""Count trainable parameters."""
total = sum(p.numel() for p in model.parameters() if p.requires_grad)
breakdown = {}
for name, module in model.named_children():
params = sum(p.numel() for p in module.parameters() if p.requires_grad)
if params > 0:
breakdown[name] = params
return total, breakdown
if __name__ == "__main__":
# Test model with dummy data
VOCAB_SIZE = 64000
BATCH_SIZE = 4
SRC_LEN = 20
TGT_LEN = 15
model = ControlMT(vocab_size=VOCAB_SIZE)
total, breakdown = count_parameters(model)
print(f"ControlMT — Total parameters: {total:,} ({total/1e6:.1f}M)")
print(f"\nParameter breakdown:")
for name, params in breakdown.items():
print(f" {name}: {params:,} ({params/1e6:.1f}M)")
# Dummy forward pass
src_ids = torch.randint(4, VOCAB_SIZE, (BATCH_SIZE, SRC_LEN))
tgt_ids = torch.randint(4, VOCAB_SIZE, (BATCH_SIZE, TGT_LEN))
src_mask = torch.ones(BATCH_SIZE, SRC_LEN, dtype=torch.long)
tgt_mask = torch.ones(BATCH_SIZE, TGT_LEN, dtype=torch.long)
# Test KN->EN
logits = model(src_ids, src_mask, tgt_ids, tgt_mask, direction_id=4)
print(f"\nForward pass (KN->EN):")
print(f" Input: src={src_ids.shape}, tgt={tgt_ids.shape}")
print(f" Output logits: {logits.shape}")
print(f" Expected: ({BATCH_SIZE}, {TGT_LEN}, {VOCAB_SIZE})")
# Test EN->KN
logits = model(src_ids, src_mask, tgt_ids, tgt_mask, direction_id=5)
print(f"\nForward pass (EN->KN):")
print(f" Output logits: {logits.shape}")
print("\nModel architecture test PASSED!")