AETHER-Micro-0.5B / modeling_aether_micro.py
Be2Jay's picture
Add GenerationMixin for transformers v4.50+ compatibility
da01c22 verified
#!/usr/bin/env python3
"""
AETHER-Micro Model Implementation (Hugging Face Standard)
모듈화 구조:
- utils.py: Helper functions
- normalization.py: RMSNorm
- embeddings.py: RoPE
- attention.py: Multi-Head Attention
- router.py: Wu-Xing Router
- moe.py: Heterogeneous MoE
- layers.py: Decoder Layer
- modeling_aether_micro.py: Main Model (이 파일)
"""
import torch
import torch.nn as nn
import torch.utils.checkpoint
from typing import Optional, Tuple, Union
from transformers import PreTrainedModel, GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast
from .configuration_aether_micro import AETHERMicroConfig
from .normalization import AETHERMicroRMSNorm
from .layers import AETHERMicroDecoderLayer
from .quality_head import AETHERMicroQualityHead
from .mtp_loss import MTPLoss
# ========================================
# PreTrained Model Base Class
# ========================================
class AETHERMicroPreTrainedModel(PreTrainedModel):
"""
AETHER-Micro PreTrained Model Base Class
모든 AETHER-Micro 모델의 기본 클래스입니다.
HF의 save_pretrained, from_pretrained 기능을 제공합니다.
"""
config_class = AETHERMicroConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["AETHERMicroDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
def _init_weights(self, module):
"""
Initialize weights
Args:
module: nn.Module to initialize
"""
std = self.config.initializer_range if hasattr(self.config, 'initializer_range') else 0.02
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
"""Enable gradient checkpointing"""
if isinstance(module, AETHERMicroModel):
module.gradient_checkpointing = value
# ========================================
# Main Transformer Model
# ========================================
class AETHERMicroModel(AETHERMicroPreTrainedModel):
"""
Main Transformer Model
Structure:
- Embedding layer
- 24 Decoder layers
- Output RMSNorm
"""
def __init__(self, config: AETHERMicroConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
# Embedding
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
# Decoder layers
self.layers = nn.ModuleList([
AETHERMicroDecoderLayer(config)
for _ in range(config.num_hidden_layers)
])
# Output normalization
self.norm = AETHERMicroRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
# Initialize weights
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
disable_ltl: Optional[bool] = False,
) -> Union[Tuple, BaseModelOutputWithPast]:
"""
Args:
input_ids: (batch_size, sequence_length)
attention_mask: (batch_size, sequence_length)
position_ids: (batch_size, sequence_length)
inputs_embeds: (batch_size, sequence_length, hidden_size)
Returns:
BaseModelOutputWithPast or tuple
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
# Embeddings
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
# Position IDs
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
0, seq_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
# Attention mask (causal)
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length), dtype=torch.bool, device=hidden_states.device
)
# Causal mask: lower triangular
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states, 0
)
# Decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
# PyTorch 2.7+ non-reentrant mode (권장)
# decoder_layer.forward()가 항상 단일 tensor 반환하도록 수정됨
hidden_states = torch.utils.checkpoint.checkpoint(
decoder_layer,
hidden_states,
attention_mask,
position_ids,
disable_ltl,
use_reentrant=False
)
else:
hidden_states = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
disable_ltl=disable_ltl,
)
# Output normalization
hidden_states = self.norm(hidden_states)
# Add last hidden state
if output_hidden_states:
all_hidden_states += (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, None, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=None,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
"""
Prepare causal attention mask
Args:
attention_mask: (batch_size, seq_length)
input_shape: (batch_size, seq_length)
inputs_embeds: embeddings tensor
past_key_values_length: 0 for training
Returns:
combined_attention_mask: (batch_size, 1, seq_length, seq_length)
"""
# Create causal mask
# [batch_size, seq_length] -> [batch_size, 1, tgt_seq_length, src_seq_length]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [batch_size, seq_length] -> [batch_size, 1, tgt_seq_length, src_seq_length]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
# ========================================
# Causal Language Model
# ========================================
class AETHERMicroForCausalLM(AETHERMicroPreTrainedModel, GenerationMixin):
"""
AETHER-Micro Causal Language Model
Structure:
- AETHERMicroModel (base transformer)
- LM Head (hidden → vocab)
- Loss computation
"""
def __init__(self, config):
super().__init__(config)
self.model = AETHERMicroModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Quality Head (Block 3)
if config.enable_quality_head:
self.quality_head = AETHERMicroQualityHead(config)
# MTP Loss (Block 5)
if config.enable_mtp_loss:
self.mtp_loss = MTPLoss(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
disable_ltl: Optional[bool] = False,
) -> Union[Tuple, CausalLMOutputWithPast]:
"""
Args:
input_ids: (batch_size, sequence_length)
labels: (batch_size, sequence_length) - for training
Returns:
CausalLMOutputWithPast with loss, logits
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Forward through base model
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
disable_ltl=disable_ltl,
)
hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state
logits = self.lm_head(hidden_states)
logits = logits.float()
# Quality Head (Block 3)
quality_scores = None
if hasattr(self, 'quality_head'):
quality_scores = self.quality_head(hidden_states)
loss = None
mtp_metrics = None
if labels is not None:
if hasattr(self, 'mtp_loss') and self.config.enable_mtp_loss:
# MTP Loss (Block 5)
loss, mtp_metrics = self.mtp_loss(hidden_states, labels)
else:
# Standard NTP Loss
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values if hasattr(outputs, 'past_key_values') else None,
hidden_states=outputs.hidden_states if hasattr(outputs, 'hidden_states') else None,
attentions=outputs.attentions if hasattr(outputs, 'attentions') else None,
)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
"""Prepare inputs for generation"""
if past_key_values:
input_ids = input_ids[:, -1:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
"""Reorder cache for beam search"""
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
# ========================================
# Helper Functions for Attention Mask
# ========================================
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# ========================================
# Export all classes
# ========================================
__all__ = [
"AETHERMicroConfig",
"AETHERMicroPreTrainedModel",
"AETHERMicroModel",
"AETHERMicroForCausalLM",
]