| | |
| | """ |
| | AETHER-Micro Decoder Layer |
| | |
| | Transformer Decoder Layer: Attention + MoE |
| | """ |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | from .configuration_aether_micro import AETHERMicroConfig |
| | from .normalization import AETHERMicroRMSNorm |
| | from .attention import AETHERMicroAttention |
| | from .moe import AETHERMicroMoE |
| | from .latent_thought import AETHERMicroLatentThought |
| | from .self_evaluation import AETHERMicroSelfEvalHead |
| |
|
| |
|
| | class AETHERMicroDecoderLayer(nn.Module): |
| | """ |
| | Transformer Decoder Layer |
| | |
| | Structure: |
| | 1. Input LayerNorm |
| | 2. Self-Attention (RoPE + GQA) |
| | 3. Post-Attention LayerNorm |
| | 4. MoE FFN (Heterogeneous experts) |
| | 5. Residual connections |
| | """ |
| |
|
| | def __init__(self, config: AETHERMicroConfig): |
| | super().__init__() |
| | self.hidden_size = config.hidden_size |
| |
|
| | |
| | self.input_layernorm = AETHERMicroRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| | self.self_attn = AETHERMicroAttention(config) |
| |
|
| | |
| | if config.enable_latent_thought: |
| | self.latent_thought = AETHERMicroLatentThought(config) |
| |
|
| | |
| | if config.enable_self_eval: |
| | self.self_evaluation = AETHERMicroSelfEvalHead(config) |
| |
|
| | |
| | self.post_attention_layernorm = AETHERMicroRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| | self.mlp = AETHERMicroMoE(config) |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | attention_mask: torch.Tensor = None, |
| | position_ids: torch.LongTensor = None, |
| | disable_ltl: bool = False, |
| | ) -> torch.Tensor: |
| | """ |
| | Args: |
| | hidden_states: (batch_size, seq_length, hidden_size) |
| | attention_mask: (batch_size, 1, seq_length, seq_length) |
| | position_ids: (batch_size, seq_length) |
| | |
| | Returns: |
| | hidden_states: (batch_size, seq_length, hidden_size) |
| | """ |
| | residual = hidden_states |
| |
|
| | |
| | hidden_states = self.input_layernorm(hidden_states) |
| | hidden_states = self.self_attn( |
| | hidden_states=hidden_states, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | ) |
| | hidden_states = residual + hidden_states |
| |
|
| | |
| | if hasattr(self, 'latent_thought') and not disable_ltl: |
| | result = self.latent_thought(hidden_states) |
| | |
| | if isinstance(result, tuple): |
| | hidden_states = result[0] |
| | |
| | else: |
| | hidden_states = result |
| |
|
| | |
| | residual = hidden_states |
| | hidden_states = self.post_attention_layernorm(hidden_states) |
| | hidden_states = self.mlp(hidden_states) |
| | hidden_states = residual + hidden_states |
| |
|
| | |
| | |
| | |
| | if hasattr(self, 'self_evaluation'): |
| | _ = self.self_evaluation(hidden_states) |
| |
|
| | return hidden_states |
| |
|