| """Configuration dataclasses for SRT Adapter.""" |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import dataclass, field |
|
|
|
|
| @dataclass |
| class MAHConfig: |
| """Metapragmatic Attention Head configuration.""" |
|
|
| d_sub: int = 512 |
| d_divergence: int = 256 |
| num_heads: int = 4 |
| dropout: float = 0.1 |
|
|
|
|
| @dataclass |
| class RRMConfig: |
| """Reflexive Recurrent Module configuration.""" |
|
|
| d_meta: int = 512 |
| inject_scale: float = 1.0 |
|
|
|
|
| @dataclass |
| class BENConfig: |
| """Bifurcation Estimation Network configuration.""" |
|
|
| d_hidden: int = 256 |
|
|
|
|
| @dataclass |
| class CommunityConfig: |
| """Unsupervised community discovery configuration.""" |
|
|
| num_prototypes: int = 32 |
| d_community: int = 64 |
| temperature: float = 1.0 |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| use_prototypes: bool = True |
|
|
| def __post_init__(self) -> None: |
| import os |
| v = os.environ.get("SRT_USE_PROTOTYPES") |
| if v is not None and v.lower() in ("0", "false", "no", "off"): |
| self.use_prototypes = False |
|
|
|
|
| @dataclass |
| class LossConfig: |
| """Loss weights.""" |
|
|
| ce_weight: float = 1.0 |
| chain_weight: float = 0.5 |
| bif_weight: float = 1.0 |
| regime_weight: float = 5.0 |
| div_alive_weight: float = 0.1 |
| |
| |
| |
| |
| inject_reg_weight: float = 0.0 |
| inject_target_norm: float = 1.0 |
| community_entropy_weight: float = 0.01 |
| |
| |
| |
| |
| |
| community_supcon_weight: float = 2.0 |
| community_supcon_temperature: float = 0.1 |
| |
| |
| |
| |
| |
| |
| |
| divergence_supcon_weight: float = 1.0 |
| divergence_supcon_temperature: float = 0.1 |
| listnet_weight: float = 0.5 |
| listnet_temperature: float = 1.0 |
| chain_residual_aux_weight: float = 0.05 |
| chain_residual_aux_target: float = 0.5 |
| |
| |
| |
| |
| |
| |
| |
| |
| archetype_supcon_weight: float = 0.0 |
| archetype_supcon_temperature: float = 0.1 |
|
|
|
|
| @dataclass |
| class TrainingConfig: |
| """Training hyperparameters.""" |
|
|
| lr: float = 3e-4 |
| weight_decay: float = 0.01 |
| epochs: int = 3 |
| batch_size: int = 16 |
| max_seq_len: int = 512 |
| val_every: int = 1000 |
| log_every: int = 100 |
| patience: int = 5 |
| warmup_steps: int = 500 |
| grad_clip: float = 1.0 |
|
|
|
|
| @dataclass |
| class SRTConfig: |
| """Top-level SRT Adapter configuration.""" |
|
|
| backbone_id: str = "Qwen/Qwen2.5-7B" |
| backbone_dtype: str = "bfloat16" |
|
|
| |
| mah_layer_indices: list[int] = field(default_factory=list) |
| rrm_inject_indices: list[int] = field(default_factory=list) |
| community_layer_idx: int = -1 |
|
|
| num_mah_layers: int = 3 |
|
|
| mah: MAHConfig = field(default_factory=MAHConfig) |
| rrm: RRMConfig = field(default_factory=RRMConfig) |
| ben: BENConfig = field(default_factory=BENConfig) |
| community: CommunityConfig = field(default_factory=CommunityConfig) |
| loss: LossConfig = field(default_factory=LossConfig) |
| training: TrainingConfig = field(default_factory=TrainingConfig) |
|
|
| def resolve_layer_indices(self, num_layers: int) -> None: |
| """Auto-compute layer indices from backbone depth if not set.""" |
| if not self.mah_layer_indices: |
| step = num_layers // (self.num_mah_layers + 1) |
| self.mah_layer_indices = [step * (i + 1) for i in range(self.num_mah_layers)] |
| if not self.rrm_inject_indices: |
| |
| self.rrm_inject_indices = self.mah_layer_indices[1:] |
| if self.community_layer_idx < 0: |
| self.community_layer_idx = max(1, num_layers // 7) |
|
|