#!/usr/bin/env python3 """ AETHER-Micro Model Implementation (Hugging Face Standard) 모듈화 구조: - utils.py: Helper functions - normalization.py: RMSNorm - embeddings.py: RoPE - attention.py: Multi-Head Attention - router.py: Wu-Xing Router - moe.py: Heterogeneous MoE - layers.py: Decoder Layer - modeling_aether_micro.py: Main Model (이 파일) """ import torch import torch.nn as nn import torch.utils.checkpoint from typing import Optional, Tuple, Union from transformers import PreTrainedModel, GenerationMixin from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast from .configuration_aether_micro import AETHERMicroConfig from .normalization import AETHERMicroRMSNorm from .layers import AETHERMicroDecoderLayer from .quality_head import AETHERMicroQualityHead from .mtp_loss import MTPLoss # ======================================== # PreTrained Model Base Class # ======================================== class AETHERMicroPreTrainedModel(PreTrainedModel): """ AETHER-Micro PreTrained Model Base Class 모든 AETHER-Micro 모델의 기본 클래스입니다. HF의 save_pretrained, from_pretrained 기능을 제공합니다. """ config_class = AETHERMicroConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["AETHERMicroDecoderLayer"] _skip_keys_device_placement = "past_key_values" def _init_weights(self, module): """ Initialize weights Args: module: nn.Module to initialize """ std = self.config.initializer_range if hasattr(self.config, 'initializer_range') else 0.02 if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() def _set_gradient_checkpointing(self, module, value=False): """Enable gradient checkpointing""" if isinstance(module, AETHERMicroModel): module.gradient_checkpointing = value # ======================================== # Main Transformer Model # ======================================== class AETHERMicroModel(AETHERMicroPreTrainedModel): """ Main Transformer Model Structure: - Embedding layer - 24 Decoder layers - Output RMSNorm """ def __init__(self, config: AETHERMicroConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size # Embedding self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) # Decoder layers self.layers = nn.ModuleList([ AETHERMicroDecoderLayer(config) for _ in range(config.num_hidden_layers) ]) # Output normalization self.norm = AETHERMicroRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False # Initialize weights self.post_init() def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, disable_ltl: Optional[bool] = False, ) -> Union[Tuple, BaseModelOutputWithPast]: """ Args: input_ids: (batch_size, sequence_length) attention_mask: (batch_size, sequence_length) position_ids: (batch_size, sequence_length) inputs_embeds: (batch_size, sequence_length, hidden_size) Returns: BaseModelOutputWithPast or tuple """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: batch_size, seq_length = input_ids.shape elif inputs_embeds is not None: batch_size, seq_length, _ = inputs_embeds.shape else: raise ValueError("You have to specify either input_ids or inputs_embeds") # Embeddings if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) hidden_states = inputs_embeds # Position IDs if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( 0, seq_length, dtype=torch.long, device=device ) position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) # Attention mask (causal) if attention_mask is None: attention_mask = torch.ones( (batch_size, seq_length), dtype=torch.bool, device=hidden_states.device ) # Causal mask: lower triangular attention_mask = self._prepare_decoder_attention_mask( attention_mask, (batch_size, seq_length), hidden_states, 0 ) # Decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: # PyTorch 2.7+ non-reentrant mode (권장) # decoder_layer.forward()가 항상 단일 tensor 반환하도록 수정됨 hidden_states = torch.utils.checkpoint.checkpoint( decoder_layer, hidden_states, attention_mask, position_ids, disable_ltl, use_reentrant=False ) else: hidden_states = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, disable_ltl=disable_ltl, ) # Output normalization hidden_states = self.norm(hidden_states) # Add last hidden state if output_hidden_states: all_hidden_states += (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, None, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=None, hidden_states=all_hidden_states, attentions=all_self_attns, ) def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): """ Prepare causal attention mask Args: attention_mask: (batch_size, seq_length) input_shape: (batch_size, seq_length) inputs_embeds: embeddings tensor past_key_values_length: 0 for training Returns: combined_attention_mask: (batch_size, 1, seq_length, seq_length) """ # Create causal mask # [batch_size, seq_length] -> [batch_size, 1, tgt_seq_length, src_seq_length] combined_attention_mask = None if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( input_shape, inputs_embeds.dtype, device=inputs_embeds.device, past_key_values_length=past_key_values_length, ) if attention_mask is not None: # [batch_size, seq_length] -> [batch_size, 1, tgt_seq_length, src_seq_length] expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( inputs_embeds.device ) combined_attention_mask = ( expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask ) return combined_attention_mask # ======================================== # Causal Language Model # ======================================== class AETHERMicroForCausalLM(AETHERMicroPreTrainedModel, GenerationMixin): """ AETHER-Micro Causal Language Model Structure: - AETHERMicroModel (base transformer) - LM Head (hidden → vocab) - Loss computation """ def __init__(self, config): super().__init__(config) self.model = AETHERMicroModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Quality Head (Block 3) if config.enable_quality_head: self.quality_head = AETHERMicroQualityHead(config) # MTP Loss (Block 5) if config.enable_mtp_loss: self.mtp_loss = MTPLoss(config) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[list] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, disable_ltl: Optional[bool] = False, ) -> Union[Tuple, CausalLMOutputWithPast]: """ Args: input_ids: (batch_size, sequence_length) labels: (batch_size, sequence_length) - for training Returns: CausalLMOutputWithPast with loss, logits """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Forward through base model outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, disable_ltl=disable_ltl, ) hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state logits = self.lm_head(hidden_states) logits = logits.float() # Quality Head (Block 3) quality_scores = None if hasattr(self, 'quality_head'): quality_scores = self.quality_head(hidden_states) loss = None mtp_metrics = None if labels is not None: if hasattr(self, 'mtp_loss') and self.config.enable_mtp_loss: # MTP Loss (Block 5) loss, mtp_metrics = self.mtp_loss(hidden_states, labels) else: # Standard NTP Loss # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = nn.CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values if hasattr(outputs, 'past_key_values') else None, hidden_states=outputs.hidden_states if hasattr(outputs, 'hidden_states') else None, attentions=outputs.attentions if hasattr(outputs, 'attentions') else None, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): """Prepare inputs for generation""" if past_key_values: input_ids = input_ids[:, -1:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -1].unsqueeze(-1) # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} model_inputs.update( { "position_ids": position_ids, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, } ) return model_inputs @staticmethod def _reorder_cache(past_key_values, beam_idx): """Reorder cache for beam search""" reordered_past = () for layer_past in past_key_values: reordered_past += ( tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), ) return reordered_past # ======================================== # Helper Functions for Attention Mask # ======================================== def _make_causal_mask( input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 ): """ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) mask_cond = torch.arange(mask.size(-1), device=device) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) if past_key_values_length > 0: mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ bsz, src_len = mask.size() tgt_len = tgt_len if tgt_len is not None else src_len expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) inverted_mask = 1.0 - expanded_mask return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) # ======================================== # Export all classes # ======================================== __all__ = [ "AETHERMicroConfig", "AETHERMicroPreTrainedModel", "AETHERMicroModel", "AETHERMicroForCausalLM", ]