Ion-LLM-Base / src /model.py
diabolic6045's picture
Upload 9 files
eb05668 verified
import torch
import torch.nn as nn
from transformers import AutoTokenizer
from utils import load_config
from tokenizers import Tokenizer
import os
import json
class TransformerBlock(nn.Module):
"""Single transformer block with self-attention and feed-forward layers"""
def __init__(self, n_embd, n_head, dropout=0.1):
super().__init__()
self.attention = nn.MultiheadAttention(n_embd, n_head, dropout=dropout, batch_first=True)
self.feed_forward = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.GELU(),
nn.Linear(4 * n_embd, n_embd)
)
self.ln1 = nn.LayerNorm(n_embd)
self.ln2 = nn.LayerNorm(n_embd)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Ensure mask is same dtype as input
if mask is not None:
mask = mask.to(dtype=x.dtype)
# Self-attention with residual connection
attn_out, _ = self.attention(x, x, x, attn_mask=mask)
x = x + self.dropout(attn_out)
x = self.ln1(x)
# Feed-forward with residual connection
ff_out = self.feed_forward(x)
x = x + self.dropout(ff_out)
x = self.ln2(x)
return x
class CustomLanguageModel(nn.Module):
"""Custom transformer-based language model"""
def __init__(self, config):
super().__init__()
self.vocab_size = config["model"]["vocab_size"]
self.n_embd = config["model"]["n_embd"]
self.n_head = config["model"]["n_head"]
self.n_layer = config["model"]["n_layer"]
self.n_positions = config["model"]["n_positions"]
# Token and position embeddings
self.token_embedding = nn.Embedding(self.vocab_size, self.n_embd)
self.position_embedding = nn.Embedding(self.n_positions, self.n_embd)
# Transformer blocks
self.transformer_blocks = nn.ModuleList([
TransformerBlock(self.n_embd, self.n_head)
for _ in range(self.n_layer)
])
# Output layer
self.ln_f = nn.LayerNorm(self.n_embd)
self.lm_head = nn.Linear(self.n_embd, self.vocab_size, bias=False)
# Tie weights between token embedding and output layer
self.token_embedding.weight = self.lm_head.weight
# Initialize weights
self.apply(self._init_weights)
# Set gradient checkpointing flag based on config
self.gradient_checkpointing_enable = config["model"].get("gradient_checkpointing", False)
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
torch.nn.init.zeros_(module.bias)
torch.nn.init.ones_(module.weight)
def forward(self, input_ids, labels=None):
batch_size, seq_length = input_ids.shape
# Create position indices
positions = torch.arange(0, seq_length, dtype=torch.long, device=input_ids.device)
positions = positions.unsqueeze(0).expand(batch_size, -1)
# Get embeddings and sum token & position embeddings
token_embeddings = self.token_embedding(input_ids)
position_embeddings = self.position_embedding(positions)
x = token_embeddings + position_embeddings
# Create causal mask and convert to same dtype as embeddings
mask = torch.triu(torch.ones((seq_length, seq_length), device=input_ids.device) * float('-inf'), diagonal=1)
mask = mask.to(dtype=x.dtype)
# Process through transformer blocks (use gradient checkpointing only if enabled)
if self.training and self.gradient_checkpointing_enable:
for block in self.transformer_blocks:
x = torch.utils.checkpoint.checkpoint(block, x, mask, use_reentrant=False)
else:
for block in self.transformer_blocks:
x = block(x, mask=mask)
x = self.ln_f(x)
logits = self.lm_head(x)
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.vocab_size), labels.view(-1))
return {"loss": loss, "logits": logits}
return {"logits": logits}
def num_parameters(self):
"""Returns the number of trainable parameters in the model."""
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def create_model(config):
"""Creates a custom language model from scratch based on the configuration."""
model = CustomLanguageModel(config)
return model
def get_tokenizer(config):
"""Loads a trained ByteLevelBPE tokenizer."""
from tokenizers import ByteLevelBPETokenizer
model_path = config["tokenizer"]["model_path"]
if not os.path.exists(os.path.join(model_path, "vocab.json")):
raise ValueError(f"No tokenizer found at {model_path}. Please train the tokenizer first.")
tokenizer = ByteLevelBPETokenizer(
os.path.join(model_path, "vocab.json"),
os.path.join(model_path, "merges.txt")
)
# Add special tokens if they don't exist
special_tokens = {
"eos_token": "<|endoftext|>",
"pad_token": "<|pad|>",
"unk_token": "<|unk|>",
"mask_token": "<|mask|>"
}
tokenizer.add_special_tokens(list(special_tokens.values()))
# Add methods to match expected interface
tokenizer.get_vocab_size = lambda: len(tokenizer.get_vocab())
def batch_encode(texts, padding=True, truncation=True, max_length=None, return_tensors=None):
encodings = tokenizer.encode_batch(texts)
# Extract token ids from encodings
token_ids = [enc.ids for enc in encodings]
if max_length and truncation:
token_ids = [ids[:max_length] for ids in token_ids]
if padding:
max_len = max(len(ids) for ids in token_ids)
pad_token_id = tokenizer.token_to_id("<|pad|>")
padded = []
for ids in token_ids:
pad_length = max_len - len(ids)
padded.append(ids + [pad_token_id] * pad_length)
token_ids = padded
if return_tensors == "pt":
return {
"input_ids": torch.tensor(token_ids),
"attention_mask": torch.ones_like(torch.tensor(token_ids))
}
return {"input_ids": token_ids}
tokenizer.batch_encode = batch_encode
print(f"ByteLevelBPE tokenizer loaded successfully. Vocab size: {tokenizer.get_vocab_size()}")
return tokenizer
if __name__ == "__main__":
config = load_config()
tokenizer = get_tokenizer(config)
config["model"]["vocab_size"] = tokenizer.get_vocab_size()
model = create_model(config)
print(f"Model created with {model.num_parameters():,} parameters.")
print(model)