import math import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedTokenizerFast from transformers import PreTrainedModel from transformers import PretrainedConfig class GPTCustomConfig(PretrainedConfig): model_type = "gpt-custom" def __init__( self, vocab_size=4000, seq_len=512, Emb_dim=256, num_heads=4, num_layers=2, hidden_dim=512, dropout=0.1, eps=1e-5, dtype="float32", **kwargs ): super().__init__(**kwargs) self.vocab_size = vocab_size self.seq_len = seq_len self.Emb_dim = Emb_dim self.num_heads = num_heads self.num_layers = num_layers self.hidden_dim = hidden_dim self.dropout = dropout self.eps = eps self.dtype = dtype # --- Adaptive position embedding --- class SinusoidalPositionalEmbedding(nn.Module): def __init__(self, emb_dim, max_seq_len=10000): super().__init__() self.emb_dim = emb_dim self.max_seq_len = max_seq_len # Create initial positional encoding matrix self._create_pe_matrix(max_seq_len, emb_dim) def _create_pe_matrix(self, seq_len, emb_dim): """Create positional encoding matrix for given sequence length and embedding dimension""" position = torch.arange(0, seq_len).unsqueeze(1).float() # (seq_len, 1) div_term = torch.exp(torch.arange(0, emb_dim, 2).float() * -(math.log(10000.0) / emb_dim)) # (emb_dim//2) pe = torch.zeros(seq_len, emb_dim) pe[:, 0::2] = torch.sin(position * div_term) if emb_dim % 2 == 1: pe[:, 1::2] = torch.cos(position * div_term[:-1]) else: pe[:, 1::2] = torch.cos(position * div_term) self.register_buffer("pe", pe) # [seq_len, emb_dim] def forward(self, x): """ Args: x: Input tensor of shape (batch_size, seq_len) or (batch_size, seq_len, emb_dim) Returns: Positional embeddings of shape (batch_size, seq_len, emb_dim) """ if x.dim() == 2: batch_size, seq_len = x.shape elif x.dim() == 3: batch_size, seq_len, _ = x.shape else: raise ValueError(f"Input tensor must be 2D or 3D, got {x.dim()}D") # Check if we need to extend the positional encoding if seq_len > self.pe.size(0): self._create_pe_matrix(seq_len, self.emb_dim) # Move to the same device as input self.pe = self.pe.to(x.device) # Return positional embeddings for the current sequence length pos_emb = self.pe[:seq_len].unsqueeze(0).expand(batch_size, seq_len, -1) return pos_emb.to(x.device) # --- Multi Head Attention --- class MultiHeadAttention(nn.Module): def __init__(self, Emb_dim, num_heads, dropout=0.1, device='cpu', dtype=torch.float32): super().__init__() assert Emb_dim % num_heads == 0, "Emb_dim must be divisible by num_heads" self.Emb_dim = Emb_dim self.device = device self.num_heads = num_heads self.head_dim = Emb_dim // num_heads self.key = nn.Linear(Emb_dim, Emb_dim, bias=False, dtype=dtype, device=device) self.query = nn.Linear(Emb_dim, Emb_dim, bias=False, dtype=dtype, device=device) self.value = nn.Linear(Emb_dim, Emb_dim, bias=False, dtype=dtype, device=device) self.scale = math.sqrt(self.head_dim) self.out_proj = nn.Linear(Emb_dim, Emb_dim, dtype=dtype, device=device) self.dropout = nn.Dropout(dropout) def forward(self, x): batch_size, seq_len, _ = x.shape # Generate Q, K, V and reshape for multi-head attention keys = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) queries = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) values = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # Compute attention scores scores = (queries @ keys.transpose(-2, -1)) / self.scale # Create causal mask dynamically based on current sequence length causal_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool, device=x.device), diagonal=1) scores = scores.masked_fill(causal_mask[None, None, :, :], float('-inf')) # Apply softmax and dropout attn = F.softmax(scores, dim=-1) attn = self.dropout(attn) # Apply attention to values out = attn @ values # Concatenate heads and project out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.Emb_dim) return self.out_proj(out) # --- Feed Forward --- class FF_ReLU(nn.Module): def __init__(self, Emb_dim, hidden_dim, dropout=0.1, device='cpu', dtype=torch.float32): super().__init__() self.relu = nn.Sequential( nn.Linear(Emb_dim, hidden_dim, device=device, dtype=dtype), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim, Emb_dim, device=device, dtype=dtype), ) def forward(self, x): return self.relu(x) class LayerNorm(nn.Module): def __init__(self, Emb_dim, eps=1e-5, device='cpu', dtype=torch.float32): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(Emb_dim, device=device, dtype=dtype)) self.bias = nn.Parameter(torch.zeros(Emb_dim, device=device, dtype=dtype)) def forward(self, x): mean = x.mean(dim=-1, keepdim=True) var = x.var(dim=-1, keepdim=True, unbiased=False) norm_x = (x - mean) / torch.sqrt(var + self.eps) return norm_x * self.weight + self.bias # --- Transformer Block --- class Block(nn.Module): def __init__(self, Emb_dim, num_heads, dropout, hidden_dim, eps=1e-5, device='cpu', dtype=torch.float32): super().__init__() self.attention = MultiHeadAttention(Emb_dim, num_heads, dropout, device=device, dtype=dtype) self.norm1 = LayerNorm(Emb_dim, eps=eps, device=device, dtype=dtype) self.ff_relu = FF_ReLU(Emb_dim, hidden_dim, dropout, device=device, dtype=dtype) self.norm2 = LayerNorm(Emb_dim, eps=eps, device=device, dtype=dtype) def forward(self, x): # Pre-norm: normalize before attention residual = x x = self.norm1(x) x = self.attention(x) x = x + residual # Residual connection # Pre-norm: normalize before FF residual = x x = self.norm2(x) x = self.ff_relu(x) x = x + residual # Residual connection return x # --- Encoder --- class Encoder(nn.Module): def __init__(self, num_layers, Emb_dim, num_heads, dropout, hidden_dim, eps=1e-5, device='cpu', dtype=torch.float32): super().__init__() self.layers = nn.ModuleList([ Block(Emb_dim, num_heads, dropout, hidden_dim, eps, device=device, dtype=dtype) for _ in range(num_layers) ]) def forward(self, x): for layer in self.layers: x = layer(x) return x class GPTModel(PreTrainedModel): config_class = GPTCustomConfig def __init__(self, config): super().__init__(config) self.config = config self.vocab_size = config.vocab_size self._dtype = torch.float32 self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # --- Token embedding --- self.embedding = nn.Embedding(config.vocab_size, config.Emb_dim, dtype=self._dtype, device=self._device) # --- Embedding dropout --- self.emb_dropout = nn.Dropout(getattr(config, 'dropout', 0.1)) # --- Adaptive position embedding --- self.position_embedding = SinusoidalPositionalEmbedding( config.Emb_dim, max_seq_len=getattr(config, 'max_seq_len', config.seq_len) ) # --- Encoder --- self.encoder = Encoder( num_layers=config.num_layers, Emb_dim=config.Emb_dim, num_heads=config.num_heads, dropout=getattr(config, 'dropout', 0.1), hidden_dim=config.hidden_dim, eps=getattr(config, 'eps', 1e-5), device=self._device, dtype=self._dtype ) # --- Final norm self.final_norm = LayerNorm( config.Emb_dim, eps=getattr(config, 'eps', 1e-5), device=self._device, dtype=self._dtype ) # --- Output Projection self.lm_head = nn.Linear( config.Emb_dim, config.vocab_size, bias=False, dtype=self._dtype, device=self._device ) def forward(self, x): # x shape: (batch_size, seq_len) emb = self.embedding(x) # (batch_size, seq_len, emb_dim) pos = self.position_embedding(x) # (batch_size, seq_len, emb_dim) x = emb + pos x = self.emb_dropout(x) x = self.encoder(x) x = self.final_norm(x) x = self.lm_head(x) return x @torch.no_grad() def generate(self, prompt_ids, tokenizer, max_new_tokens=50, temperature=1.0, top_k=None, top_p=1.0, eos_token_id=None): self.eval() if prompt_ids.dim() == 1: prompt_ids = prompt_ids.unsqueeze(0) # (1, seq_len) # Always prepend the prefix prefix = torch.tensor([[1, 145, 31]], device=prompt_ids.device) generated = torch.cat([prefix, prompt_ids], dim=1) # (1, total_len) max_context_len = getattr(self.config, 'max_seq_len', getattr(self.config, 'seq_len', 2048)) for _ in range(max_new_tokens): # Use sliding window if sequence gets too long if generated.size(1) > max_context_len: input_ids = generated[:, -max_context_len:] else: input_ids = generated logits = self.forward(input_ids) # (batch_size, seq_len, vocab_size) logits = logits[:, -1, :] # (batch_size, vocab_size) # Temperature scaling if temperature != 1.0: logits = logits / temperature # Top-k filtering if top_k is not None and top_k > 0: topk_vals, topk_indices = torch.topk(logits, top_k) mask = torch.full_like(logits, float('-inf')) mask.scatter_(dim=-1, index=topk_indices, src=topk_vals) logits = mask # Top-p (nucleus) filtering if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) probs = F.softmax(sorted_logits, dim=-1) cum_probs = torch.cumsum(probs, dim=-1) sorted_mask = cum_probs > top_p sorted_mask[..., 1:] = sorted_mask[..., :-1].clone() sorted_mask[..., 0] = 0 indices_to_remove = sorted_mask.scatter(dim=-1, index=sorted_indices, src=sorted_mask) logits = logits.masked_fill(indices_to_remove, float('-inf')) # Sample next token probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) # (batch_size, 1) generated = torch.cat([generated, next_token], dim=-1) # check if we've reached the end of the sequence if eos_token_id is not None and next_token.item() == eos_token_id: break # Decode generated tokens into text generated_text = tokenizer.decode(generated[0].tolist(), skip_special_tokens=False) generated_text = generated_text.replace('', '\n').replace('', ' ') return generated_text