|
|
|
|
|
|
|
|
from re import M |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import math |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from torch.utils.checkpoint import checkpoint |
|
|
from tqdm import tqdm |
|
|
import matplotlib.pyplot as plt |
|
|
from torch.cuda.amp import autocast, GradScaler |
|
|
import numpy as np |
|
|
import os |
|
|
from safetensors.torch import save_file, load_file |
|
|
import json |
|
|
from transformers import PreTrainedTokenizerFast |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
os.environ['CUDA_LAUNCH_BLOCKING'] = '1' |
|
|
MODEL = "liminerity/MoR-deep" |
|
|
def save_huggingface_model(model, tokenizer, folder_path="MoR-v1"): |
|
|
|
|
|
os.makedirs(folder_path, exist_ok=True) |
|
|
|
|
|
weights = model.state_dict() |
|
|
save_file(weights, os.path.join(folder_path, "model.safetensors")) |
|
|
|
|
|
config = { |
|
|
"vocab_size": VOCAB_SIZE, |
|
|
"dim": DIM, |
|
|
"num_layers": NUM_LAYERS, |
|
|
"num_heads": HEADS, |
|
|
"max_recursion": MAX_RECURSIONS, |
|
|
"num_experts": model.num_experts, |
|
|
"ffn_expansion": 4, |
|
|
"max_position_embeddings": 2048, |
|
|
"model_type": "MoR", |
|
|
"architecture": "MixtureOfRecursions", |
|
|
"hidden_act": "gelu" |
|
|
} |
|
|
with open(os.path.join(folder_path, "config.json"), "w") as f: |
|
|
json.dump(config, f, indent=2) |
|
|
|
|
|
hf_tokenizer = PreTrainedTokenizerFast( |
|
|
tokenizer_object=tokenizer, |
|
|
unk_token="[UNK]", |
|
|
pad_token="[PAD]", |
|
|
bos_token="[BOS]", |
|
|
eos_token="[EOS]", |
|
|
) |
|
|
hf_tokenizer.save_pretrained(folder_path) |
|
|
|
|
|
index = { |
|
|
"metadata": {"total_size": sum(p.numel() * p.element_size() for p in model.parameters())}, |
|
|
"weight_map": {name: "model.safetensors" for name in weights.keys()} |
|
|
} |
|
|
with open(os.path.join(folder_path, "model.safetensors.index.json"), "w") as f: |
|
|
json.dump(index, f, indent=2) |
|
|
print(f"Model saved in Hugging Face format to {folder_path}/") |
|
|
|
|
|
def load_model_from_hub(repo_id=MODEL): |
|
|
|
|
|
local_dir = f"./models/{repo_id}" |
|
|
if not os.path.exists(local_dir): |
|
|
print(f"Downloading model from {repo_id}...") |
|
|
os.makedirs(local_dir, exist_ok=True) |
|
|
|
|
|
config_path = hf_hub_download(repo_id, "config.json", cache_dir=local_dir) |
|
|
|
|
|
safetensors_path = hf_hub_download(repo_id, "model.safetensors", cache_dir=local_dir) |
|
|
else: |
|
|
print(f"Using cached model from {local_dir}") |
|
|
config_path = os.path.join(local_dir, "config.json") |
|
|
safetensors_path = os.path.join(local_dir, "model.safetensors") |
|
|
|
|
|
|
|
|
with open(config_path, 'r') as f: |
|
|
config = json.load(f) |
|
|
|
|
|
|
|
|
weights = load_file(safetensors_path) |
|
|
|
|
|
|
|
|
NUM_EXPERTS = weights['expert_routers.0.gate.weight'].shape[0] |
|
|
print(f"Inferred number of experts from checkpoint: {NUM_EXPERTS}") |
|
|
|
|
|
|
|
|
config['num_experts'] = NUM_EXPERTS |
|
|
|
|
|
|
|
|
global VOCAB_SIZE, DIM, NUM_LAYERS, HEADS, MAX_RECURSIONS |
|
|
VOCAB_SIZE = config['vocab_size'] |
|
|
DIM = config['dim'] |
|
|
NUM_LAYERS = config['num_layers'] |
|
|
HEADS = config['num_heads'] |
|
|
MAX_RECURSIONS = config['max_recursion'] |
|
|
|
|
|
|
|
|
model = QuantizedMoRModel( |
|
|
vocab_size=VOCAB_SIZE, |
|
|
dim=DIM, |
|
|
num_layers=NUM_LAYERS, |
|
|
num_heads=HEADS, |
|
|
max_recursion=MAX_RECURSIONS, |
|
|
num_experts=NUM_EXPERTS |
|
|
) |
|
|
|
|
|
model.load_state_dict(weights) |
|
|
return model |
|
|
|
|
|
|
|
|
VOCAB_SIZE = 10000 |
|
|
DIM = 1536 |
|
|
NUM_LAYERS = 6 |
|
|
HEADS = 8 |
|
|
BATCH_SIZE = 32 |
|
|
SEQ_LEN = 512 |
|
|
MAX_RECURSIONS = 4 |
|
|
learn_rate = 5e-5 |
|
|
EPOCHS = 3 |
|
|
NUM_EXPERTS = 12 |
|
|
GRAD_ACCUM_STEPS = 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def prepare_datasets(file_path, tokenizer, seq_len=SEQ_LEN, val_split=0.05): |
|
|
print("Preparing datasets with tokenizer...") |
|
|
|
|
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
|
text = f.read() |
|
|
|
|
|
|
|
|
chunk_size = 500000 |
|
|
chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)] |
|
|
encoded_chunks = [] |
|
|
|
|
|
for chunk in tqdm(chunks, desc="Tokenizing text chunks"): |
|
|
|
|
|
encoding = tokenizer.encode(chunk, add_special_tokens=False) |
|
|
input_ids = torch.tensor(encoding) |
|
|
encoded_chunks.append(input_ids) |
|
|
|
|
|
|
|
|
encoded = torch.cat(encoded_chunks) |
|
|
total_tokens = len(encoded) |
|
|
split_idx = int(total_tokens * (1 - val_split)) |
|
|
|
|
|
|
|
|
train_dataset = TextDataset(encoded[:split_idx], seq_len) |
|
|
val_dataset = TextDataset(encoded[split_idx:], seq_len) |
|
|
|
|
|
print(f"Training samples: {len(train_dataset)}") |
|
|
print(f"Validation samples: {len(val_dataset)}") |
|
|
print(f"Total tokens: {total_tokens}") |
|
|
return train_dataset, val_dataset |
|
|
|
|
|
class TextDataset(Dataset): |
|
|
def __init__(self, encoded_data, seq_len=SEQ_LEN): |
|
|
self.encoded = encoded_data |
|
|
self.seq_len = seq_len |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.encoded) // self.seq_len |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
start = idx * self.seq_len |
|
|
end = start + self.seq_len + 1 |
|
|
segment = self.encoded[start:end] |
|
|
return segment[:-1].clone(), segment[1:].clone() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ExpertChoiceRouter(nn.Module): |
|
|
"""Expert Choice Routing: Experts select top-k tokens""" |
|
|
def __init__(self, dim, num_experts, k=2): |
|
|
super().__init__() |
|
|
self.num_experts = num_experts |
|
|
self.k = k |
|
|
self.gate = nn.Linear(dim, num_experts, bias=False) |
|
|
|
|
|
def forward(self, x): |
|
|
scores = self.gate(x) |
|
|
expert_weights, expert_indices = torch.topk(scores, self.k, dim=-1) |
|
|
return expert_weights.softmax(dim=-1), expert_indices |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Quantizer4Bit(nn.Module): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
|
|
|
@staticmethod |
|
|
def quantize(tensor): |
|
|
max_val = tensor.abs().max() |
|
|
scale = max_val / 7.5 if max_val > 1e-8 else 1.0 |
|
|
quantized = torch.clamp(torch.round(tensor / scale), -8, 7) |
|
|
return quantized.to(torch.int8), scale |
|
|
|
|
|
@staticmethod |
|
|
def dequantize(quantized, scale): |
|
|
return quantized.float() * scale |
|
|
|
|
|
def init_weights(module): |
|
|
if isinstance(module, nn.Linear): |
|
|
nn.init.xavier_uniform_(module.weight) |
|
|
if module.bias is not None: |
|
|
nn.init.zeros_(module.bias) |
|
|
elif isinstance(module, nn.Embedding): |
|
|
nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
elif isinstance(module, nn.LayerNorm): |
|
|
nn.init.ones_(module.weight) |
|
|
nn.init.zeros_(module.bias) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class QuantizedRecursiveTransformerBlock(nn.Module): |
|
|
def __init__(self, dim, num_heads, ffn_expansion=4): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
self.num_heads = num_heads |
|
|
self.head_dim = dim // num_heads |
|
|
self.q_proj = nn.Linear(dim, dim) |
|
|
self.k_proj = nn.Linear(dim, dim) |
|
|
self.v_proj = nn.Linear(dim, dim) |
|
|
self.attn_out = nn.Linear(dim, dim) |
|
|
self.ffn = nn.Sequential( |
|
|
nn.Linear(dim, ffn_expansion * dim), |
|
|
nn.GELU(), |
|
|
nn.Linear(ffn_expansion * dim, dim) |
|
|
) |
|
|
self.norm1 = nn.LayerNorm(dim) |
|
|
self.norm2 = nn.LayerNorm(dim) |
|
|
|
|
|
def forward(self, x): |
|
|
return checkpoint(self._forward, x, use_reentrant=False) |
|
|
|
|
|
def _forward(self, x): |
|
|
residual = x |
|
|
x = self.norm1(x) |
|
|
q = self.q_proj(x) |
|
|
k = self.k_proj(x) |
|
|
v = self.v_proj(x) |
|
|
k_quant, k_scale = Quantizer4Bit.quantize(k) |
|
|
v_quant, v_scale = Quantizer4Bit.quantize(v) |
|
|
k = Quantizer4Bit.dequantize(k_quant, k_scale) |
|
|
v = Quantizer4Bit.dequantize(v_quant, v_scale) |
|
|
B, T, _ = q.shape |
|
|
q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5) |
|
|
attn = attn.softmax(dim=-1) |
|
|
attn_out = (attn @ v).transpose(1, 2).contiguous().view(B, T, self.dim) |
|
|
attn_out = self.attn_out(attn_out) |
|
|
x = residual + attn_out |
|
|
x = x + self.ffn(self.norm2(x)) |
|
|
return x |
|
|
|
|
|
class RecursionDepthRouter(nn.Module): |
|
|
def __init__(self, dim, max_depth=4): |
|
|
super().__init__() |
|
|
self.max_depth = max_depth |
|
|
self.router = nn.Sequential( |
|
|
nn.Linear(dim, dim), |
|
|
nn.ReLU(), |
|
|
nn.Linear(dim, max_depth) |
|
|
) |
|
|
for layer in self.router: |
|
|
if isinstance(layer, nn.Linear): |
|
|
nn.init.xavier_uniform_(layer.weight) |
|
|
nn.init.zeros_(layer.bias) |
|
|
|
|
|
def forward(self, x): |
|
|
x_pooled = x.mean(dim=(0, 1)) |
|
|
router_logits = self.router(x_pooled) |
|
|
return router_logits.softmax(dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class QuantizedMoRModel(nn.Module): |
|
|
def __init__(self, vocab_size, dim=DIM, num_layers=NUM_LAYERS, |
|
|
num_heads=HEADS, max_recursion=MAX_RECURSIONS, num_experts=NUM_EXPERTS): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
self.max_recursion = max_recursion |
|
|
self.num_experts = num_experts |
|
|
self.embedding = nn.Embedding(vocab_size, dim) |
|
|
self.pos_embed = nn.Embedding(2048, dim) |
|
|
self.init_layers = nn.ModuleList([ |
|
|
QuantizedRecursiveTransformerBlock(dim, num_heads) |
|
|
for _ in range(2) |
|
|
]) |
|
|
self.cycle_depth = 3 |
|
|
self.recursive_blocks = nn.ModuleList([ |
|
|
QuantizedRecursiveTransformerBlock(dim, num_heads) |
|
|
for _ in range(self.cycle_depth) |
|
|
]) |
|
|
self.recursion_routers = nn.ModuleList([ |
|
|
RecursionDepthRouter(dim, max_depth=max_recursion) |
|
|
for _ in range(num_layers - 4) |
|
|
]) |
|
|
self.expert_routers = nn.ModuleList([ |
|
|
ExpertChoiceRouter(dim, num_experts) |
|
|
for _ in range(max_recursion) |
|
|
]) |
|
|
self.final_layers = nn.ModuleList([ |
|
|
QuantizedRecursiveTransformerBlock(dim, num_heads) |
|
|
for _ in range(2) |
|
|
]) |
|
|
self.ln_f = nn.LayerNorm(dim) |
|
|
self.head = nn.Linear(dim, vocab_size, bias=False) |
|
|
|
|
|
def forward(self, x): |
|
|
pos = torch.arange(0, x.shape[1], device=x.device) |
|
|
x = self.embedding(x) * 0.02 |
|
|
x = x + self.pos_embed(pos) |
|
|
for layer in self.init_layers: |
|
|
x = layer(x) * 0.8 |
|
|
batch_size, seq_len, _ = x.shape |
|
|
recursion_outputs = [] |
|
|
|
|
|
for router in self.recursion_routers: |
|
|
depth_probs = router(x) |
|
|
depth = torch.multinomial(depth_probs, 1).item() |
|
|
expert_weights, expert_indices = self.expert_routers[depth](x) |
|
|
full_weights = torch.zeros((batch_size, seq_len, self.num_experts), |
|
|
device=x.device) |
|
|
full_weights.scatter_(2, expert_indices, expert_weights) |
|
|
expert_outputs = [] |
|
|
for expert_idx in range(self.num_experts): |
|
|
expert_x = x * full_weights[:, :, expert_idx].unsqueeze(-1) |
|
|
out = self.recursive_blocks[depth % self.cycle_depth](expert_x) |
|
|
expert_outputs.append(out) |
|
|
x = sum(expert_outputs) |
|
|
recursion_outputs.append(x) |
|
|
|
|
|
if recursion_outputs: |
|
|
x = torch.stack(recursion_outputs).mean(dim=0) |
|
|
|
|
|
for layer in self.final_layers: |
|
|
x = layer(x) |
|
|
|
|
|
x = self.ln_f(x) |
|
|
logits = self.head(x) |
|
|
return logits |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_lr(current_step, total_steps, warmup_steps, max_lr): |
|
|
if current_step < warmup_steps: |
|
|
return max_lr * (current_step / warmup_steps) |
|
|
else: |
|
|
decay_ratio = (current_step - warmup_steps) / (total_steps - warmup_steps) |
|
|
return max_lr * 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_model(): |
|
|
|
|
|
tokenizer = PreTrainedTokenizerFast.from_pretrained(MODEL) |
|
|
global VOCAB_SIZE |
|
|
VOCAB_SIZE = tokenizer.vocab_size |
|
|
|
|
|
|
|
|
train_dataset, val_dataset = prepare_datasets("input.txt", tokenizer, SEQ_LEN, val_split=0.05) |
|
|
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True) |
|
|
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True) |
|
|
|
|
|
|
|
|
model = load_model_from_hub(MODEL) |
|
|
|
|
|
|
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
|
print(f"Model Parameters: {total_params/1e6:.2f}M") |
|
|
|
|
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=learn_rate, weight_decay=0.01) |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model = model.to(device) |
|
|
|
|
|
|
|
|
scaler = GradScaler() |
|
|
|
|
|
|
|
|
total_steps = EPOCHS * len(train_loader) |
|
|
warmup_steps = int(0.1 * total_steps) |
|
|
print(f"Total training steps: {total_steps}, Warmup steps: {warmup_steps}") |
|
|
|
|
|
|
|
|
train_losses = [] |
|
|
val_losses = [] |
|
|
best_val_loss = float('inf') |
|
|
|
|
|
for epoch in range(EPOCHS): |
|
|
model.train() |
|
|
epoch_train_loss = 0 |
|
|
accumulated_loss = 0 |
|
|
optimizer.zero_grad() |
|
|
|
|
|
for step, (inputs, targets) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1} Training")): |
|
|
global_step = epoch * len(train_loader) + step |
|
|
current_lr = get_lr(global_step, total_steps, warmup_steps, learn_rate) |
|
|
|
|
|
for param_group in optimizer.param_groups: |
|
|
param_group['lr'] = current_lr |
|
|
|
|
|
inputs, targets = inputs.to(device, non_blocking=True), targets.to(device, non_blocking=True) |
|
|
|
|
|
with autocast(): |
|
|
logits = model(inputs) |
|
|
loss = F.cross_entropy( |
|
|
logits.view(-1, VOCAB_SIZE), |
|
|
targets.view(-1), |
|
|
ignore_index=0 |
|
|
) / GRAD_ACCUM_STEPS |
|
|
|
|
|
scaler.scale(loss).backward() |
|
|
accumulated_loss += loss.item() * GRAD_ACCUM_STEPS |
|
|
|
|
|
if step % 100 == 0: |
|
|
print(f"Step {global_step}: Batch Loss={accumulated_loss:.4f}, LR={current_lr:.2e}") |
|
|
|
|
|
if (step + 1) % GRAD_ACCUM_STEPS == 0 or step == len(train_loader) - 1: |
|
|
scaler.unscale_(optimizer) |
|
|
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
|
|
scaler.step(optimizer) |
|
|
scaler.update() |
|
|
optimizer.zero_grad() |
|
|
epoch_train_loss += accumulated_loss |
|
|
accumulated_loss = 0 |
|
|
|
|
|
avg_train_loss = epoch_train_loss / len(train_loader) |
|
|
train_losses.append(avg_train_loss) |
|
|
|
|
|
|
|
|
model.eval() |
|
|
epoch_val_loss = 0 |
|
|
with torch.no_grad(): |
|
|
for inputs, targets in tqdm(val_loader, desc=f"Epoch {epoch+1} Validation"): |
|
|
inputs, targets = inputs.to(device, non_blocking=True), targets.to(device, non_blocking=True) |
|
|
with autocast(): |
|
|
logits = model(inputs) |
|
|
loss = F.cross_entropy( |
|
|
logits.view(-1, VOCAB_SIZE), |
|
|
targets.view(-1), |
|
|
ignore_index=0 |
|
|
) |
|
|
epoch_val_loss += loss.item() |
|
|
|
|
|
avg_val_loss = epoch_val_loss / len(val_loader) |
|
|
val_losses.append(avg_val_loss) |
|
|
|
|
|
if avg_val_loss < best_val_loss: |
|
|
best_val_loss = avg_val_loss |
|
|
save_huggingface_model(model, tokenizer, "MoR-v1-continued") |
|
|
print(f"Saved new best model with val loss: {best_val_loss:.4f}") |
|
|
|
|
|
print(f"Epoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | LR: {current_lr:.2e}") |
|
|
|
|
|
|
|
|
plt.figure(figsize=(10, 5)) |
|
|
plt.plot(train_losses, label='Training Loss') |
|
|
plt.plot(val_losses, label='Validation Loss') |
|
|
plt.title("Training and Validation Loss") |
|
|
plt.xlabel("Epoch") |
|
|
plt.ylabel("Loss") |
|
|
plt.legend() |
|
|
plt.savefig("training_validation_loss_continued.png") |
|
|
|
|
|
|
|
|
save_huggingface_model(model, tokenizer, "MoR-v1-continued") |
|
|
print("Training complete. Models saved.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
train_model() |