"""V4 decoder config. Three variants for the ablation: - A: full LASER2 per-layer cross-attention - B: no LASER2 (pure decoder baseline) - C: LASER2 input-only (first layer only) """ from dataclasses import dataclass, field @dataclass class V4Config: # Vocab — LASER2 SPM 50K vocab_size: int = 50004 # LASER2 fairseq dict: bos/pad/eos/unk + 50000 SPM pad_token_id: int = 1 bos_token_id: int = 0 eos_token_id: int = 2 # Model shape n_layer: int = 28 n_embd: int = 1024 n_head: int = 16 # query heads n_kv_head: int = 2 # GQA 8:1 head_dim: int = 64 # FFN ffn_mult: float = 5.4 # SwiGLU 5.4× = 5504 hidden ffn_hidden: int = 5504 # explicit # Context max_seq_len: int = 2048 # Positional rope_theta: float = 10000.0 # Cross-attention to LASER2 cross_attention_mode: str = "per_layer" # "per_layer" | "input_only" | "none" laser_dim: int = 1024 # LASER2 BiLSTM output is 512*2 = 1024 # Training dropout: float = 0.0 tied_embeddings: bool = True # Init init_std: float = 0.02 def config_variant_a() -> V4Config: """Variant A: full LASER2 per-layer cross-attention.""" return V4Config(cross_attention_mode="per_layer") def config_variant_b() -> V4Config: """Variant B: pure decoder baseline (no LASER2).""" return V4Config(cross_attention_mode="none") def config_variant_c() -> V4Config: """Variant C: LASER2 input-only (first layer).""" return V4Config(cross_attention_mode="input_only") def config_test() -> V4Config: """Small config for unit tests.""" return V4Config( vocab_size=256, n_layer=2, n_embd=64, n_head=4, n_kv_head=2, head_dim=16, ffn_mult=4, ffn_hidden=256, max_seq_len=128, cross_attention_mode="per_layer", )