|
|
from math import sqrt |
|
|
from functools import partial |
|
|
from typing import Self |
|
|
from collections.abc import Generator |
|
|
from collections import deque |
|
|
|
|
|
import torch |
|
|
|
|
|
from torch import Tensor |
|
|
from torch.nn import ( |
|
|
Module, |
|
|
ModuleList, |
|
|
Sequential, |
|
|
Embedding, |
|
|
Linear, |
|
|
SiLU, |
|
|
RMSNorm, |
|
|
Dropout1d, |
|
|
CrossEntropyLoss, |
|
|
Parameter, |
|
|
) |
|
|
|
|
|
from torch.nn.functional import softmax, scaled_dot_product_attention |
|
|
from torch.nn.utils.parametrize import register_parametrization, remove_parametrizations |
|
|
from torch.utils.checkpoint import checkpoint as torch_checkpoint |
|
|
|
|
|
from transformers import PretrainedConfig, PreTrainedModel |
|
|
|
|
|
from caching import KVCache, DynamicKVBlock |
|
|
|
|
|
from data import IGNORE_INDEX |
|
|
|
|
|
|
|
|
class NoPEGPT(Module): |
|
|
"""A generative pretrained transformer with no positional embeddings.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vocabulary_size: int, |
|
|
embedding_dimensions: int, |
|
|
num_heads: int, |
|
|
num_layers: int, |
|
|
feed_forward_ratio: int, |
|
|
dropout: float, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
if vocabulary_size <= 0: |
|
|
raise ValueError( |
|
|
f"Vocabulary size must be greater than 0, {vocabulary_size} given." |
|
|
) |
|
|
|
|
|
if num_layers <= 0: |
|
|
raise ValueError(f"Num layers must be greater than 0, {num_layers} given.") |
|
|
|
|
|
token_embeddings = Embedding(vocabulary_size, embedding_dimensions) |
|
|
|
|
|
output_layer = Linear(embedding_dimensions, vocabulary_size, bias=False) |
|
|
|
|
|
output_layer.weight = token_embeddings.weight |
|
|
|
|
|
self.token_embeddings = token_embeddings |
|
|
|
|
|
self.body = ModuleList( |
|
|
[ |
|
|
DecoderBlock( |
|
|
embedding_dimensions, |
|
|
num_heads, |
|
|
feed_forward_ratio, |
|
|
dropout, |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
] |
|
|
) |
|
|
|
|
|
self.checkpoint = lambda layer, x: layer(x) |
|
|
|
|
|
self.output_norm = RMSNorm(embedding_dimensions) |
|
|
self.output_layer = output_layer |
|
|
|
|
|
self.loss_function = CrossEntropyLoss(ignore_index=IGNORE_INDEX) |
|
|
|
|
|
self.vocabulary_size: int = vocabulary_size |
|
|
self.embedding_dimensions: int = embedding_dimensions |
|
|
self.num_heads: int = num_heads |
|
|
self.num_layers: int = num_layers |
|
|
|
|
|
@property |
|
|
def num_trainable_params(self) -> int: |
|
|
return sum(param.numel() for param in self.parameters() if param.requires_grad) |
|
|
|
|
|
def enable_activation_checkpointing(self) -> None: |
|
|
"""Instead of memorizing the activations of the forward pass, recompute them at various checkpoints.""" |
|
|
|
|
|
self.checkpoint = partial(torch_checkpoint, use_reentrant=False) |
|
|
|
|
|
def freeze_model_parameters(self) -> None: |
|
|
"""Freeze all model parameters to prevent them from being updated during training.""" |
|
|
|
|
|
for param in self.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
@torch.no_grad() |
|
|
def resize_token_embeddings(self, vocabulary_size: int) -> None: |
|
|
"""Resize the token embeddings to accommodate a new vocabulary size.""" |
|
|
|
|
|
if vocabulary_size <= 0: |
|
|
raise ValueError( |
|
|
f"Vocabulary size must be greater than 0, {vocabulary_size} given." |
|
|
) |
|
|
|
|
|
new_embeddings = Embedding(vocabulary_size, self.embedding_dimensions) |
|
|
|
|
|
new_embeddings = new_embeddings.to(self.token_embeddings.weight.device) |
|
|
|
|
|
num_tokens_to_copy = min(vocabulary_size, self.token_embeddings.num_embeddings) |
|
|
|
|
|
new_embeddings.weight[:num_tokens_to_copy, :] = self.token_embeddings.weight[ |
|
|
:num_tokens_to_copy, : |
|
|
] |
|
|
|
|
|
|
|
|
for i in range(num_tokens_to_copy, vocabulary_size): |
|
|
new_embeddings.weight[i] = torch.randn(self.embedding_dimensions) / sqrt( |
|
|
self.embedding_dimensions |
|
|
) |
|
|
|
|
|
self.token_embeddings.weight = new_embeddings.weight |
|
|
self.token_embeddings.num_embeddings = new_embeddings.num_embeddings |
|
|
|
|
|
self.output_layer.weight = self.token_embeddings.weight |
|
|
|
|
|
self.vocabulary_size = vocabulary_size |
|
|
|
|
|
def unfreeze_token_embeddings(self) -> None: |
|
|
"""Unfreeze the token embeddings to allow for fine-tuning.""" |
|
|
|
|
|
self.token_embeddings.weight.requires_grad = True |
|
|
|
|
|
def add_lora_parameters(self, rank: int, alpha: float, dropout: float) -> None: |
|
|
"""Reparameterize the weights of the model using LoRA adapters.""" |
|
|
|
|
|
for module in self.body: |
|
|
register_parametrization( |
|
|
module.attention.qkv_proj, |
|
|
"weight", |
|
|
LoRA.from_linear(module.attention.qkv_proj, 3 * rank, alpha, dropout), |
|
|
) |
|
|
|
|
|
register_parametrization( |
|
|
module.attention.out_proj, |
|
|
"weight", |
|
|
LoRA.from_linear(module.attention.out_proj, rank, alpha, dropout), |
|
|
) |
|
|
|
|
|
register_parametrization( |
|
|
module.mlp.layers[0], |
|
|
"weight", |
|
|
LoRA.from_linear(module.mlp.layers[0], rank, alpha, dropout), |
|
|
) |
|
|
|
|
|
register_parametrization( |
|
|
module.mlp.layers[2], |
|
|
"weight", |
|
|
LoRA.from_linear(module.mlp.layers[2], rank, alpha, dropout), |
|
|
) |
|
|
|
|
|
def lora_state_dict(self) -> dict[str, Tensor]: |
|
|
"""Return a state dict containing only the LoRA parameters.""" |
|
|
|
|
|
return { |
|
|
name: module for name, module in self.state_dict().items() if "lora" in name |
|
|
} |
|
|
|
|
|
def merge_lora_parameters(self) -> None: |
|
|
"""Merge the LoRA parameters with the original parameters.""" |
|
|
|
|
|
for module in self.modules(): |
|
|
if hasattr(module, "parametrizations"): |
|
|
lora_params = [name for name in module.parametrizations.keys()] |
|
|
|
|
|
for name in lora_params: |
|
|
remove_parametrizations(module, name) |
|
|
|
|
|
def forward( |
|
|
self, x: Tensor, y: Tensor | None = None |
|
|
) -> tuple[Tensor, Tensor | None]: |
|
|
"""A forward pass optimized for batch training.""" |
|
|
|
|
|
z = self.token_embeddings(x) |
|
|
|
|
|
for layer in self.body: |
|
|
z = self.checkpoint(layer, z) |
|
|
|
|
|
z = self.output_norm(z) |
|
|
z = self.output_layer(z) |
|
|
|
|
|
if y is not None: |
|
|
y_pred = z.view(-1, z.size(-1)) |
|
|
labels = y.view(-1) |
|
|
|
|
|
loss = self.loss_function(y_pred, labels) |
|
|
else: |
|
|
loss = None |
|
|
|
|
|
return z, loss |
|
|
|
|
|
@torch.no_grad() |
|
|
def predict(self, x: Tensor, kv_cache: KVCache) -> Tensor: |
|
|
"""A forward pass optimized for next-token prediction.""" |
|
|
|
|
|
z = self.token_embeddings(x) |
|
|
|
|
|
for layer, kv_block in zip(self.body, kv_cache): |
|
|
z = layer.predict(z, kv_block) |
|
|
|
|
|
z = z[:, -1, :] |
|
|
|
|
|
z = self.output_norm(z) |
|
|
z = self.output_layer(z) |
|
|
|
|
|
return z |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate( |
|
|
self, |
|
|
prompt: Tensor, |
|
|
max_tokens: int = 1000, |
|
|
context_length: int = 1024, |
|
|
temperature: float = 1.0, |
|
|
top_k: int = 500, |
|
|
top_p: float = 0.9, |
|
|
repeat_penalty: float = 0.1, |
|
|
repeat_window: int = 50, |
|
|
) -> Generator[tuple[Tensor, Tensor], None, int]: |
|
|
""" |
|
|
Given a prompt, sample the next {max_tokens} tokens from the model weighted |
|
|
by their predicted probabilities and filtered by the {top_k} and {top_p}. |
|
|
""" |
|
|
|
|
|
if max_tokens <= 0: |
|
|
raise ValueError(f"Max tokens must be greater than 0, {max_tokens} given.") |
|
|
|
|
|
if context_length <= 0: |
|
|
raise ValueError( |
|
|
f"Context length must be greater than 0, {context_length} given." |
|
|
) |
|
|
|
|
|
if temperature <= 0: |
|
|
raise ValueError( |
|
|
f"Temperature must be greater than 0, {temperature} given." |
|
|
) |
|
|
|
|
|
if top_k <= 0 or top_k > self.vocabulary_size: |
|
|
raise ValueError( |
|
|
f"Top k must be between 1 and {self.vocabulary_size}, {top_k} given." |
|
|
) |
|
|
|
|
|
if top_p <= 0.0 or top_p > 1.0: |
|
|
raise ValueError(f"Top p must be between 0 and 1, {top_p} given.") |
|
|
|
|
|
if repeat_penalty < 0.0 or repeat_penalty > 1.0: |
|
|
raise ValueError( |
|
|
f"Repeat penalty must be between 0 and 1, {repeat_penalty} given." |
|
|
) |
|
|
|
|
|
if repeat_window <= 0: |
|
|
raise ValueError( |
|
|
f"Repeat window must be greater than 0, {repeat_window} given." |
|
|
) |
|
|
|
|
|
kv_cache = KVCache(self, 1, context_length).to(prompt.device) |
|
|
|
|
|
prompt = prompt[-context_length:] |
|
|
|
|
|
previous_tokens = deque(maxlen=repeat_window) |
|
|
num_tokens = 0 |
|
|
|
|
|
while num_tokens < max_tokens: |
|
|
logits = self.predict(prompt.unsqueeze(0), kv_cache).squeeze() |
|
|
|
|
|
for previous_token in previous_tokens: |
|
|
logits[previous_token] -= repeat_penalty * torch.abs( |
|
|
logits[previous_token] |
|
|
) |
|
|
|
|
|
logits, indices = torch.topk(logits, top_k, sorted=True) |
|
|
|
|
|
logits /= temperature |
|
|
|
|
|
probabilities = softmax(logits, dim=0) |
|
|
|
|
|
cumulative_probability_mass = torch.cumsum(probabilities, dim=0) |
|
|
|
|
|
min_probability_mass = cumulative_probability_mass[0] |
|
|
|
|
|
threshold_p = max(top_p, min_probability_mass.item()) |
|
|
|
|
|
selected_indices = cumulative_probability_mass <= threshold_p |
|
|
|
|
|
logits = logits[selected_indices] |
|
|
indices = indices[selected_indices] |
|
|
|
|
|
probabilities = softmax(logits, dim=0) |
|
|
|
|
|
offset = torch.multinomial(probabilities, num_samples=1).squeeze() |
|
|
|
|
|
next_token = indices[offset] |
|
|
probability = probabilities[offset] |
|
|
|
|
|
yield next_token, probability |
|
|
|
|
|
num_tokens += 1 |
|
|
|
|
|
previous_tokens.append(next_token) |
|
|
|
|
|
prompt = next_token.unsqueeze(0) |
|
|
|
|
|
return num_tokens |
|
|
|
|
|
|
|
|
class NoPEGPTHuggingFaceConfig(PretrainedConfig): |
|
|
"""Provide a monolithic configuration object to enable compatibility with HuggingFace Transformers API.""" |
|
|
|
|
|
model_type = "nope-gpt" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vocabulary_size: int = 50257, |
|
|
embedding_dimensions: int = 1024, |
|
|
num_heads: int = 16, |
|
|
num_layers: int = 24, |
|
|
feed_forward_ratio: int = 4, |
|
|
dropout: float = 0.1, |
|
|
**kwargs, |
|
|
): |
|
|
self.vocabulary_size = vocabulary_size |
|
|
self.embedding_dimensions = embedding_dimensions |
|
|
self.num_heads = num_heads |
|
|
self.num_layers = num_layers |
|
|
self.feed_forward_ratio = feed_forward_ratio |
|
|
self.dropout = dropout |
|
|
|
|
|
super().__init__(**kwargs) |
|
|
|
|
|
|
|
|
class NoPEGPTHuggingFaceModel(PreTrainedModel): |
|
|
"""Wrap model to enable compatibility with HuggingFace Transformers API.""" |
|
|
|
|
|
config_class = NoPEGPTHuggingFaceConfig |
|
|
|
|
|
def __init__(self, config: NoPEGPTHuggingFaceConfig): |
|
|
super().__init__(config) |
|
|
|
|
|
self.model = NoPEGPT( |
|
|
config.vocabulary_size, |
|
|
config.embedding_dimensions, |
|
|
config.num_heads, |
|
|
config.num_layers, |
|
|
config.feed_forward_ratio, |
|
|
config.dropout, |
|
|
) |
|
|
|
|
|
def forward(self, x: Tensor, y: Tensor | None = None) -> dict[str, Tensor | None]: |
|
|
logits, loss = self.model.forward(x, y) |
|
|
|
|
|
return { |
|
|
"logits": logits, |
|
|
"loss": loss, |
|
|
} |
|
|
|
|
|
|
|
|
class DecoderBlock(Module): |
|
|
"""Decoder block with multi-head attention, multilayer perceptron, and residual connections.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
embedding_dimensions: int, |
|
|
num_heads: int, |
|
|
feed_forward_ratio: int, |
|
|
dropout: float, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.norm1 = RMSNorm(embedding_dimensions) |
|
|
self.attention = SelfAttention(embedding_dimensions, num_heads, dropout) |
|
|
|
|
|
self.norm2 = RMSNorm(embedding_dimensions) |
|
|
self.mlp = MLP(embedding_dimensions, feed_forward_ratio, dropout) |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
z = self.norm1(x) |
|
|
z = self.attention(z) |
|
|
|
|
|
z = x + z |
|
|
|
|
|
x = z |
|
|
|
|
|
z = self.norm2(x) |
|
|
z = self.mlp(z) |
|
|
|
|
|
z = x + z |
|
|
|
|
|
return z |
|
|
|
|
|
@torch.no_grad() |
|
|
def predict(self, x: Tensor, kv_block: DynamicKVBlock) -> Tensor: |
|
|
"""A forward pass optimized for next-token prediction.""" |
|
|
|
|
|
z = self.norm1(x) |
|
|
z = self.attention.predict(z, kv_block) |
|
|
|
|
|
z = x + z |
|
|
|
|
|
x = z |
|
|
|
|
|
z = self.norm2(x) |
|
|
z = self.mlp.predict(z) |
|
|
|
|
|
z = x + z |
|
|
|
|
|
return z |
|
|
|
|
|
|
|
|
class SelfAttention(Module): |
|
|
"""Multihead self-attention with causal masking.""" |
|
|
|
|
|
def __init__(self, embedding_dimensions: int, num_heads: int, dropout: float): |
|
|
super().__init__() |
|
|
|
|
|
if embedding_dimensions <= 0: |
|
|
raise ValueError( |
|
|
f"Embedding dimensions must be greater than 0, {embedding_dimensions} given." |
|
|
) |
|
|
|
|
|
if num_heads <= 0: |
|
|
raise ValueError(f"Num heads must be greater than 0, {num_heads} given.") |
|
|
|
|
|
if embedding_dimensions % num_heads != 0: |
|
|
raise ValueError( |
|
|
f"Embedding dimensions must be divisible by num heads, {embedding_dimensions} and {num_heads} given." |
|
|
) |
|
|
|
|
|
self.qkv_proj = Linear( |
|
|
embedding_dimensions, 3 * embedding_dimensions, bias=False |
|
|
) |
|
|
|
|
|
self.out_proj = Linear(embedding_dimensions, embedding_dimensions, bias=False) |
|
|
|
|
|
head_dimensions: int = embedding_dimensions // num_heads |
|
|
scale: float = 1.0 / sqrt(head_dimensions) |
|
|
|
|
|
self.embedding_dimensions: int = embedding_dimensions |
|
|
self.num_heads: int = num_heads |
|
|
self.head_dimensions: int = head_dimensions |
|
|
self.scale: float = scale |
|
|
self.dropout: float = dropout |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
b, t, d = x.size() |
|
|
|
|
|
q, k, v = self.qkv_proj(x).split(self.embedding_dimensions, dim=-1) |
|
|
|
|
|
q = q.view(b, t, self.num_heads, self.head_dimensions).transpose(1, 2) |
|
|
k = k.view(b, t, self.num_heads, self.head_dimensions).transpose(1, 2) |
|
|
v = v.view(b, t, self.num_heads, self.head_dimensions).transpose(1, 2) |
|
|
|
|
|
z = scaled_dot_product_attention( |
|
|
q, |
|
|
k, |
|
|
v, |
|
|
scale=self.scale, |
|
|
dropout_p=self.dropout if self.training else 0, |
|
|
is_causal=True, |
|
|
) |
|
|
|
|
|
z = z.transpose(1, 2).contiguous().view(b, t, d) |
|
|
|
|
|
z = self.out_proj(z) |
|
|
|
|
|
return z |
|
|
|
|
|
@torch.no_grad() |
|
|
def predict(self, x: Tensor, kv_block: DynamicKVBlock) -> Tensor: |
|
|
"""A forward pass optimized for next-token prediction.""" |
|
|
|
|
|
b, t, d = x.size() |
|
|
|
|
|
is_autoregressive_phase = t == 1 |
|
|
|
|
|
q, k, v = self.qkv_proj(x).split(self.embedding_dimensions, dim=-1) |
|
|
|
|
|
q = q.view(b, t, self.num_heads, self.head_dimensions).transpose(1, 2) |
|
|
k = k.view(b, t, self.num_heads, self.head_dimensions).transpose(1, 2) |
|
|
v = v.view(b, t, self.num_heads, self.head_dimensions).transpose(1, 2) |
|
|
|
|
|
k, v = kv_block.update(k, v) |
|
|
|
|
|
z = scaled_dot_product_attention( |
|
|
q, |
|
|
k, |
|
|
v, |
|
|
scale=self.scale, |
|
|
is_causal=not is_autoregressive_phase, |
|
|
) |
|
|
|
|
|
z = z.transpose(1, 2).contiguous().view(b, t, d) |
|
|
|
|
|
z = self.out_proj(z) |
|
|
|
|
|
return z |
|
|
|
|
|
|
|
|
class MLP(Module): |
|
|
"""A two layer fully-connected network with dropout.""" |
|
|
|
|
|
def __init__( |
|
|
self, embedding_dimensions: int, feed_forward_ratio: int, dropout: float |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
if feed_forward_ratio not in {1, 2, 4}: |
|
|
raise ValueError("Feed-forward ratio must be either 1, 2, or 4.") |
|
|
|
|
|
hidden_dimensions: int = feed_forward_ratio * embedding_dimensions |
|
|
|
|
|
self.layers = Sequential( |
|
|
Linear(embedding_dimensions, hidden_dimensions, bias=False), |
|
|
SiLU(), |
|
|
Linear(hidden_dimensions, embedding_dimensions, bias=False), |
|
|
) |
|
|
|
|
|
self.dropout = Dropout1d(p=dropout) |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
return self.dropout(self.layers(x)) |
|
|
|
|
|
def predict(self, x: Tensor) -> Tensor: |
|
|
return self.layers(x) |
|
|
|
|
|
|
|
|
class LoRA(Module): |
|
|
"""Low rank weight decomposition transformation.""" |
|
|
|
|
|
@classmethod |
|
|
def from_linear( |
|
|
cls, linear: Linear, rank: int, alpha: float, dropout: float |
|
|
) -> Self: |
|
|
out_features, in_features = linear.weight.shape |
|
|
|
|
|
return cls(in_features, out_features, rank, alpha, dropout) |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
in_features: int, |
|
|
out_features: int, |
|
|
rank: int, |
|
|
alpha: float, |
|
|
dropout: float, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
if rank <= 0: |
|
|
raise ValueError(f"Rank must be greater than 0, {rank} given.") |
|
|
|
|
|
if alpha <= 0.0: |
|
|
raise ValueError(f"Alpha must be greater than 0, {alpha} given.") |
|
|
|
|
|
lora_a = torch.randn(rank, in_features) / sqrt(rank) |
|
|
lora_b = torch.zeros(out_features, rank) |
|
|
|
|
|
self.lora_a = Parameter(lora_a) |
|
|
self.lora_b = Parameter(lora_b) |
|
|
|
|
|
self.dropout = Dropout1d(dropout) |
|
|
|
|
|
self.alpha: float = alpha |
|
|
|
|
|
def forward(self, weight: Tensor) -> Tensor: |
|
|
z = self.lora_b @ self.dropout(self.lora_a) |
|
|
|
|
|
z *= self.alpha |
|
|
|
|
|
return weight + z |
|
|
|