MoR-deep / Initial_Train_MoR.py
gate369's picture
Create Initial_Train_MoR.py
d0e6871 verified
################################################
#Mixture of Recursions w/ Expert Choice Routing#
################################################
#This code is what i used to initially train this model. I continued training with 'Continue_Training_MoR.py'
from re import M
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.utils.data import Dataset, DataLoader
from torch.utils.checkpoint import checkpoint
from tokenizers import Tokenizer, models, trainers, pre_tokenizers
from tqdm import tqdm
import matplotlib.pyplot as plt
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import os
from safetensors.torch import save_file
import json
import os
from transformers import PreTrainedTokenizerFast
# Add this at the top to help with debugging
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
def save_huggingface_model(model, tokenizer, folder_path="MoR-v1"):
# Create directory structure
os.makedirs(folder_path, exist_ok=True)
# 1. Save model weights in safetensors format
weights = model.state_dict()
save_file(weights, os.path.join(folder_path, "model.safetensors"))
# 2. Create and save config.json
config = {
"vocab_size": VOCAB_SIZE,
"dim": DIM,
"num_layers": NUM_LAYERS,
"num_heads": HEADS,
"max_recursion": MAX_RECURSIONS,
"num_experts": MAX_RECURSIONS,
"ffn_expansion": 4,
"max_position_embeddings": 2048,
"model_type": "MoR",
"architecture": "MixtureOfRecursions",
"hidden_act": "gelu"
}
with open(os.path.join(folder_path, "config.json"), "w") as f:
json.dump(config, f, indent=2)
# 3. Save tokenizer files
hf_tokenizer = PreTrainedTokenizerFast(
tokenizer_object=tokenizer,
unk_token="[UNK]",
pad_token="[PAD]",
bos_token="[BOS]",
eos_token="[EOS]",
)
hf_tokenizer.save_pretrained(folder_path)
# 4. Create safetensors index file
index = {
"metadata": {"total_size": sum(p.numel() * p.element_size() for p in model.parameters())},
"weight_map": {name: "model.safetensors" for name in weights.keys()}
}
with open(os.path.join(folder_path, "model.safetensors.index.json"), "w") as f:
json.dump(index, f, indent=2)
print(f"Model saved in Hugging Face format to {folder_path}/")
VOCAB_SIZE = 10000
DIM = 1536
NUM_LAYERS = 6
HEADS = 8
BATCH_SIZE = 32
SEQ_LEN = 512
MAX_RECURSIONS = 4
learn_rate = 5e-5
EPOCHS = 3
NUM_EXPERTS = 12
GRAD_ACCUM_STEPS = 4 # Gradient accumulation steps
# ----------------------
# Character-Level Tokenizer
# ----------------------
def train_tokenizer(file_path, vocab_size=VOCAB_SIZE):
print("Training tokenizer...")
tokenizer = Tokenizer(models.BPE(unk_token="[UNK]"))
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
# GPU-accelerated text loading and preprocessing
if torch.cuda.is_available():
print("Using GPU for text preprocessing...")
with open(file_path, 'r') as f:
text = f.read()
# Process text in chunks on GPU
chunk_size = 1000000 # 1 million characters per chunk
chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
processed_chunks = []
for chunk in tqdm(chunks, desc="Processing text chunks on GPU"):
# Create tensor on GPU
chunk_tensor = torch.tensor([ord(c) for c in chunk], dtype=torch.int32, device='cuda')
# Simple GPU preprocessing (example: remove control characters)
processed_tensor = chunk_tensor[chunk_tensor >= 32] # Keep only printable ASCII
processed_chunks.append(processed_tensor.cpu().numpy().tobytes().decode('utf-8', errors='replace'))
text = ''.join(processed_chunks)
trainer = trainers.BpeTrainer(
vocab_size=vocab_size,
special_tokens=["[PAD]", "[UNK]", "[BOS]", "[EOS]"],
min_frequency=2
)
# Train tokenizer using memory-mapped files for large datasets
if os.path.getsize(file_path) > 100 * 1024 * 1024: # > 100MB
print("Using memory-mapped files for large dataset...")
tokenizer.train([file_path], trainer=trainer)
else:
# For smaller datasets, use preprocessed text
tokenizer.train_from_iterator([text], trainer=trainer, length=len(text))
print("Tokenizer successfully trained")
return tokenizer
def prepare_datasets(file_path, tokenizer, seq_len=SEQ_LEN, val_split=0.05):
print("Preparing datasets with GPU acceleration...")
# Memory-mapped file reading for large datasets
with open(file_path, 'r') as f:
text = f.read()
# GPU-accelerated tokenization pipeline
if torch.cuda.is_available():
print("Using GPU for tokenization pipeline...")
# Process text in chunks
chunk_size = 500000 # 500k characters per chunk
chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
encoded_chunks = []
for chunk in tqdm(chunks, desc="Tokenizing on GPU"):
# Encode on CPU
chunk_encoded = tokenizer.encode(chunk).ids
# Move to GPU for processing
chunk_tensor = torch.tensor(chunk_encoded, device='cuda')
encoded_chunks.append(chunk_tensor)
# Concatenate all chunks on GPU
encoded = torch.cat(encoded_chunks)
else:
# CPU fallback
encoded = tokenizer.encode(text).ids
encoded = torch.tensor(encoded, device='cpu')
total_tokens = len(encoded)
split_idx = int(total_tokens * (1 - val_split))
# Create datasets with direct device placement
train_dataset = TextDataset(encoded[:split_idx], seq_len)
val_dataset = TextDataset(encoded[split_idx:], seq_len)
total_batch_length = len(train_dataset)
print(f"Training samples: {total_batch_length}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Total tokens: {total_tokens}")
return train_dataset, val_dataset
class TextDataset(Dataset):
def __init__(self, encoded_data, seq_len=SEQ_LEN):
# Keep data on its original device (GPU/CPU)
self.encoded = encoded_data
self.seq_len = seq_len
self.device = encoded_data.device
def __len__(self):
return len(self.encoded) // self.seq_len
def __getitem__(self, idx):
start = idx * self.seq_len
end = start + self.seq_len + 1
segment = self.encoded[start:end]
# Return tensors directly on correct device
return segment[:-1], segment[1:]
# ----------------------
# MoR Model Components
# ----------------------
print("Defining components...")
class ExpertChoiceRouter(nn.Module):
"""Expert Choice Routing: Experts select top-k tokens"""
def __init__(self, dim, num_experts, k=2):
super().__init__()
self.num_experts = num_experts
self.k = k
self.gate = nn.Linear(dim, num_experts, bias=False)
def forward(self, x):
# x: (batch, seq_len, dim)
scores = self.gate(x) # (batch, seq_len, num_experts)
expert_weights, expert_indices = torch.topk(scores, self.k, dim=-1)
return expert_weights.softmax(dim=-1), expert_indices
# ----------------------
# 4-bit Quantization Utilities
# ----------------------
# Improved Quantization with gradient scaling
class Quantizer4Bit(nn.Module):
def __init__(self):
super().__init__()
@staticmethod
def quantize(tensor):
"""Quantize tensor to 4-bit integers with gradient scaling"""
# Use per-tensor scaling with safe normalization
max_val = tensor.abs().max()
scale = max_val / 7.5 if max_val > 1e-8 else 1.0
quantized = torch.clamp(torch.round(tensor / scale), -8, 7)
return quantized.to(torch.int8), scale
@staticmethod
def dequantize(quantized, scale):
"""Dequantize 4-bit integers to float"""
return quantized.float() * scale
# Weight initialization function
def init_weights(module):
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
# ----------------------
# MoR Model Components with Quantization
# ----------------------
class QuantizedRecursiveTransformerBlock(nn.Module):
def __init__(self, dim, num_heads, ffn_expansion=4):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
# Attention layers
self.q_proj = nn.Linear(dim, dim)
self.k_proj = nn.Linear(dim, dim)
self.v_proj = nn.Linear(dim, dim)
self.attn_out = nn.Linear(dim, dim)
# FFN layers
self.ffn = nn.Sequential(
nn.Linear(dim, ffn_expansion * dim),
nn.GELU(),
nn.Linear(ffn_expansion * dim, dim)
)
# Normalization
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
def forward(self, x):
# Use gradient checkpointing for this block
return checkpoint(self._forward, x, use_reentrant=False)
def _forward(self, x):
# x: (batch, seq_len, dim)
residual = x
x = self.norm1(x)
# Projections
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
# Quantize K and V
k_quant, k_scale = Quantizer4Bit.quantize(k)
v_quant, v_scale = Quantizer4Bit.quantize(v)
# Dequantize for computation
k = Quantizer4Bit.dequantize(k_quant, k_scale)
v = Quantizer4Bit.dequantize(v_quant, v_scale)
# Attention
B, T, _ = q.shape
q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
# Memory-efficient attention computation
attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
attn = attn.softmax(dim=-1)
attn_out = (attn @ v).transpose(1, 2).contiguous().view(B, T, self.dim)
attn_out = self.attn_out(attn_out)
# Residual connection
x = residual + attn_out
# FFN
x = x + self.ffn(self.norm2(x))
return x
class RecursionDepthRouter(nn.Module):
"""Lightweight Router for Dynamic Recursion Depth"""
def __init__(self, dim, max_depth=4):
super().__init__()
self.max_depth = max_depth
self.router = nn.Sequential(
nn.Linear(dim, dim), # Increased capacity
nn.ReLU(),
nn.Linear(dim, max_depth)
)
# Initialize router weights properly
for layer in self.router:
if isinstance(layer, nn.Linear):
nn.init.xavier_uniform_(layer.weight)
nn.init.zeros_(layer.bias)
def forward(self, x):
# x: (batch, seq_len, dim)
# Global average pooling across batch and sequence
x_pooled = x.mean(dim=(0, 1)) # (dim)
router_logits = self.router(x_pooled) # (max_depth)
return router_logits.softmax(dim=-1)
# ----------------------
# Main MoR Architecture (with Quantization)
# ----------------------
class QuantizedMoRModel(nn.Module):
def __init__(self, vocab_size, dim=DIM, num_layers=NUM_LAYERS,
num_heads=HEADS, max_recursion=MAX_RECURSIONS, num_experts=NUM_EXPERTS):
super().__init__()
self.dim = dim
self.max_recursion = max_recursion
self.num_experts = num_experts
# Embedding layers (unique parameters)
self.embedding = nn.Embedding(vocab_size, dim)
self.pos_embed = nn.Embedding(2048, dim)
# Initial unique layers
self.init_layers = nn.ModuleList([
QuantizedRecursiveTransformerBlock(dim, num_heads)
for _ in range(2)
])
# Middle-cycle shared layers
self.cycle_depth = 3
self.recursive_blocks = nn.ModuleList([
QuantizedRecursiveTransformerBlock(dim, num_heads)
for _ in range(self.cycle_depth)
])
# Recursion routers
self.recursion_routers = nn.ModuleList([
RecursionDepthRouter(dim, max_depth=max_recursion)
for _ in range(num_layers - 4)
])
# Expert choice routing
self.expert_routers = nn.ModuleList([
ExpertChoiceRouter(dim, num_experts)
for _ in range(max_recursion)
])
# Final unique layers
self.final_layers = nn.ModuleList([
QuantizedRecursiveTransformerBlock(dim, num_heads)
for _ in range(2)
])
# Output head
self.ln_f = nn.LayerNorm(dim)
self.head = nn.Linear(dim, vocab_size, bias=False)
def forward(self, x):
# Embedding with scaling
pos = torch.arange(0, x.shape[1], device=x.device)
x = self.embedding(x) * 0.02 # Scale embeddings
x = x + self.pos_embed(pos)
for layer in self.init_layers:
x = layer(x) * 0.8 # Scale residual
# Middle-cycle with recursion
batch_size, seq_len, _ = x.shape
recursion_outputs = []
for router in self.recursion_routers:
# Get recursion depth probabilities (scalar for whole batch)
depth_probs = router(x) # (max_depth)
# Sample single depth for entire batch
depth = torch.multinomial(depth_probs, 1).item() # convert to int
# Process through recursive blocks
expert_weights, expert_indices = self.expert_routers[depth](x)
# Create full weight matrix
full_weights = torch.zeros((batch_size, seq_len, self.num_experts),
device=x.device)
full_weights.scatter_(2, expert_indices, expert_weights)
# Process each expert in parallel without conditionals
expert_outputs = []
for expert_idx in range(self.num_experts):
# Create expert input using weights
expert_x = x * full_weights[:, :, expert_idx].unsqueeze(-1)
# Process through block
out = self.recursive_blocks[depth % self.cycle_depth](expert_x)
expert_outputs.append(out)
# Combine expert outputs
x = sum(expert_outputs)
recursion_outputs.append(x)
# Combine outputs from different recursion depths
if recursion_outputs:
x = torch.stack(recursion_outputs).mean(dim=0)
# Final unique layers
for layer in self.final_layers:
x = layer(x)
# Output
x = self.ln_f(x)
logits = self.head(x)
return logits
# ----------------------
# Learning Rate Scheduler
# ----------------------
def get_lr(current_step, total_steps, warmup_steps, max_lr):
"""Cosine annealing with warmup"""
if current_step < warmup_steps:
return max_lr * (current_step / warmup_steps)
else:
decay_ratio = (current_step - warmup_steps) / (total_steps - warmup_steps)
return max_lr * 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
# ----------------------
# Training Loop with Validation
# ----------------------
def train_model():
# Config
LR = learn_rate
# Initialize tokenizer and datasets
tokenizer = train_tokenizer("input.txt", VOCAB_SIZE)
train_dataset, val_dataset = prepare_datasets("input.txt", tokenizer, SEQ_LEN, val_split=0.05)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
# Initialize model
model = QuantizedMoRModel(
vocab_size=VOCAB_SIZE,
dim=DIM,
num_layers=NUM_LAYERS,
num_heads=HEADS
)
model.apply(init_weights)
# Parameter counting
total_params = sum(p.numel() for p in model.parameters())
print(f"Model Parameters: {total_params/1e6:.2f}M")
# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# Mixed precision training
scaler = GradScaler()
# Training setup
total_steps = EPOCHS * len(train_loader)
warmup_steps = int(0.1 * total_steps) # 10% warmup
print(f"Total training steps: {total_steps}, Warmup steps: {warmup_steps}")
# Training loop
train_losses = []
val_losses = []
best_val_loss = float('inf')
for epoch in range(EPOCHS):
# Training phase
model.train()
epoch_train_loss = 0
accumulated_loss = 0
optimizer.zero_grad()
for step, (inputs, targets) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1} Training")):
global_step = epoch * len(train_loader) + step
current_lr = get_lr(global_step, total_steps, warmup_steps, LR)
# Update learning rate
for param_group in optimizer.param_groups:
param_group['lr'] = current_lr
inputs, targets = inputs.to(device), targets.to(device)
with autocast():
logits = model(inputs)
loss = F.cross_entropy(
logits.view(-1, VOCAB_SIZE),
targets.view(-1),
ignore_index=0 # Ignore padding index
) / GRAD_ACCUM_STEPS
# Scale loss and backprop
scaler.scale(loss).backward()
accumulated_loss += loss.item() * GRAD_ACCUM_STEPS
# Print every 100 batches (not update steps)
if step % 100 == 0:
print(f"Step {global_step}: Batch Loss={accumulated_loss:.4f}, LR={current_lr:.2e}")
# Gradient accumulation
if (step + 1) % GRAD_ACCUM_STEPS == 0 or step == len(train_loader) - 1:
# Gradient clipping
scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# Update weights
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
# Logging for update steps
epoch_train_loss += accumulated_loss
#print(f"UPDATE Step {global_step}/{total_steps}: Loss={accumulated_loss:.4f}, GradNorm={grad_norm:.4f}")
accumulated_loss = 0
avg_train_loss = epoch_train_loss / len(train_loader)
train_losses.append(avg_train_loss)
# Validation phase
model.eval()
epoch_val_loss = 0
with torch.no_grad():
for inputs, targets in tqdm(val_loader, desc=f"Epoch {epoch+1} Validation"):
inputs, targets = inputs.to(device), targets.to(device)
with autocast():
logits = model(inputs)
loss = F.cross_entropy(
logits.view(-1, VOCAB_SIZE),
targets.view(-1),
ignore_index=0
)
epoch_val_loss += loss.item()
avg_val_loss = epoch_val_loss / len(val_loader)
val_losses.append(avg_val_loss)
# Save best model
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
save_huggingface_model(model, tokenizer, "MoR-v1")
print(f"Saved new best model with val loss: {best_val_loss:.4f}")
print(f"Epoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | LR: {current_lr:.2e}")
# Plot training and validation
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.title("Training and Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.savefig("training_validation_loss.png")
# Save final model
save_huggingface_model(model, tokenizer, "MoR-v1")
print("Training complete. Models saved.")
# ----------------------
# Execution
# ----------------------
if __name__ == "__main__":
train_model()