|
|
| import torch
|
| import torch.nn as nn
|
| from typing import Generator
|
| import torch.nn.functional as F
|
| from transformers import PreTrainedModel
|
| from .configuration_gpt import GPTConfig
|
|
|
|
|
| class GPT(nn.Module):
|
| """
|
| The GPT language model:
|
| - Embeddings (token + positional)
|
| - Stack of Transformer blocks
|
| - Final LayerNorm + Linear head for output logits
|
| """
|
|
|
| def __init__(
|
| self,
|
| block_size: int = 1024,
|
| vocab_size: int = 50304,
|
| n_layer: int = 12,
|
| n_head: int = 12,
|
| n_embd: int = 768,
|
| ):
|
| super().__init__()
|
|
|
|
|
| self.block_size = block_size
|
| self.vocab_size = vocab_size
|
| self.n_layer = n_layer
|
| self.n_head = n_head
|
| self.n_embd = n_embd
|
|
|
|
|
| self.transformer = nn.ModuleDict(
|
| dict(
|
| wte=nn.Embedding(self.vocab_size, self.n_embd),
|
| wpe=nn.Embedding(self.block_size, self.n_embd),
|
| h=nn.ModuleList(
|
| [self.Block(self.n_embd, self.n_head) for _ in range(self.n_layer)]
|
| ),
|
| ln_f=nn.LayerNorm(self.n_embd),
|
| )
|
| )
|
|
|
|
|
| self.lm_head = nn.Linear(self.n_embd, self.vocab_size, bias=False)
|
|
|
|
|
| self.transformer.wte.weight = self.lm_head.weight
|
|
|
| def forward(self, x):
|
| B, T = x.shape
|
| assert T <= self.block_size, "Cannot forward sequence longer than block size"
|
|
|
|
|
| tok_emb = self.transformer.wte(x)
|
| pos_emb = self.transformer.wpe(torch.arange(T, device=x.device))
|
| x = tok_emb + pos_emb.unsqueeze(0)
|
|
|
|
|
| for block in self.transformer.h:
|
| x = block(x)
|
|
|
| x = self.transformer.ln_f(x)
|
| logits = self.lm_head(x)
|
| return logits
|
|
|
| class CausalSelfAttention(nn.Module):
|
| """
|
| Multi-head self-attention with causal masking.
|
| """
|
|
|
| def __init__(self, n_embd, n_head):
|
| super().__init__()
|
| assert (
|
| n_embd % n_head == 0
|
| ), "Embedding dimension must be divisible by number of heads"
|
| self.n_head = n_head
|
| self.n_embd = n_embd
|
|
|
|
|
| self.c_attn = nn.Linear(n_embd, 3 * n_embd)
|
| self.c_proj = nn.Linear(n_embd, n_embd)
|
|
|
| def forward(self, x):
|
| B, T, C = x.size()
|
| qkv = self.c_attn(x)
|
| q, k, v = qkv.split(self.n_embd, dim=2)
|
|
|
|
|
| k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
|
| q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
|
| v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
|
|
|
|
|
| y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
|
|
|
|
|
| y = y.transpose(1, 2).contiguous().view(B, T, C)
|
| y = self.c_proj(y)
|
| return y
|
|
|
| class MLP(nn.Module):
|
| """
|
| Feed-forward network block used in Transformer architectures.
|
| """
|
|
|
| def __init__(self, n_embd):
|
| super().__init__()
|
| self.c_fc = nn.Linear(n_embd, 4 * n_embd)
|
| self.gelu = nn.GELU(approximate="tanh")
|
| self.c_proj = nn.Linear(4 * n_embd, n_embd)
|
|
|
| def forward(self, x):
|
| return self.c_proj(self.gelu(self.c_fc(x)))
|
|
|
| class Block(nn.Module):
|
| """
|
| A single Transformer block.
|
| """
|
|
|
| def __init__(self, n_embd, n_head):
|
| super().__init__()
|
| self.ln_1 = nn.LayerNorm(n_embd)
|
| self.attn = GPT.CausalSelfAttention(n_embd, n_head)
|
| self.ln_2 = nn.LayerNorm(n_embd)
|
| self.mlp = GPT.MLP(n_embd)
|
|
|
| def forward(self, x):
|
| x = x + self.attn(self.ln_1(x))
|
| x = x + self.mlp(self.ln_2(x))
|
| return x
|
|
|
|
|
| class GPTModelForTextGeneration(PreTrainedModel):
|
| """
|
| A wrapper class for GPT-based text generation.
|
| This integrates a Transformer model within the Hugging Face `PreTrainedModel` framework.
|
| """
|
|
|
| config_class = GPTConfig
|
|
|
| def __init__(self, config):
|
| super().__init__(config)
|
|
|
|
|
| self.model = GPT(
|
| block_size=config.block_size,
|
| vocab_size=config.vocab_size,
|
| n_layer=config.n_layer,
|
| n_head=config.n_head,
|
| n_embd=config.n_embd,
|
| )
|
|
|
| def forward(self, input_ids: torch.Tensor):
|
|
|
| assert isinstance(input_ids, torch.Tensor), "input_ids must be a PyTorch tensor"
|
|
|
| tokens = input_ids.clone()
|
| tokens = tokens.unsqueeze(0) if tokens.dim() == 1 else tokens
|
|
|
| assert (
|
| tokens.ndim == 2 and tokens.shape[0] == 1
|
| ), "input_ids must have 2 dimensions: (1, sequence_length)"
|
|
|
|
|
| assert torch.all(
|
| (tokens >= 0) & (tokens <= self.model.vocab_size)
|
| ), "input_ids contain invalid token values"
|
|
|
|
|
| logits = self.model.forward(tokens)
|
|
|
| return {"logits": logits}
|
|
|
| @torch.no_grad()
|
| def generate(
|
| self,
|
| input_ids: torch.Tensor,
|
| max_length: int = 50,
|
| do_sample: bool = True,
|
| top_k: int = 50,
|
| top_p: float = 0.95,
|
| temperature: float = 0.9,
|
| device: str = "cpu",
|
| ):
|
| """
|
| Generates text using autoregressive sampling with top-k, top-p, and temperature.
|
| """
|
|
|
|
|
| if device.startswith("cuda"):
|
| assert torch.cuda.is_available(), "CUDA is not available, please use 'cpu'"
|
| if device != "cuda":
|
| try:
|
| device_index = int(device.split(":")[1])
|
| assert (
|
| 0 <= device_index < torch.cuda.device_count()
|
| ), f"Invalid CUDA device index: {device_index}"
|
| except (IndexError, ValueError):
|
| raise ValueError(
|
| "Invalid device format. Use 'cpu', 'cuda', or 'cuda:N' where N is an integer."
|
| )
|
| elif device != "cpu":
|
| raise ValueError("Invalid device. Use 'cpu', 'cuda', or 'cuda:N'.")
|
|
|
|
|
| input_ids = input_ids.to(device)
|
|
|
|
|
| assert isinstance(input_ids, torch.Tensor), "input_ids must be a PyTorch tensor"
|
| tokens = input_ids.clone()
|
| tokens = tokens.unsqueeze(0) if tokens.dim() == 1 else tokens
|
|
|
| assert (
|
| tokens.ndim == 2 and tokens.shape[0] == 1
|
| ), "input_ids must have 2 dimensions: (1, sequence_length)"
|
|
|
|
|
| assert torch.all(
|
| (tokens >= 0) & (tokens < self.model.vocab_size)
|
| ), "input_ids contain invalid token values"
|
|
|
|
|
| assert (
|
| isinstance(max_length, int) and max_length >= 1
|
| ), "max_length must be a positive integer"
|
| assert (
|
| max_length <= self.model.block_size
|
| ), f"max_length must be in range [1, {self.model.block_size}]"
|
|
|
|
|
| assert isinstance(top_k, int) and top_k >= 1, "top_k must be a positive integer"
|
|
|
|
|
| assert (
|
| isinstance(top_p, (int, float)) and 0.0 <= top_p <= 1.0
|
| ), "top_p must be in range [0, 1]"
|
|
|
|
|
| assert (
|
| isinstance(temperature, (int, float)) and 0.0 <= temperature <= 1.0
|
| ), "temperature must be in range [0, 1]"
|
|
|
|
|
| tokens = tokens.to(device)
|
|
|
|
|
| while tokens.size(1) < max_length:
|
| logits = self.forward(tokens)["logits"][:, -1, :]
|
| logits = logits / temperature
|
|
|
| if do_sample:
|
| top_k = min(top_k, logits.size(-1))
|
|
|
|
|
| indices_to_remove = (
|
| logits < torch.topk(logits, top_k, dim=1)[0][..., -1, None]
|
| )
|
| logits[indices_to_remove] = float("-inf")
|
|
|
| sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
|
|
| cumulative_probs = torch.cumsum(
|
| F.softmax(sorted_logits, dim=-1), dim=-1
|
| )
|
|
|
| sorted_indices_to_remove = cumulative_probs > top_p
|
|
|
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
|
| ..., :-1
|
| ].clone()
|
| sorted_indices_to_remove[..., 0] = 0
|
|
|
|
|
| sorted_logits[sorted_indices_to_remove] = float("-inf")
|
|
|
| logits = torch.gather(sorted_logits, 1, sorted_indices.argsort(-1))
|
|
|
|
|
| next_tokens = torch.multinomial(F.softmax(logits, -1), 1)
|
| else:
|
| next_tokens = torch.argmax(logits, dim=-1, keepdim=True)
|
|
|
| tokens = torch.cat((tokens, next_tokens), dim=1)
|
|
|
| return tokens.flatten()
|
|
|