File size: 7,317 Bytes
eb05668 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
import torch
import torch.nn as nn
from transformers import AutoTokenizer
from utils import load_config
from tokenizers import Tokenizer
import os
import json
class TransformerBlock(nn.Module):
"""Single transformer block with self-attention and feed-forward layers"""
def __init__(self, n_embd, n_head, dropout=0.1):
super().__init__()
self.attention = nn.MultiheadAttention(n_embd, n_head, dropout=dropout, batch_first=True)
self.feed_forward = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.GELU(),
nn.Linear(4 * n_embd, n_embd)
)
self.ln1 = nn.LayerNorm(n_embd)
self.ln2 = nn.LayerNorm(n_embd)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Ensure mask is same dtype as input
if mask is not None:
mask = mask.to(dtype=x.dtype)
# Self-attention with residual connection
attn_out, _ = self.attention(x, x, x, attn_mask=mask)
x = x + self.dropout(attn_out)
x = self.ln1(x)
# Feed-forward with residual connection
ff_out = self.feed_forward(x)
x = x + self.dropout(ff_out)
x = self.ln2(x)
return x
class CustomLanguageModel(nn.Module):
"""Custom transformer-based language model"""
def __init__(self, config):
super().__init__()
self.vocab_size = config["model"]["vocab_size"]
self.n_embd = config["model"]["n_embd"]
self.n_head = config["model"]["n_head"]
self.n_layer = config["model"]["n_layer"]
self.n_positions = config["model"]["n_positions"]
# Token and position embeddings
self.token_embedding = nn.Embedding(self.vocab_size, self.n_embd)
self.position_embedding = nn.Embedding(self.n_positions, self.n_embd)
# Transformer blocks
self.transformer_blocks = nn.ModuleList([
TransformerBlock(self.n_embd, self.n_head)
for _ in range(self.n_layer)
])
# Output layer
self.ln_f = nn.LayerNorm(self.n_embd)
self.lm_head = nn.Linear(self.n_embd, self.vocab_size, bias=False)
# Tie weights between token embedding and output layer
self.token_embedding.weight = self.lm_head.weight
# Initialize weights
self.apply(self._init_weights)
# Set gradient checkpointing flag based on config
self.gradient_checkpointing_enable = config["model"].get("gradient_checkpointing", False)
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
torch.nn.init.zeros_(module.bias)
torch.nn.init.ones_(module.weight)
def forward(self, input_ids, labels=None):
batch_size, seq_length = input_ids.shape
# Create position indices
positions = torch.arange(0, seq_length, dtype=torch.long, device=input_ids.device)
positions = positions.unsqueeze(0).expand(batch_size, -1)
# Get embeddings and sum token & position embeddings
token_embeddings = self.token_embedding(input_ids)
position_embeddings = self.position_embedding(positions)
x = token_embeddings + position_embeddings
# Create causal mask and convert to same dtype as embeddings
mask = torch.triu(torch.ones((seq_length, seq_length), device=input_ids.device) * float('-inf'), diagonal=1)
mask = mask.to(dtype=x.dtype)
# Process through transformer blocks (use gradient checkpointing only if enabled)
if self.training and self.gradient_checkpointing_enable:
for block in self.transformer_blocks:
x = torch.utils.checkpoint.checkpoint(block, x, mask, use_reentrant=False)
else:
for block in self.transformer_blocks:
x = block(x, mask=mask)
x = self.ln_f(x)
logits = self.lm_head(x)
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.vocab_size), labels.view(-1))
return {"loss": loss, "logits": logits}
return {"logits": logits}
def num_parameters(self):
"""Returns the number of trainable parameters in the model."""
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def create_model(config):
"""Creates a custom language model from scratch based on the configuration."""
model = CustomLanguageModel(config)
return model
def get_tokenizer(config):
"""Loads a trained ByteLevelBPE tokenizer."""
from tokenizers import ByteLevelBPETokenizer
model_path = config["tokenizer"]["model_path"]
if not os.path.exists(os.path.join(model_path, "vocab.json")):
raise ValueError(f"No tokenizer found at {model_path}. Please train the tokenizer first.")
tokenizer = ByteLevelBPETokenizer(
os.path.join(model_path, "vocab.json"),
os.path.join(model_path, "merges.txt")
)
# Add special tokens if they don't exist
special_tokens = {
"eos_token": "<|endoftext|>",
"pad_token": "<|pad|>",
"unk_token": "<|unk|>",
"mask_token": "<|mask|>"
}
tokenizer.add_special_tokens(list(special_tokens.values()))
# Add methods to match expected interface
tokenizer.get_vocab_size = lambda: len(tokenizer.get_vocab())
def batch_encode(texts, padding=True, truncation=True, max_length=None, return_tensors=None):
encodings = tokenizer.encode_batch(texts)
# Extract token ids from encodings
token_ids = [enc.ids for enc in encodings]
if max_length and truncation:
token_ids = [ids[:max_length] for ids in token_ids]
if padding:
max_len = max(len(ids) for ids in token_ids)
pad_token_id = tokenizer.token_to_id("<|pad|>")
padded = []
for ids in token_ids:
pad_length = max_len - len(ids)
padded.append(ids + [pad_token_id] * pad_length)
token_ids = padded
if return_tensors == "pt":
return {
"input_ids": torch.tensor(token_ids),
"attention_mask": torch.ones_like(torch.tensor(token_ids))
}
return {"input_ids": token_ids}
tokenizer.batch_encode = batch_encode
print(f"ByteLevelBPE tokenizer loaded successfully. Vocab size: {tokenizer.get_vocab_size()}")
return tokenizer
if __name__ == "__main__":
config = load_config()
tokenizer = get_tokenizer(config)
config["model"]["vocab_size"] = tokenizer.get_vocab_size()
model = create_model(config)
print(f"Model created with {model.num_parameters():,} parameters.")
print(model)
|