#!/usr/bin/env python3 """ AETHER-Micro Decoder Layer Transformer Decoder Layer: Attention + MoE """ import torch import torch.nn as nn from .configuration_aether_micro import AETHERMicroConfig from .normalization import AETHERMicroRMSNorm from .attention import AETHERMicroAttention from .moe import AETHERMicroMoE from .latent_thought import AETHERMicroLatentThought from .self_evaluation import AETHERMicroSelfEvalHead class AETHERMicroDecoderLayer(nn.Module): """ Transformer Decoder Layer Structure: 1. Input LayerNorm 2. Self-Attention (RoPE + GQA) 3. Post-Attention LayerNorm 4. MoE FFN (Heterogeneous experts) 5. Residual connections """ def __init__(self, config: AETHERMicroConfig): super().__init__() self.hidden_size = config.hidden_size # Self-Attention self.input_layernorm = AETHERMicroRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.self_attn = AETHERMicroAttention(config) # Latent Thought Loop (Block 1) if config.enable_latent_thought: self.latent_thought = AETHERMicroLatentThought(config) # Self Evaluation (Block 4) if config.enable_self_eval: self.self_evaluation = AETHERMicroSelfEvalHead(config) # MoE FFN self.post_attention_layernorm = AETHERMicroRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.mlp = AETHERMicroMoE(config) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor = None, position_ids: torch.LongTensor = None, disable_ltl: bool = False, ) -> torch.Tensor: """ Args: hidden_states: (batch_size, seq_length, hidden_size) attention_mask: (batch_size, 1, seq_length, seq_length) position_ids: (batch_size, seq_length) Returns: hidden_states: (batch_size, seq_length, hidden_size) """ residual = hidden_states # Self-Attention hidden_states = self.input_layernorm(hidden_states) hidden_states = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, ) hidden_states = residual + hidden_states # Latent Thought Loop (Block 1) - Checkpoint-safe if hasattr(self, 'latent_thought') and not disable_ltl: result = self.latent_thought(hidden_states) # Tuple unpacking (gradient checkpointing 호환) if isinstance(result, tuple): hidden_states = result[0] # metrics는 체크포인팅 모드에서는 무시됨 else: hidden_states = result # MoE FFN residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states # Self Evaluation (Block 4) # NOTE: self_evaluation returns (quality, overall) tuple for metrics # We don't modify hidden_states - just compute quality scores if hasattr(self, 'self_evaluation'): _ = self.self_evaluation(hidden_states) # Compute but don't assign return hidden_states