| """V4 decoder config. | |
| Three variants for the ablation: | |
| - A: full LASER2 per-layer cross-attention | |
| - B: no LASER2 (pure decoder baseline) | |
| - C: LASER2 input-only (first layer only) | |
| """ | |
| from dataclasses import dataclass, field | |
| class V4Config: | |
| # Vocab — LASER2 SPM 50K | |
| vocab_size: int = 50004 # LASER2 fairseq dict: bos/pad/eos/unk + 50000 SPM | |
| pad_token_id: int = 1 | |
| bos_token_id: int = 0 | |
| eos_token_id: int = 2 | |
| # Model shape | |
| n_layer: int = 28 | |
| n_embd: int = 1024 | |
| n_head: int = 16 # query heads | |
| n_kv_head: int = 2 # GQA 8:1 | |
| head_dim: int = 64 | |
| # FFN | |
| ffn_mult: float = 5.4 # SwiGLU 5.4× = 5504 hidden | |
| ffn_hidden: int = 5504 # explicit | |
| # Context | |
| max_seq_len: int = 2048 | |
| # Positional | |
| rope_theta: float = 10000.0 | |
| # Cross-attention to LASER2 | |
| cross_attention_mode: str = "per_layer" # "per_layer" | "input_only" | "none" | |
| laser_dim: int = 1024 # LASER2 BiLSTM output is 512*2 = 1024 | |
| # Training | |
| dropout: float = 0.0 | |
| tied_embeddings: bool = True | |
| # Init | |
| init_std: float = 0.02 | |
| def config_variant_a() -> V4Config: | |
| """Variant A: full LASER2 per-layer cross-attention.""" | |
| return V4Config(cross_attention_mode="per_layer") | |
| def config_variant_b() -> V4Config: | |
| """Variant B: pure decoder baseline (no LASER2).""" | |
| return V4Config(cross_attention_mode="none") | |
| def config_variant_c() -> V4Config: | |
| """Variant C: LASER2 input-only (first layer).""" | |
| return V4Config(cross_attention_mode="input_only") | |
| def config_test() -> V4Config: | |
| """Small config for unit tests.""" | |
| return V4Config( | |
| vocab_size=256, | |
| n_layer=2, | |
| n_embd=64, | |
| n_head=4, | |
| n_kv_head=2, | |
| head_dim=16, | |
| ffn_mult=4, | |
| ffn_hidden=256, | |
| max_seq_len=128, | |
| cross_attention_mode="per_layer", | |
| ) | |