import math from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.modeling_utils import PreTrainedModel from .configuration_tinygpt import TinyGPTConfig class TinyGPTRMSNorm(nn.Module): def __init__(self, hidden_size: int, eps: float = 1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.eps = eps def forward(self, x: torch.Tensor) -> torch.Tensor: variance = x.float().pow(2).mean(dim=-1, keepdim=True) x = x * torch.rsqrt(variance + self.eps) return x * self.weight class TinyGPTAttention(nn.Module): def __init__(self, config: TinyGPTConfig): super().__init__() self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads if self.head_dim * self.num_heads != self.hidden_size: raise ValueError("hidden_size must be divisible by num_attention_heads") self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) self.dropout = nn.Dropout(config.dropout) def _shape(self, x: torch.Tensor) -> torch.Tensor: batch, seq_len, _ = x.size() return x.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: q = self._shape(self.q_proj(hidden_states)) k = self._shape(self.k_proj(hidden_states)) v = self._shape(self.v_proj(hidden_states)) attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) seq_len = hidden_states.size(1) causal_mask = torch.triu( torch.ones(seq_len, seq_len, device=hidden_states.device, dtype=torch.bool), diagonal=1, ) attn_scores = attn_scores.masked_fill(causal_mask, torch.finfo(attn_scores.dtype).min) if attention_mask is not None: key_mask = attention_mask[:, None, None, :].to(torch.bool) attn_scores = attn_scores.masked_fill(~key_mask, torch.finfo(attn_scores.dtype).min) attn_probs = F.softmax(attn_scores, dim=-1, dtype=torch.float32).to(hidden_states.dtype) attn_probs = self.dropout(attn_probs) attn_output = torch.matmul(attn_probs, v) attn_output = attn_output.transpose(1, 2).contiguous().view( hidden_states.size(0), seq_len, self.hidden_size ) return self.out_proj(attn_output) class TinyGPTMLP(nn.Module): def __init__(self, config: TinyGPTConfig): super().__init__() self.fc_in = nn.Linear(config.hidden_size, config.intermediate_size, bias=True) self.fc_out = nn.Linear(config.intermediate_size, config.hidden_size, bias=True) self.dropout = nn.Dropout(config.dropout) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.fc_in(hidden_states) hidden_states = F.gelu(hidden_states) hidden_states = self.fc_out(hidden_states) return self.dropout(hidden_states) class TinyGPTBlock(nn.Module): def __init__(self, config: TinyGPTConfig): super().__init__() self.attn_norm = TinyGPTRMSNorm(config.hidden_size, eps=config.layer_norm_eps) self.attn = TinyGPTAttention(config) self.mlp_norm = TinyGPTRMSNorm(config.hidden_size, eps=config.layer_norm_eps) self.mlp = TinyGPTMLP(config) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = hidden_states + self.attn(self.attn_norm(hidden_states), attention_mask) hidden_states = hidden_states + self.mlp(self.mlp_norm(hidden_states)) return hidden_states class TinyGPTPreTrainedModel(PreTrainedModel): config_class = TinyGPTConfig base_model_prefix = "model" supports_gradient_checkpointing = False _no_split_modules = ["TinyGPTBlock"] def _init_weights(self, module): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=0.02) class TinyGPTModel(TinyGPTPreTrainedModel): def __init__(self, config: TinyGPTConfig): super().__init__(config) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.position_embeddings = nn.Parameter( torch.zeros(config.max_position_embeddings, config.hidden_size) ) self.dropout = nn.Dropout(config.dropout) self.layers = nn.ModuleList( [TinyGPTBlock(config) for _ in range(config.num_hidden_layers)] ) self.final_norm = TinyGPTRMSNorm(config.hidden_size, eps=config.layer_norm_eps) self.post_init() def forward( self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: seq_len = input_ids.size(1) hidden_states = self.embed_tokens(input_ids) + self.position_embeddings[:seq_len] hidden_states = self.dropout(hidden_states) for layer in self.layers: hidden_states = layer(hidden_states, attention_mask=attention_mask) hidden_states = self.final_norm(hidden_states) return hidden_states class TinyGPTForCausalLM(TinyGPTPreTrainedModel): _tied_weights_keys = [] def __init__(self, config: TinyGPTConfig): super().__init__(config) self.model = TinyGPTModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs): return {"input_ids": input_ids, "attention_mask": attention_mask} def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.LongTensor] = None, **kwargs, ) -> CausalLMOutputWithPast: hidden_states = self.model(input_ids=input_ids, attention_mask=attention_mask) logits = self.lm_head(hidden_states) loss = None if labels is not None: shift_logits = logits[:, :-1, :].contiguous() shift_labels = labels[:, 1:].contiguous() loss = F.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=self.config.pad_token_id, ) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=None, hidden_states=None, attentions=None, )