| import importlib |
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import PreTrainedModel, GenerationMixin |
| from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
| from .configuration_tiny_gpt import TinyGPTConfig |
|
|
| _FLASH2_KERNEL = None |
| _FLASH3_KERNEL = None |
|
|
| def _get_flash2_kernel(): |
| global _FLASH2_KERNEL |
| if _FLASH2_KERNEL is None: |
| try: |
| kernels = importlib.import_module("kernels") |
| _FLASH2_KERNEL = kernels.get_kernel("kernels-community/flash-attn2", version=1) |
| except ImportError: |
| pass |
| return _FLASH2_KERNEL |
|
|
| def _get_flash3_kernel(): |
| global _FLASH3_KERNEL |
| if _FLASH3_KERNEL is None: |
| try: |
| kernels = importlib.import_module("kernels") |
| _FLASH3_KERNEL = kernels.get_kernel("kernels-community/flash-attn3", version=1) |
| except ImportError: |
| pass |
| return _FLASH3_KERNEL |
|
|
| def _get_sageattn(): |
| try: |
| module = importlib.import_module("sageattention") |
| return module.sageattn |
| except ImportError: |
| return None |
|
|
| class CausalSelfAttention(nn.Module): |
| def __init__(self, config: TinyGPTConfig): |
| super().__init__() |
| if config.n_embd % config.n_head != 0: |
| raise ValueError("n_embd must be divisible by n_head") |
| self.n_head = int(config.n_head) |
| self.head_dim = int(config.n_embd // config.n_head) |
| self.attention_backend = str(getattr(config, "attention_backend", "torch")) |
| self.torch_fallback = bool(getattr(config, "torch_fallback", True)) |
| self.dropout_p = float(config.dropout) if hasattr(config, "dropout") else 0.0 |
| if self.attention_backend not in ("sage", "torch", "flash2", "flash3"): |
| self.attention_backend = "torch" |
| if self.attention_backend == "sage" and self.head_dim not in (64, 96, 128): |
| self.attention_backend = "torch" |
| if self.attention_backend == "sage" and self.dropout_p != 0.0: |
| self.attention_backend = "torch" |
| if self.attention_backend == "flash3" and self.dropout_p != 0.0: |
| self.attention_backend = "torch" |
| if self.attention_backend in ("flash2", "flash3") and self.head_dim % 8 != 0: |
| self.attention_backend = "torch" |
| self.qkv = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False) |
| self.proj = nn.Linear(config.n_embd, config.n_embd, bias=False) |
| self.dropout = nn.Dropout(self.dropout_p) |
| mask = torch.tril(torch.ones(config.ctx_len, config.ctx_len, dtype=torch.bool)) |
| self.register_buffer("mask", mask.view(1, 1, config.ctx_len, config.ctx_len), persistent=False) |
| self.sageattn = None |
| self.flash_kernel = None |
| if self.attention_backend == "sage": |
| self.sageattn = _get_sageattn() |
| if self.sageattn is None and not self.torch_fallback: |
| raise RuntimeError("SageAttention requested but not available") |
| if self.attention_backend == "flash2": |
| self.flash_kernel = _get_flash2_kernel() |
| if self.flash_kernel is None and not self.torch_fallback: |
| raise RuntimeError("FlashAttention2 requested but not available") |
| if self.attention_backend == "flash3": |
| self.flash_kernel = _get_flash3_kernel() |
| if self.flash_kernel is None and not self.torch_fallback: |
| raise RuntimeError("FlashAttention3 requested but not available") |
|
|
| def _torch_attention(self, q, k, v, t): |
| scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) |
| scores = scores.masked_fill(self.mask[:, :, :t, :t] == 0, float("-inf")) |
| att = F.softmax(scores.float(), dim=-1).to(q.dtype) |
| att = self.dropout(att) |
| return att @ v |
|
|
| def _sage_attention(self, q, k, v): |
| if self.sageattn is None: |
| return None |
| if not q.is_cuda: |
| return None |
| try: |
| return self.sageattn(q.contiguous(), k.contiguous(), v.contiguous(), tensor_layout="HND", is_causal=True) |
| except Exception: |
| return None |
|
|
| def _flash2_attention(self, q, k, v): |
| if self.flash_kernel is None: |
| return None |
| if not q.is_cuda: |
| return None |
| try: |
| q = q.transpose(1, 2).contiguous() |
| k = k.transpose(1, 2).contiguous() |
| v = v.transpose(1, 2).contiguous() |
| dropout_p = self.dropout_p if self.training else 0.0 |
| y = self.flash_kernel.flash_attn_func(q, k, v, dropout_p=dropout_p, causal=True) |
| return y.transpose(1, 2).contiguous() |
| except Exception: |
| return None |
|
|
| def _flash3_attention(self, q, k, v): |
| if self.flash_kernel is None: |
| return None |
| if not q.is_cuda: |
| return None |
| try: |
| q = q.transpose(1, 2).contiguous() |
| k = k.transpose(1, 2).contiguous() |
| v = v.transpose(1, 2).contiguous() |
| y = self.flash_kernel.flash_attn_func(q, k, v, causal=True) |
| return y.transpose(1, 2).contiguous() |
| except Exception: |
| return None |
|
|
| def forward(self, x): |
| b, t, c = x.shape |
| qkv = self.qkv(x) |
| q, k, v = qkv.chunk(3, dim=-1) |
| q = q.view(b, t, self.n_head, self.head_dim).transpose(1, 2).contiguous() |
| k = k.view(b, t, self.n_head, self.head_dim).transpose(1, 2).contiguous() |
| v = v.view(b, t, self.n_head, self.head_dim).transpose(1, 2).contiguous() |
| if self.attention_backend == "sage": |
| y = self._sage_attention(q, k, v) |
| if y is None: |
| y = self._torch_attention(q, k, v, t) |
| elif self.attention_backend == "flash2": |
| y = self._flash2_attention(q, k, v) |
| if y is None: |
| y = self._torch_attention(q, k, v, t) |
| elif self.attention_backend == "flash3": |
| y = self._flash3_attention(q, k, v) |
| if y is None: |
| y = self._torch_attention(q, k, v, t) |
| else: |
| y = self._torch_attention(q, k, v, t) |
| y = y.transpose(1, 2).contiguous().view(b, t, c) |
| return self.proj(y) |
|
|
| class MLP(nn.Module): |
| def __init__(self, config: TinyGPTConfig): |
| super().__init__() |
| self.fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False) |
| self.proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False) |
| self.dropout = nn.Dropout(config.dropout) |
|
|
| def forward(self, x): |
| x = self.fc(x) |
| x = F.gelu(x) |
| x = self.proj(x) |
| x = self.dropout(x) |
| return x |
|
|
| class Block(nn.Module): |
| def __init__(self, config: TinyGPTConfig): |
| super().__init__() |
| self.ln1 = nn.LayerNorm(config.n_embd) |
| self.attn = CausalSelfAttention(config) |
| self.ln2 = nn.LayerNorm(config.n_embd) |
| self.mlp = MLP(config) |
|
|
| def forward(self, x): |
| x = x + self.attn(self.ln1(x)) |
| x = x + self.mlp(self.ln2(x)) |
| return x |
|
|
| class TinyGPTPreTrainedModel(PreTrainedModel): |
| config_class = TinyGPTConfig |
| base_model_prefix = "tiny_gpt" |
| supports_gradient_checkpointing = False |
|
|
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
| class TinyGPTModel(TinyGPTPreTrainedModel): |
| _tied_weights_keys = ["head.weight"] |
|
|
| def __init__(self, config: TinyGPTConfig): |
| super().__init__(config) |
| self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd) |
| self.pos_emb = nn.Embedding(config.ctx_len, config.n_embd) |
| self.drop = nn.Dropout(config.dropout) |
| self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)]) |
| self.ln_f = nn.LayerNorm(config.n_embd) |
| self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.tok_emb |
|
|
| def set_input_embeddings(self, value): |
| self.tok_emb = value |
| self.head.weight = self.tok_emb.weight |
|
|
| def get_output_embeddings(self): |
| return self.head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.head = new_embeddings |
|
|
| def tie_weights(self, *args, **kwargs): |
| self.head.weight = self.tok_emb.weight |
|
|
| def forward(self, input_ids, attention_mask=None, return_dict=True, return_logits=False, **kwargs): |
| b, t = input_ids.shape |
| if t > self.config.ctx_len: |
| raise ValueError(f"Input length {t} > ctx_len {self.config.ctx_len}. Truncate before calling the model.") |
| pos = torch.arange(0, t, dtype=torch.long, device=input_ids.device).unsqueeze(0) |
| x = self.tok_emb(input_ids) + self.pos_emb(pos) |
| x = self.drop(x) |
| for block in self.blocks: |
| x = block(x) |
| hidden = self.ln_f(x) |
| logits = self.head(hidden) if return_logits else None |
| if not return_dict: |
| return (hidden, logits) if return_logits else (hidden,) |
| if return_logits: |
| return hidden, logits |
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden, |
| past_key_values=None, |
| hidden_states=None, |
| attentions=None, |
| ) |
|
|
| class TinyGPTForCausalLM(TinyGPTPreTrainedModel, GenerationMixin): |
| _tied_weights_keys = ["tiny_gpt.head.weight"] |
|
|
| def __init__(self, config: TinyGPTConfig): |
| super().__init__(config) |
| self.tiny_gpt = TinyGPTModel(config) |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.tiny_gpt.tok_emb |
|
|
| def set_input_embeddings(self, value): |
| self.tiny_gpt.tok_emb = value |
| self.tiny_gpt.head.weight = self.tiny_gpt.tok_emb.weight |
|
|
| def get_output_embeddings(self): |
| return self.tiny_gpt.head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.tiny_gpt.head = new_embeddings |
|
|
| def tie_weights(self, *args, **kwargs): |
| self.tiny_gpt.head.weight = self.tiny_gpt.tok_emb.weight |
|
|
| def prepare_inputs_for_generation(self, input_ids, **kwargs): |
| return {"input_ids": input_ids} |
|
|
| def forward(self, input_ids, attention_mask=None, labels=None, return_dict=True, **kwargs): |
| hidden, logits = self.tiny_gpt( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| return_dict=True, |
| return_logits=True, |
| ) |
| loss = None |
| if labels is not None: |
| shift_logits = logits[:, :-1, :].contiguous() |
| shift_labels = labels[:, 1:].contiguous() |
| loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)).float(), shift_labels.view(-1)) |
| if not return_dict: |
| return ((loss, logits) if loss is not None else (logits,)) |
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=None, |
| hidden_states=None, |
| attentions=None, |
| ) |