andrewdalpino's picture
Upload model
320681e verified
from math import sqrt
from functools import partial
from typing import Self
from collections.abc import Generator
from collections import deque
import torch
from torch import Tensor
from torch.nn import (
Module,
ModuleList,
Sequential,
Embedding,
Linear,
SiLU,
RMSNorm,
Dropout1d,
CrossEntropyLoss,
Parameter,
)
from torch.nn.functional import softmax, scaled_dot_product_attention
from torch.nn.utils.parametrize import register_parametrization, remove_parametrizations
from torch.utils.checkpoint import checkpoint as torch_checkpoint
from transformers import PretrainedConfig, PreTrainedModel
from caching import KVCache, DynamicKVBlock
from data import IGNORE_INDEX
class NoPEGPT(Module):
"""A generative pretrained transformer with no positional embeddings."""
def __init__(
self,
vocabulary_size: int,
embedding_dimensions: int,
num_heads: int,
num_layers: int,
feed_forward_ratio: int,
dropout: float,
):
super().__init__()
if vocabulary_size <= 0:
raise ValueError(
f"Vocabulary size must be greater than 0, {vocabulary_size} given."
)
if num_layers <= 0:
raise ValueError(f"Num layers must be greater than 0, {num_layers} given.")
token_embeddings = Embedding(vocabulary_size, embedding_dimensions)
output_layer = Linear(embedding_dimensions, vocabulary_size, bias=False)
output_layer.weight = token_embeddings.weight # Tie weights
self.token_embeddings = token_embeddings
self.body = ModuleList(
[
DecoderBlock(
embedding_dimensions,
num_heads,
feed_forward_ratio,
dropout,
)
for _ in range(num_layers)
]
)
self.checkpoint = lambda layer, x: layer(x)
self.output_norm = RMSNorm(embedding_dimensions)
self.output_layer = output_layer
self.loss_function = CrossEntropyLoss(ignore_index=IGNORE_INDEX)
self.vocabulary_size: int = vocabulary_size
self.embedding_dimensions: int = embedding_dimensions
self.num_heads: int = num_heads
self.num_layers: int = num_layers
@property
def num_trainable_params(self) -> int:
return sum(param.numel() for param in self.parameters() if param.requires_grad)
def enable_activation_checkpointing(self) -> None:
"""Instead of memorizing the activations of the forward pass, recompute them at various checkpoints."""
self.checkpoint = partial(torch_checkpoint, use_reentrant=False)
def freeze_model_parameters(self) -> None:
"""Freeze all model parameters to prevent them from being updated during training."""
for param in self.parameters():
param.requires_grad = False
@torch.no_grad()
def resize_token_embeddings(self, vocabulary_size: int) -> None:
"""Resize the token embeddings to accommodate a new vocabulary size."""
if vocabulary_size <= 0:
raise ValueError(
f"Vocabulary size must be greater than 0, {vocabulary_size} given."
)
new_embeddings = Embedding(vocabulary_size, self.embedding_dimensions)
new_embeddings = new_embeddings.to(self.token_embeddings.weight.device)
num_tokens_to_copy = min(vocabulary_size, self.token_embeddings.num_embeddings)
new_embeddings.weight[:num_tokens_to_copy, :] = self.token_embeddings.weight[
:num_tokens_to_copy, :
]
# Initialize new embeddings with kaiming normal distribution.
for i in range(num_tokens_to_copy, vocabulary_size):
new_embeddings.weight[i] = torch.randn(self.embedding_dimensions) / sqrt(
self.embedding_dimensions
)
self.token_embeddings.weight = new_embeddings.weight
self.token_embeddings.num_embeddings = new_embeddings.num_embeddings
self.output_layer.weight = self.token_embeddings.weight # Retie weights
self.vocabulary_size = vocabulary_size
def unfreeze_token_embeddings(self) -> None:
"""Unfreeze the token embeddings to allow for fine-tuning."""
self.token_embeddings.weight.requires_grad = True
def add_lora_parameters(self, rank: int, alpha: float, dropout: float) -> None:
"""Reparameterize the weights of the model using LoRA adapters."""
for module in self.body:
register_parametrization(
module.attention.qkv_proj,
"weight",
LoRA.from_linear(module.attention.qkv_proj, 3 * rank, alpha, dropout),
)
register_parametrization(
module.attention.out_proj,
"weight",
LoRA.from_linear(module.attention.out_proj, rank, alpha, dropout),
)
register_parametrization(
module.mlp.layers[0],
"weight",
LoRA.from_linear(module.mlp.layers[0], rank, alpha, dropout),
)
register_parametrization(
module.mlp.layers[2],
"weight",
LoRA.from_linear(module.mlp.layers[2], rank, alpha, dropout),
)
def lora_state_dict(self) -> dict[str, Tensor]:
"""Return a state dict containing only the LoRA parameters."""
return {
name: module for name, module in self.state_dict().items() if "lora" in name
}
def merge_lora_parameters(self) -> None:
"""Merge the LoRA parameters with the original parameters."""
for module in self.modules():
if hasattr(module, "parametrizations"):
lora_params = [name for name in module.parametrizations.keys()]
for name in lora_params:
remove_parametrizations(module, name)
def forward(
self, x: Tensor, y: Tensor | None = None
) -> tuple[Tensor, Tensor | None]:
"""A forward pass optimized for batch training."""
z = self.token_embeddings(x)
for layer in self.body:
z = self.checkpoint(layer, z)
z = self.output_norm(z)
z = self.output_layer(z)
if y is not None:
y_pred = z.view(-1, z.size(-1))
labels = y.view(-1) # Flatten the batch dimension.
loss = self.loss_function(y_pred, labels)
else:
loss = None
return z, loss
@torch.no_grad()
def predict(self, x: Tensor, kv_cache: KVCache) -> Tensor:
"""A forward pass optimized for next-token prediction."""
z = self.token_embeddings(x)
for layer, kv_block in zip(self.body, kv_cache):
z = layer.predict(z, kv_block)
z = z[:, -1, :] # Pluck only the last token embedding from each batch.
z = self.output_norm(z)
z = self.output_layer(z)
return z
@torch.no_grad()
def generate(
self,
prompt: Tensor,
max_tokens: int = 1000,
context_length: int = 1024,
temperature: float = 1.0,
top_k: int = 500,
top_p: float = 0.9,
repeat_penalty: float = 0.1,
repeat_window: int = 50,
) -> Generator[tuple[Tensor, Tensor], None, int]:
"""
Given a prompt, sample the next {max_tokens} tokens from the model weighted
by their predicted probabilities and filtered by the {top_k} and {top_p}.
"""
if max_tokens <= 0:
raise ValueError(f"Max tokens must be greater than 0, {max_tokens} given.")
if context_length <= 0:
raise ValueError(
f"Context length must be greater than 0, {context_length} given."
)
if temperature <= 0:
raise ValueError(
f"Temperature must be greater than 0, {temperature} given."
)
if top_k <= 0 or top_k > self.vocabulary_size:
raise ValueError(
f"Top k must be between 1 and {self.vocabulary_size}, {top_k} given."
)
if top_p <= 0.0 or top_p > 1.0:
raise ValueError(f"Top p must be between 0 and 1, {top_p} given.")
if repeat_penalty < 0.0 or repeat_penalty > 1.0:
raise ValueError(
f"Repeat penalty must be between 0 and 1, {repeat_penalty} given."
)
if repeat_window <= 0:
raise ValueError(
f"Repeat window must be greater than 0, {repeat_window} given."
)
kv_cache = KVCache(self, 1, context_length).to(prompt.device)
prompt = prompt[-context_length:]
previous_tokens = deque(maxlen=repeat_window)
num_tokens = 0
while num_tokens < max_tokens:
logits = self.predict(prompt.unsqueeze(0), kv_cache).squeeze()
for previous_token in previous_tokens:
logits[previous_token] -= repeat_penalty * torch.abs(
logits[previous_token]
)
logits, indices = torch.topk(logits, top_k, sorted=True)
logits /= temperature
probabilities = softmax(logits, dim=0)
cumulative_probability_mass = torch.cumsum(probabilities, dim=0)
min_probability_mass = cumulative_probability_mass[0]
threshold_p = max(top_p, min_probability_mass.item())
selected_indices = cumulative_probability_mass <= threshold_p
logits = logits[selected_indices]
indices = indices[selected_indices]
probabilities = softmax(logits, dim=0)
offset = torch.multinomial(probabilities, num_samples=1).squeeze()
next_token = indices[offset]
probability = probabilities[offset]
yield next_token, probability
num_tokens += 1
previous_tokens.append(next_token)
prompt = next_token.unsqueeze(0)
return num_tokens
class NoPEGPTHuggingFaceConfig(PretrainedConfig):
"""Provide a monolithic configuration object to enable compatibility with HuggingFace Transformers API."""
model_type = "nope-gpt"
def __init__(
self,
vocabulary_size: int = 50257,
embedding_dimensions: int = 1024,
num_heads: int = 16,
num_layers: int = 24,
feed_forward_ratio: int = 4,
dropout: float = 0.1,
**kwargs,
):
self.vocabulary_size = vocabulary_size
self.embedding_dimensions = embedding_dimensions
self.num_heads = num_heads
self.num_layers = num_layers
self.feed_forward_ratio = feed_forward_ratio
self.dropout = dropout
super().__init__(**kwargs)
class NoPEGPTHuggingFaceModel(PreTrainedModel):
"""Wrap model to enable compatibility with HuggingFace Transformers API."""
config_class = NoPEGPTHuggingFaceConfig
def __init__(self, config: NoPEGPTHuggingFaceConfig):
super().__init__(config)
self.model = NoPEGPT(
config.vocabulary_size,
config.embedding_dimensions,
config.num_heads,
config.num_layers,
config.feed_forward_ratio,
config.dropout,
)
def forward(self, x: Tensor, y: Tensor | None = None) -> dict[str, Tensor | None]:
logits, loss = self.model.forward(x, y)
return {
"logits": logits,
"loss": loss,
}
class DecoderBlock(Module):
"""Decoder block with multi-head attention, multilayer perceptron, and residual connections."""
def __init__(
self,
embedding_dimensions: int,
num_heads: int,
feed_forward_ratio: int,
dropout: float,
):
super().__init__()
self.norm1 = RMSNorm(embedding_dimensions)
self.attention = SelfAttention(embedding_dimensions, num_heads, dropout)
self.norm2 = RMSNorm(embedding_dimensions)
self.mlp = MLP(embedding_dimensions, feed_forward_ratio, dropout)
def forward(self, x: Tensor) -> Tensor:
z = self.norm1(x)
z = self.attention(z)
z = x + z # Residual connection
x = z
z = self.norm2(x)
z = self.mlp(z)
z = x + z # Residual connection
return z
@torch.no_grad()
def predict(self, x: Tensor, kv_block: DynamicKVBlock) -> Tensor:
"""A forward pass optimized for next-token prediction."""
z = self.norm1(x)
z = self.attention.predict(z, kv_block)
z = x + z # Residual connection
x = z
z = self.norm2(x)
z = self.mlp.predict(z)
z = x + z # Residual connection
return z
class SelfAttention(Module):
"""Multihead self-attention with causal masking."""
def __init__(self, embedding_dimensions: int, num_heads: int, dropout: float):
super().__init__()
if embedding_dimensions <= 0:
raise ValueError(
f"Embedding dimensions must be greater than 0, {embedding_dimensions} given."
)
if num_heads <= 0:
raise ValueError(f"Num heads must be greater than 0, {num_heads} given.")
if embedding_dimensions % num_heads != 0:
raise ValueError(
f"Embedding dimensions must be divisible by num heads, {embedding_dimensions} and {num_heads} given."
)
self.qkv_proj = Linear(
embedding_dimensions, 3 * embedding_dimensions, bias=False
)
self.out_proj = Linear(embedding_dimensions, embedding_dimensions, bias=False)
head_dimensions: int = embedding_dimensions // num_heads
scale: float = 1.0 / sqrt(head_dimensions)
self.embedding_dimensions: int = embedding_dimensions
self.num_heads: int = num_heads
self.head_dimensions: int = head_dimensions
self.scale: float = scale
self.dropout: float = dropout
def forward(self, x: Tensor) -> Tensor:
b, t, d = x.size()
q, k, v = self.qkv_proj(x).split(self.embedding_dimensions, dim=-1)
q = q.view(b, t, self.num_heads, self.head_dimensions).transpose(1, 2)
k = k.view(b, t, self.num_heads, self.head_dimensions).transpose(1, 2)
v = v.view(b, t, self.num_heads, self.head_dimensions).transpose(1, 2)
z = scaled_dot_product_attention(
q,
k,
v,
scale=self.scale,
dropout_p=self.dropout if self.training else 0,
is_causal=True,
)
z = z.transpose(1, 2).contiguous().view(b, t, d)
z = self.out_proj(z)
return z
@torch.no_grad()
def predict(self, x: Tensor, kv_block: DynamicKVBlock) -> Tensor:
"""A forward pass optimized for next-token prediction."""
b, t, d = x.size()
is_autoregressive_phase = t == 1
q, k, v = self.qkv_proj(x).split(self.embedding_dimensions, dim=-1)
q = q.view(b, t, self.num_heads, self.head_dimensions).transpose(1, 2)
k = k.view(b, t, self.num_heads, self.head_dimensions).transpose(1, 2)
v = v.view(b, t, self.num_heads, self.head_dimensions).transpose(1, 2)
k, v = kv_block.update(k, v)
z = scaled_dot_product_attention(
q,
k,
v,
scale=self.scale,
is_causal=not is_autoregressive_phase,
)
z = z.transpose(1, 2).contiguous().view(b, t, d)
z = self.out_proj(z)
return z
class MLP(Module):
"""A two layer fully-connected network with dropout."""
def __init__(
self, embedding_dimensions: int, feed_forward_ratio: int, dropout: float
):
super().__init__()
if feed_forward_ratio not in {1, 2, 4}:
raise ValueError("Feed-forward ratio must be either 1, 2, or 4.")
hidden_dimensions: int = feed_forward_ratio * embedding_dimensions
self.layers = Sequential(
Linear(embedding_dimensions, hidden_dimensions, bias=False),
SiLU(),
Linear(hidden_dimensions, embedding_dimensions, bias=False),
)
self.dropout = Dropout1d(p=dropout)
def forward(self, x: Tensor) -> Tensor:
return self.dropout(self.layers(x))
def predict(self, x: Tensor) -> Tensor:
return self.layers(x)
class LoRA(Module):
"""Low rank weight decomposition transformation."""
@classmethod
def from_linear(
cls, linear: Linear, rank: int, alpha: float, dropout: float
) -> Self:
out_features, in_features = linear.weight.shape
return cls(in_features, out_features, rank, alpha, dropout)
def __init__(
self,
in_features: int,
out_features: int,
rank: int,
alpha: float,
dropout: float,
):
super().__init__()
if rank <= 0:
raise ValueError(f"Rank must be greater than 0, {rank} given.")
if alpha <= 0.0:
raise ValueError(f"Alpha must be greater than 0, {alpha} given.")
lora_a = torch.randn(rank, in_features) / sqrt(rank)
lora_b = torch.zeros(out_features, rank)
self.lora_a = Parameter(lora_a)
self.lora_b = Parameter(lora_b)
self.dropout = Dropout1d(dropout)
self.alpha: float = alpha
def forward(self, weight: Tensor) -> Tensor:
z = self.lora_b @ self.dropout(self.lora_a)
z *= self.alpha
return weight + z