Translation
Transformers
Safetensors
Kannada
English
controlmt
text2text-generation
machine-translation
kannada
english
indic
low-resource
code-mix
encoder-decoder
custom_code
Eval Results (legacy)
Instructions to use anandkaman/controlmt-v2.3 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use anandkaman/controlmt-v2.3 with Transformers:
# Use a pipeline as a high-level helper # Warning: Pipeline type "translation" is no longer supported in transformers v5. # You must load the model directly (see below) or downgrade to v4.x with: # 'pip install "transformers<5.0.0' from transformers import pipeline pipe = pipeline("translation", model="anandkaman/controlmt-v2.3", trust_remote_code=True)# Load model directly from transformers import AutoModelForSeq2SeqLM model = AutoModelForSeq2SeqLM.from_pretrained("anandkaman/controlmt-v2.3", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """ | |
| ControlMT Model — Modular Encoder-Decoder Transformer with Explicit Routing | |
| Trained by: Anand Kaman | |
| Architecture: | |
| - Shared Core (6 layers, ~40M params) — language-agnostic "brain" | |
| - Per-language Encoder (2 layers, ~10M each) — KN, EN | |
| - Per-language Decoder (2 layers, ~10M each) — KN, EN | |
| - Control Embeddings — style/format vectors injected into core | |
| - Explicit routing — code selects encoder/decoder, not learned | |
| """ | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| # Model hyperparameters | |
| D_MODEL = 512 | |
| N_HEADS = 8 | |
| D_FF = 2048 | |
| DROPOUT = 0.1 | |
| ENCODER_LAYERS = 2 | |
| DECODER_LAYERS = 2 | |
| CORE_LAYERS = 6 | |
| MAX_SEQ_LEN = 320 # 256 max content tokens + room for [BOS, direction, style, EOS] prefix + buffer | |
| # Token IDs | |
| PAD_ID = 0 | |
| BOS_ID = 1 | |
| EOS_ID = 2 | |
| # Control token IDs (style/register, set per training example from style_labels.jsonl) | |
| CONTROL_TOKENS = { | |
| "strict": 6, "natural": 7, "formal": 8, | |
| "casual": 9, "json": 10, "text": 11, | |
| } | |
| # Default for pairs without a style label (translit, synth, etc.) | |
| DEFAULT_CONTROL_ID = CONTROL_TOKENS["natural"] | |
| # Direction token IDs. | |
| # v2: kn2en (4), en2kn (5) | |
| # v2.1: + rkn2kn (12) for Aksharantar word-level transliteration data | |
| # v3 reservations: rkn2en (13), hi2en (14), en2hi (15) | |
| DIRECTION_TOKENS = { | |
| "kn2en": 4, "en2kn": 5, | |
| "rkn2kn": 12, "rkn2en": 13, | |
| "hi2en": 14, "en2hi": 15, | |
| } | |
| class PositionalEncoding(nn.Module): | |
| """Sinusoidal positional encoding.""" | |
| def __init__(self, d_model, max_len=MAX_SEQ_LEN, dropout=DROPOUT): | |
| super().__init__() | |
| self.dropout = nn.Dropout(dropout) | |
| pe = torch.zeros(max_len, d_model) | |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | |
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| pe = pe.unsqueeze(0) # (1, max_len, d_model) | |
| self.register_buffer("pe", pe) | |
| def forward(self, x): | |
| x = x + self.pe[:, :x.size(1)] | |
| return self.dropout(x) | |
| class TransformerEncoderBlock(nn.Module): | |
| """Single transformer encoder layer: self-attention + FFN.""" | |
| def __init__(self, d_model=D_MODEL, n_heads=N_HEADS, d_ff=D_FF, dropout=DROPOUT): | |
| super().__init__() | |
| self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True) | |
| self.norm1 = nn.LayerNorm(d_model) | |
| self.ffn = nn.Sequential( | |
| nn.Linear(d_model, d_ff), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(d_ff, d_model), | |
| nn.Dropout(dropout), | |
| ) | |
| self.norm2 = nn.LayerNorm(d_model) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x, src_mask=None, src_key_padding_mask=None): | |
| # Self-attention with residual | |
| attn_out, _ = self.self_attn(x, x, x, key_padding_mask=src_key_padding_mask) | |
| x = self.norm1(x + self.dropout(attn_out)) | |
| # FFN with residual | |
| x = self.norm2(x + self.ffn(x)) | |
| return x | |
| class TransformerDecoderBlock(nn.Module): | |
| """Single transformer decoder layer: self-attention + cross-attention + FFN.""" | |
| def __init__(self, d_model=D_MODEL, n_heads=N_HEADS, d_ff=D_FF, dropout=DROPOUT): | |
| super().__init__() | |
| self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True) | |
| self.norm1 = nn.LayerNorm(d_model) | |
| self.cross_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True) | |
| self.norm2 = nn.LayerNorm(d_model) | |
| self.ffn = nn.Sequential( | |
| nn.Linear(d_model, d_ff), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(d_ff, d_model), | |
| nn.Dropout(dropout), | |
| ) | |
| self.norm3 = nn.LayerNorm(d_model) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x, memory, tgt_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None): | |
| # Masked self-attention | |
| attn_out, _ = self.self_attn(x, x, x, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask) | |
| x = self.norm1(x + self.dropout(attn_out)) | |
| # Cross-attention over encoder output | |
| cross_out, _ = self.cross_attn(x, memory, memory, key_padding_mask=memory_key_padding_mask) | |
| x = self.norm2(x + self.dropout(cross_out)) | |
| # FFN | |
| x = self.norm3(x + self.ffn(x)) | |
| return x | |
| class LanguageEncoder(nn.Module): | |
| """Per-language encoder module (2 layers).""" | |
| def __init__(self, d_model=D_MODEL, n_layers=ENCODER_LAYERS, n_heads=N_HEADS, d_ff=D_FF, dropout=DROPOUT): | |
| super().__init__() | |
| self.layers = nn.ModuleList([ | |
| TransformerEncoderBlock(d_model, n_heads, d_ff, dropout) | |
| for _ in range(n_layers) | |
| ]) | |
| def forward(self, x, src_key_padding_mask=None): | |
| for layer in self.layers: | |
| x = layer(x, src_key_padding_mask=src_key_padding_mask) | |
| return x | |
| class LanguageDecoder(nn.Module): | |
| """Per-language decoder module (2 layers).""" | |
| def __init__(self, d_model=D_MODEL, n_layers=DECODER_LAYERS, n_heads=N_HEADS, d_ff=D_FF, dropout=DROPOUT): | |
| super().__init__() | |
| self.layers = nn.ModuleList([ | |
| TransformerDecoderBlock(d_model, n_heads, d_ff, dropout) | |
| for _ in range(n_layers) | |
| ]) | |
| def forward(self, x, memory, tgt_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None): | |
| for layer in self.layers: | |
| x = layer(x, memory, tgt_mask=tgt_mask, | |
| tgt_key_padding_mask=tgt_key_padding_mask, | |
| memory_key_padding_mask=memory_key_padding_mask) | |
| return x | |
| class SharedCore(nn.Module): | |
| """Shared core — the brain. 6 encoder layers + 6 decoder layers. | |
| The core processes encoder output through its encoder layers, | |
| then the decoder side uses cross-attention to attend to core encoder output. | |
| Control embeddings are prepended to the encoder sequence. | |
| """ | |
| def __init__(self, d_model=D_MODEL, n_layers=CORE_LAYERS, n_heads=N_HEADS, d_ff=D_FF, dropout=DROPOUT): | |
| super().__init__() | |
| self.encoder_layers = nn.ModuleList([ | |
| TransformerEncoderBlock(d_model, n_heads, d_ff, dropout) | |
| for _ in range(n_layers) | |
| ]) | |
| self.decoder_layers = nn.ModuleList([ | |
| TransformerDecoderBlock(d_model, n_heads, d_ff, dropout) | |
| for _ in range(n_layers) | |
| ]) | |
| def encode(self, x, src_key_padding_mask=None): | |
| """Process encoder output through core encoder layers.""" | |
| for layer in self.encoder_layers: | |
| x = layer(x, src_key_padding_mask=src_key_padding_mask) | |
| return x | |
| def decode(self, x, memory, tgt_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None): | |
| """Process decoder through core decoder layers with cross-attention to core encoder output.""" | |
| for layer in self.decoder_layers: | |
| x = layer(x, memory, tgt_mask=tgt_mask, | |
| tgt_key_padding_mask=tgt_key_padding_mask, | |
| memory_key_padding_mask=memory_key_padding_mask) | |
| return x | |
| class ControlMT(nn.Module): | |
| """ | |
| ControlMT — Modular Encoder-Decoder Transformer | |
| Flow: Input -> Lang Encoder -> Shared Core Encoder -> Shared Core Decoder <- Lang Decoder -> Output | |
| Control embeddings prepended to encoder sequence for style/format control. | |
| """ | |
| def __init__(self, vocab_size, d_model=D_MODEL, n_heads=N_HEADS, d_ff=D_FF, | |
| dropout=DROPOUT, max_seq_len=MAX_SEQ_LEN, n_control_tokens=6): | |
| super().__init__() | |
| self.d_model = d_model | |
| self.vocab_size = vocab_size | |
| # Shared token embedding (all languages share vocabulary) | |
| self.token_embedding = nn.Embedding(vocab_size, d_model, padding_idx=PAD_ID) | |
| self.pos_encoding = PositionalEncoding(d_model, max_seq_len, dropout) | |
| # Control embeddings (style/format — injected into encoder sequence) | |
| self.control_embedding = nn.Embedding(n_control_tokens, d_model) | |
| # Per-language encoders | |
| self.encoders = nn.ModuleDict({ | |
| "kn": LanguageEncoder(d_model), | |
| "en": LanguageEncoder(d_model), | |
| }) | |
| # Shared core (the brain) | |
| self.core = SharedCore(d_model) | |
| # Per-language decoders | |
| self.decoders = nn.ModuleDict({ | |
| "kn": LanguageDecoder(d_model), | |
| "en": LanguageDecoder(d_model), | |
| }) | |
| # Output projection (shared across languages) | |
| self.output_proj = nn.Linear(d_model, vocab_size) | |
| # Tie embedding weights with output projection | |
| self.output_proj.weight = self.token_embedding.weight | |
| # Init weights | |
| self._init_weights() | |
| def _init_weights(self): | |
| """Xavier uniform initialization.""" | |
| for p in self.parameters(): | |
| if p.dim() > 1: | |
| nn.init.xavier_uniform_(p) | |
| def _get_lang(self, direction_id): | |
| """Get src/tgt language from direction ID. | |
| v2.1 supports: | |
| kn2en (4) → kn enc, en dec | |
| en2kn (5) → en enc, kn dec | |
| rkn2kn (12) → en enc, kn dec (romanized Kannada uses EN encoder — it's Latin script) | |
| v3 reservations (rkn2en, hi2en, en2hi) not yet wired. | |
| """ | |
| if direction_id == DIRECTION_TOKENS["kn2en"]: | |
| return "kn", "en" | |
| elif direction_id == DIRECTION_TOKENS["en2kn"]: | |
| return "en", "kn" | |
| elif direction_id == DIRECTION_TOKENS["rkn2kn"]: | |
| # Romanized Kannada is Latin-script — route through the EN encoder. | |
| # Target is Kannada script → KN decoder. | |
| return "en", "kn" | |
| else: | |
| raise ValueError(f"Unknown direction ID: {direction_id}") | |
| def generate_square_subsequent_mask(sz, device): | |
| """Causal mask for decoder self-attention.""" | |
| mask = torch.triu(torch.ones(sz, sz, device=device), diagonal=1).bool() | |
| return mask | |
| def encode(self, src_ids, src_mask, direction_id, control_id=CONTROL_TOKENS["strict"]): | |
| """ | |
| Encode source sequence. | |
| Args: | |
| src_ids: (batch, src_len) — source token IDs | |
| src_mask: (batch, src_len) — 1 for real tokens, 0 for padding | |
| direction_id: int — direction token ID (4=KN2EN, 5=EN2KN) | |
| control_id: int — control token ID (6=strict, etc.) | |
| Returns: | |
| memory: (batch, src_len+1, d_model) — encoded representation | |
| memory_key_padding_mask: (batch, src_len+1) — padding mask | |
| """ | |
| src_lang, _ = self._get_lang(direction_id) | |
| # Embed tokens | |
| x = self.token_embedding(src_ids) * math.sqrt(self.d_model) | |
| x = self.pos_encoding(x) | |
| # Create padding mask (True = ignore) | |
| src_key_padding_mask = (src_mask == 0) | |
| # Pass through language-specific encoder | |
| x = self.encoders[src_lang](x, src_key_padding_mask=src_key_padding_mask) | |
| # Prepend control embedding | |
| batch_size = x.size(0) | |
| ctrl = self.control_embedding(torch.tensor([control_id - 6], device=x.device)) # offset by first control ID | |
| ctrl = ctrl.unsqueeze(0).expand(batch_size, -1, -1) # (batch, 1, d_model) | |
| x = torch.cat([ctrl, x], dim=1) # (batch, src_len+1, d_model) | |
| # Extend padding mask for control token (always attend to it) | |
| ctrl_mask = torch.zeros(batch_size, 1, dtype=torch.bool, device=x.device) | |
| memory_key_padding_mask = torch.cat([ctrl_mask, src_key_padding_mask], dim=1) | |
| # Pass through shared core encoder | |
| memory = self.core.encode(x, src_key_padding_mask=memory_key_padding_mask) | |
| return memory, memory_key_padding_mask | |
| def decode(self, tgt_ids, tgt_mask, memory, memory_key_padding_mask, direction_id): | |
| """ | |
| Decode target sequence. | |
| Args: | |
| tgt_ids: (batch, tgt_len) | |
| tgt_mask: (batch, tgt_len) — 1 for real tokens, 0 for padding | |
| memory: (batch, src_len+1, d_model) | |
| memory_key_padding_mask: (batch, src_len+1) | |
| direction_id: int | |
| Returns: | |
| logits: (batch, tgt_len, vocab_size) | |
| """ | |
| _, tgt_lang = self._get_lang(direction_id) | |
| # Embed target tokens | |
| x = self.token_embedding(tgt_ids) * math.sqrt(self.d_model) | |
| x = self.pos_encoding(x) | |
| # Causal mask for decoder | |
| tgt_len = tgt_ids.size(1) | |
| causal_mask = self.generate_square_subsequent_mask(tgt_len, tgt_ids.device) | |
| tgt_key_padding_mask = (tgt_mask == 0) | |
| # Pass through shared core decoder | |
| x = self.core.decode(x, memory, tgt_mask=causal_mask, | |
| tgt_key_padding_mask=tgt_key_padding_mask, | |
| memory_key_padding_mask=memory_key_padding_mask) | |
| # Pass through language-specific decoder | |
| x = self.decoders[tgt_lang](x, memory, tgt_mask=causal_mask, | |
| tgt_key_padding_mask=tgt_key_padding_mask, | |
| memory_key_padding_mask=memory_key_padding_mask) | |
| # Project to vocabulary | |
| logits = self.output_proj(x) | |
| return logits | |
| def forward(self, src_ids, src_mask, tgt_ids, tgt_mask, direction_id, control_id=CONTROL_TOKENS["strict"]): | |
| """ | |
| Full forward pass for training. | |
| Args: | |
| src_ids: (batch, src_len) | |
| src_mask: (batch, src_len) | |
| tgt_ids: (batch, tgt_len) | |
| tgt_mask: (batch, tgt_len) | |
| direction_id: int — single direction for the batch | |
| control_id: int | |
| Returns: | |
| logits: (batch, tgt_len, vocab_size) | |
| """ | |
| memory, memory_key_padding_mask = self.encode(src_ids, src_mask, direction_id, control_id) | |
| logits = self.decode(tgt_ids, tgt_mask, memory, memory_key_padding_mask, direction_id) | |
| return logits | |
| def count_parameters(model): | |
| """Count trainable parameters.""" | |
| total = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| breakdown = {} | |
| for name, module in model.named_children(): | |
| params = sum(p.numel() for p in module.parameters() if p.requires_grad) | |
| if params > 0: | |
| breakdown[name] = params | |
| return total, breakdown | |
| if __name__ == "__main__": | |
| # Test model with dummy data | |
| VOCAB_SIZE = 64000 | |
| BATCH_SIZE = 4 | |
| SRC_LEN = 20 | |
| TGT_LEN = 15 | |
| model = ControlMT(vocab_size=VOCAB_SIZE) | |
| total, breakdown = count_parameters(model) | |
| print(f"ControlMT — Total parameters: {total:,} ({total/1e6:.1f}M)") | |
| print(f"\nParameter breakdown:") | |
| for name, params in breakdown.items(): | |
| print(f" {name}: {params:,} ({params/1e6:.1f}M)") | |
| # Dummy forward pass | |
| src_ids = torch.randint(4, VOCAB_SIZE, (BATCH_SIZE, SRC_LEN)) | |
| tgt_ids = torch.randint(4, VOCAB_SIZE, (BATCH_SIZE, TGT_LEN)) | |
| src_mask = torch.ones(BATCH_SIZE, SRC_LEN, dtype=torch.long) | |
| tgt_mask = torch.ones(BATCH_SIZE, TGT_LEN, dtype=torch.long) | |
| # Test KN->EN | |
| logits = model(src_ids, src_mask, tgt_ids, tgt_mask, direction_id=4) | |
| print(f"\nForward pass (KN->EN):") | |
| print(f" Input: src={src_ids.shape}, tgt={tgt_ids.shape}") | |
| print(f" Output logits: {logits.shape}") | |
| print(f" Expected: ({BATCH_SIZE}, {TGT_LEN}, {VOCAB_SIZE})") | |
| # Test EN->KN | |
| logits = model(src_ids, src_mask, tgt_ids, tgt_mask, direction_id=5) | |
| print(f"\nForward pass (EN->KN):") | |
| print(f" Output logits: {logits.shape}") | |
| print("\nModel architecture test PASSED!") | |