import importlib import math import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel, GenerationMixin from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from .configuration_tiny_gpt import TinyGPTConfig _FLASH2_KERNEL = None _FLASH3_KERNEL = None def _get_flash2_kernel(): global _FLASH2_KERNEL if _FLASH2_KERNEL is None: try: kernels = importlib.import_module("kernels") _FLASH2_KERNEL = kernels.get_kernel("kernels-community/flash-attn2", version=1) except ImportError: pass return _FLASH2_KERNEL def _get_flash3_kernel(): global _FLASH3_KERNEL if _FLASH3_KERNEL is None: try: kernels = importlib.import_module("kernels") _FLASH3_KERNEL = kernels.get_kernel("kernels-community/flash-attn3", version=1) except ImportError: pass return _FLASH3_KERNEL def _get_sageattn(): try: module = importlib.import_module("sageattention") return module.sageattn except ImportError: return None class CausalSelfAttention(nn.Module): def __init__(self, config: TinyGPTConfig): super().__init__() if config.n_embd % config.n_head != 0: raise ValueError("n_embd must be divisible by n_head") self.n_head = int(config.n_head) self.head_dim = int(config.n_embd // config.n_head) self.attention_backend = str(getattr(config, "attention_backend", "torch")) self.torch_fallback = bool(getattr(config, "torch_fallback", True)) self.dropout_p = float(config.dropout) if hasattr(config, "dropout") else 0.0 if self.attention_backend not in ("sage", "torch", "flash2", "flash3"): self.attention_backend = "torch" if self.attention_backend == "sage" and self.head_dim not in (64, 96, 128): self.attention_backend = "torch" if self.attention_backend == "sage" and self.dropout_p != 0.0: self.attention_backend = "torch" if self.attention_backend == "flash3" and self.dropout_p != 0.0: self.attention_backend = "torch" if self.attention_backend in ("flash2", "flash3") and self.head_dim % 8 != 0: self.attention_backend = "torch" self.qkv = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False) self.proj = nn.Linear(config.n_embd, config.n_embd, bias=False) self.dropout = nn.Dropout(self.dropout_p) mask = torch.tril(torch.ones(config.ctx_len, config.ctx_len, dtype=torch.bool)) self.register_buffer("mask", mask.view(1, 1, config.ctx_len, config.ctx_len), persistent=False) self.sageattn = None self.flash_kernel = None if self.attention_backend == "sage": self.sageattn = _get_sageattn() if self.sageattn is None and not self.torch_fallback: raise RuntimeError("SageAttention requested but not available") if self.attention_backend == "flash2": self.flash_kernel = _get_flash2_kernel() if self.flash_kernel is None and not self.torch_fallback: raise RuntimeError("FlashAttention2 requested but not available") if self.attention_backend == "flash3": self.flash_kernel = _get_flash3_kernel() if self.flash_kernel is None and not self.torch_fallback: raise RuntimeError("FlashAttention3 requested but not available") def _torch_attention(self, q, k, v, t): scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) scores = scores.masked_fill(self.mask[:, :, :t, :t] == 0, float("-inf")) att = F.softmax(scores.float(), dim=-1).to(q.dtype) att = self.dropout(att) return att @ v def _sage_attention(self, q, k, v): if self.sageattn is None: return None if not q.is_cuda: return None try: return self.sageattn(q.contiguous(), k.contiguous(), v.contiguous(), tensor_layout="HND", is_causal=True) except Exception: return None def _flash2_attention(self, q, k, v): if self.flash_kernel is None: return None if not q.is_cuda: return None try: q = q.transpose(1, 2).contiguous() k = k.transpose(1, 2).contiguous() v = v.transpose(1, 2).contiguous() dropout_p = self.dropout_p if self.training else 0.0 y = self.flash_kernel.flash_attn_func(q, k, v, dropout_p=dropout_p, causal=True) return y.transpose(1, 2).contiguous() except Exception: return None def _flash3_attention(self, q, k, v): if self.flash_kernel is None: return None if not q.is_cuda: return None try: q = q.transpose(1, 2).contiguous() k = k.transpose(1, 2).contiguous() v = v.transpose(1, 2).contiguous() y = self.flash_kernel.flash_attn_func(q, k, v, causal=True) return y.transpose(1, 2).contiguous() except Exception: return None def forward(self, x): b, t, c = x.shape qkv = self.qkv(x) q, k, v = qkv.chunk(3, dim=-1) q = q.view(b, t, self.n_head, self.head_dim).transpose(1, 2).contiguous() k = k.view(b, t, self.n_head, self.head_dim).transpose(1, 2).contiguous() v = v.view(b, t, self.n_head, self.head_dim).transpose(1, 2).contiguous() if self.attention_backend == "sage": y = self._sage_attention(q, k, v) if y is None: y = self._torch_attention(q, k, v, t) elif self.attention_backend == "flash2": y = self._flash2_attention(q, k, v) if y is None: y = self._torch_attention(q, k, v, t) elif self.attention_backend == "flash3": y = self._flash3_attention(q, k, v) if y is None: y = self._torch_attention(q, k, v, t) else: y = self._torch_attention(q, k, v, t) y = y.transpose(1, 2).contiguous().view(b, t, c) return self.proj(y) class MLP(nn.Module): def __init__(self, config: TinyGPTConfig): super().__init__() self.fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False) self.proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False) self.dropout = nn.Dropout(config.dropout) def forward(self, x): x = self.fc(x) x = F.gelu(x) x = self.proj(x) x = self.dropout(x) return x class Block(nn.Module): def __init__(self, config: TinyGPTConfig): super().__init__() self.ln1 = nn.LayerNorm(config.n_embd) self.attn = CausalSelfAttention(config) self.ln2 = nn.LayerNorm(config.n_embd) self.mlp = MLP(config) def forward(self, x): x = x + self.attn(self.ln1(x)) x = x + self.mlp(self.ln2(x)) return x class TinyGPTPreTrainedModel(PreTrainedModel): config_class = TinyGPTConfig base_model_prefix = "tiny_gpt" supports_gradient_checkpointing = False def _init_weights(self, module): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=0.02) class TinyGPTModel(TinyGPTPreTrainedModel): _tied_weights_keys = ["head.weight"] def __init__(self, config: TinyGPTConfig): super().__init__(config) self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd) self.pos_emb = nn.Embedding(config.ctx_len, config.n_embd) self.drop = nn.Dropout(config.dropout) self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)]) self.ln_f = nn.LayerNorm(config.n_embd) self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.post_init() def get_input_embeddings(self): return self.tok_emb def set_input_embeddings(self, value): self.tok_emb = value self.head.weight = self.tok_emb.weight def get_output_embeddings(self): return self.head def set_output_embeddings(self, new_embeddings): self.head = new_embeddings def tie_weights(self, *args, **kwargs): self.head.weight = self.tok_emb.weight def forward(self, input_ids, attention_mask=None, return_dict=True, return_logits=False, **kwargs): b, t = input_ids.shape if t > self.config.ctx_len: raise ValueError(f"Input length {t} > ctx_len {self.config.ctx_len}. Truncate before calling the model.") pos = torch.arange(0, t, dtype=torch.long, device=input_ids.device).unsqueeze(0) x = self.tok_emb(input_ids) + self.pos_emb(pos) x = self.drop(x) for block in self.blocks: x = block(x) hidden = self.ln_f(x) logits = self.head(hidden) if return_logits else None if not return_dict: return (hidden, logits) if return_logits else (hidden,) if return_logits: return hidden, logits return BaseModelOutputWithPast( last_hidden_state=hidden, past_key_values=None, hidden_states=None, attentions=None, ) class TinyGPTForCausalLM(TinyGPTPreTrainedModel, GenerationMixin): _tied_weights_keys = ["tiny_gpt.head.weight"] def __init__(self, config: TinyGPTConfig): super().__init__(config) self.tiny_gpt = TinyGPTModel(config) self.post_init() def get_input_embeddings(self): return self.tiny_gpt.tok_emb def set_input_embeddings(self, value): self.tiny_gpt.tok_emb = value self.tiny_gpt.head.weight = self.tiny_gpt.tok_emb.weight def get_output_embeddings(self): return self.tiny_gpt.head def set_output_embeddings(self, new_embeddings): self.tiny_gpt.head = new_embeddings def tie_weights(self, *args, **kwargs): self.tiny_gpt.head.weight = self.tiny_gpt.tok_emb.weight def prepare_inputs_for_generation(self, input_ids, **kwargs): return {"input_ids": input_ids} def forward(self, input_ids, attention_mask=None, labels=None, return_dict=True, **kwargs): hidden, logits = self.tiny_gpt( input_ids=input_ids, attention_mask=attention_mask, return_dict=True, return_logits=True, ) loss = None if labels is not None: shift_logits = logits[:, :-1, :].contiguous() shift_labels = labels[:, 1:].contiguous() loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)).float(), shift_labels.view(-1)) if not return_dict: return ((loss, logits) if loss is not None else (logits,)) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=None, hidden_states=None, attentions=None, )