|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from transformers import AutoTokenizer
|
|
|
from utils import load_config
|
|
|
from tokenizers import Tokenizer
|
|
|
import os
|
|
|
import json
|
|
|
|
|
|
class TransformerBlock(nn.Module):
|
|
|
"""Single transformer block with self-attention and feed-forward layers"""
|
|
|
def __init__(self, n_embd, n_head, dropout=0.1):
|
|
|
super().__init__()
|
|
|
self.attention = nn.MultiheadAttention(n_embd, n_head, dropout=dropout, batch_first=True)
|
|
|
self.feed_forward = nn.Sequential(
|
|
|
nn.Linear(n_embd, 4 * n_embd),
|
|
|
nn.GELU(),
|
|
|
nn.Linear(4 * n_embd, n_embd)
|
|
|
)
|
|
|
self.ln1 = nn.LayerNorm(n_embd)
|
|
|
self.ln2 = nn.LayerNorm(n_embd)
|
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
|
|
def forward(self, x, mask=None):
|
|
|
|
|
|
if mask is not None:
|
|
|
mask = mask.to(dtype=x.dtype)
|
|
|
|
|
|
attn_out, _ = self.attention(x, x, x, attn_mask=mask)
|
|
|
x = x + self.dropout(attn_out)
|
|
|
x = self.ln1(x)
|
|
|
|
|
|
ff_out = self.feed_forward(x)
|
|
|
x = x + self.dropout(ff_out)
|
|
|
x = self.ln2(x)
|
|
|
return x
|
|
|
|
|
|
class CustomLanguageModel(nn.Module):
|
|
|
"""Custom transformer-based language model"""
|
|
|
def __init__(self, config):
|
|
|
super().__init__()
|
|
|
self.vocab_size = config["model"]["vocab_size"]
|
|
|
self.n_embd = config["model"]["n_embd"]
|
|
|
self.n_head = config["model"]["n_head"]
|
|
|
self.n_layer = config["model"]["n_layer"]
|
|
|
self.n_positions = config["model"]["n_positions"]
|
|
|
|
|
|
|
|
|
self.token_embedding = nn.Embedding(self.vocab_size, self.n_embd)
|
|
|
self.position_embedding = nn.Embedding(self.n_positions, self.n_embd)
|
|
|
|
|
|
|
|
|
self.transformer_blocks = nn.ModuleList([
|
|
|
TransformerBlock(self.n_embd, self.n_head)
|
|
|
for _ in range(self.n_layer)
|
|
|
])
|
|
|
|
|
|
|
|
|
self.ln_f = nn.LayerNorm(self.n_embd)
|
|
|
self.lm_head = nn.Linear(self.n_embd, self.vocab_size, bias=False)
|
|
|
|
|
|
|
|
|
self.token_embedding.weight = self.lm_head.weight
|
|
|
|
|
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
|
|
|
self.gradient_checkpointing_enable = config["model"].get("gradient_checkpointing", False)
|
|
|
|
|
|
def _init_weights(self, module):
|
|
|
if isinstance(module, (nn.Linear, nn.Embedding)):
|
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
|
|
if isinstance(module, nn.Linear) and module.bias is not None:
|
|
|
torch.nn.init.zeros_(module.bias)
|
|
|
elif isinstance(module, nn.LayerNorm):
|
|
|
torch.nn.init.zeros_(module.bias)
|
|
|
torch.nn.init.ones_(module.weight)
|
|
|
|
|
|
def forward(self, input_ids, labels=None):
|
|
|
batch_size, seq_length = input_ids.shape
|
|
|
|
|
|
|
|
|
positions = torch.arange(0, seq_length, dtype=torch.long, device=input_ids.device)
|
|
|
positions = positions.unsqueeze(0).expand(batch_size, -1)
|
|
|
|
|
|
|
|
|
token_embeddings = self.token_embedding(input_ids)
|
|
|
position_embeddings = self.position_embedding(positions)
|
|
|
x = token_embeddings + position_embeddings
|
|
|
|
|
|
|
|
|
mask = torch.triu(torch.ones((seq_length, seq_length), device=input_ids.device) * float('-inf'), diagonal=1)
|
|
|
mask = mask.to(dtype=x.dtype)
|
|
|
|
|
|
|
|
|
if self.training and self.gradient_checkpointing_enable:
|
|
|
for block in self.transformer_blocks:
|
|
|
x = torch.utils.checkpoint.checkpoint(block, x, mask, use_reentrant=False)
|
|
|
else:
|
|
|
for block in self.transformer_blocks:
|
|
|
x = block(x, mask=mask)
|
|
|
|
|
|
x = self.ln_f(x)
|
|
|
logits = self.lm_head(x)
|
|
|
|
|
|
if labels is not None:
|
|
|
loss_fct = nn.CrossEntropyLoss()
|
|
|
loss = loss_fct(logits.view(-1, self.vocab_size), labels.view(-1))
|
|
|
return {"loss": loss, "logits": logits}
|
|
|
|
|
|
return {"logits": logits}
|
|
|
|
|
|
def num_parameters(self):
|
|
|
"""Returns the number of trainable parameters in the model."""
|
|
|
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
|
|
|
|
|
def create_model(config):
|
|
|
"""Creates a custom language model from scratch based on the configuration."""
|
|
|
model = CustomLanguageModel(config)
|
|
|
return model
|
|
|
|
|
|
def get_tokenizer(config):
|
|
|
"""Loads a trained ByteLevelBPE tokenizer."""
|
|
|
from tokenizers import ByteLevelBPETokenizer
|
|
|
|
|
|
model_path = config["tokenizer"]["model_path"]
|
|
|
if not os.path.exists(os.path.join(model_path, "vocab.json")):
|
|
|
raise ValueError(f"No tokenizer found at {model_path}. Please train the tokenizer first.")
|
|
|
|
|
|
tokenizer = ByteLevelBPETokenizer(
|
|
|
os.path.join(model_path, "vocab.json"),
|
|
|
os.path.join(model_path, "merges.txt")
|
|
|
)
|
|
|
|
|
|
|
|
|
special_tokens = {
|
|
|
"eos_token": "<|endoftext|>",
|
|
|
"pad_token": "<|pad|>",
|
|
|
"unk_token": "<|unk|>",
|
|
|
"mask_token": "<|mask|>"
|
|
|
}
|
|
|
tokenizer.add_special_tokens(list(special_tokens.values()))
|
|
|
|
|
|
|
|
|
tokenizer.get_vocab_size = lambda: len(tokenizer.get_vocab())
|
|
|
|
|
|
def batch_encode(texts, padding=True, truncation=True, max_length=None, return_tensors=None):
|
|
|
encodings = tokenizer.encode_batch(texts)
|
|
|
|
|
|
token_ids = [enc.ids for enc in encodings]
|
|
|
|
|
|
if max_length and truncation:
|
|
|
token_ids = [ids[:max_length] for ids in token_ids]
|
|
|
|
|
|
if padding:
|
|
|
max_len = max(len(ids) for ids in token_ids)
|
|
|
pad_token_id = tokenizer.token_to_id("<|pad|>")
|
|
|
padded = []
|
|
|
for ids in token_ids:
|
|
|
pad_length = max_len - len(ids)
|
|
|
padded.append(ids + [pad_token_id] * pad_length)
|
|
|
token_ids = padded
|
|
|
|
|
|
if return_tensors == "pt":
|
|
|
return {
|
|
|
"input_ids": torch.tensor(token_ids),
|
|
|
"attention_mask": torch.ones_like(torch.tensor(token_ids))
|
|
|
}
|
|
|
return {"input_ids": token_ids}
|
|
|
|
|
|
tokenizer.batch_encode = batch_encode
|
|
|
|
|
|
print(f"ByteLevelBPE tokenizer loaded successfully. Vocab size: {tokenizer.get_vocab_size()}")
|
|
|
return tokenizer
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
config = load_config()
|
|
|
tokenizer = get_tokenizer(config)
|
|
|
config["model"]["vocab_size"] = tokenizer.get_vocab_size()
|
|
|
model = create_model(config)
|
|
|
print(f"Model created with {model.num_parameters():,} parameters.")
|
|
|
print(model)
|
|
|
|