File size: 1,886 Bytes
e118aba | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 | """V4 decoder config.
Three variants for the ablation:
- A: full LASER2 per-layer cross-attention
- B: no LASER2 (pure decoder baseline)
- C: LASER2 input-only (first layer only)
"""
from dataclasses import dataclass, field
@dataclass
class V4Config:
# Vocab — LASER2 SPM 50K
vocab_size: int = 50004 # LASER2 fairseq dict: bos/pad/eos/unk + 50000 SPM
pad_token_id: int = 1
bos_token_id: int = 0
eos_token_id: int = 2
# Model shape
n_layer: int = 28
n_embd: int = 1024
n_head: int = 16 # query heads
n_kv_head: int = 2 # GQA 8:1
head_dim: int = 64
# FFN
ffn_mult: float = 5.4 # SwiGLU 5.4× = 5504 hidden
ffn_hidden: int = 5504 # explicit
# Context
max_seq_len: int = 2048
# Positional
rope_theta: float = 10000.0
# Cross-attention to LASER2
cross_attention_mode: str = "per_layer" # "per_layer" | "input_only" | "none"
laser_dim: int = 1024 # LASER2 BiLSTM output is 512*2 = 1024
# Training
dropout: float = 0.0
tied_embeddings: bool = True
# Init
init_std: float = 0.02
def config_variant_a() -> V4Config:
"""Variant A: full LASER2 per-layer cross-attention."""
return V4Config(cross_attention_mode="per_layer")
def config_variant_b() -> V4Config:
"""Variant B: pure decoder baseline (no LASER2)."""
return V4Config(cross_attention_mode="none")
def config_variant_c() -> V4Config:
"""Variant C: LASER2 input-only (first layer)."""
return V4Config(cross_attention_mode="input_only")
def config_test() -> V4Config:
"""Small config for unit tests."""
return V4Config(
vocab_size=256,
n_layer=2,
n_embd=64,
n_head=4,
n_kv_head=2,
head_dim=16,
ffn_mult=4,
ffn_hidden=256,
max_seq_len=128,
cross_attention_mode="per_layer",
)
|