|
|
import torch |
|
|
import torch.nn as nn |
|
|
import math |
|
|
from transformers import AutoTokenizer, Trainer, TrainingArguments, DataCollatorForLanguageModeling |
|
|
from datasets import load_dataset, interleave_datasets |
|
|
from mixture_of_recursion import RecursiveLanguageModel, RecursiveLanguageModelConfig |
|
|
import gc |
|
|
|
|
|
|
|
|
TOTAL_SAMPLES = 50000 |
|
|
BATCH_SIZE = 1 |
|
|
GRAD_ACCUM = 32 |
|
|
EPOCHS = 3 |
|
|
LEARNING_RATE = 3e-4 |
|
|
MAX_LENGTH = 384 |
|
|
|
|
|
print("Starting training with 50K premium samples") |
|
|
print("-" * 60) |
|
|
|
|
|
|
|
|
print("\nLoading tokenizer...") |
|
|
tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
print(f"Tokenizer vocab size: {len(tokenizer)}") |
|
|
print(f"Pad token ID: {tokenizer.pad_token_id}") |
|
|
|
|
|
|
|
|
print("\nLoading datasets...") |
|
|
print(" FineWeb-Edu (45%)") |
|
|
fineweb = load_dataset( |
|
|
"HuggingFaceFW/fineweb-edu", |
|
|
name="sample-10BT", |
|
|
split="train", |
|
|
streaming=True |
|
|
).shuffle(seed=42).take(int(TOTAL_SAMPLES * 0.45)) |
|
|
|
|
|
print(" Cosmopedia (30%)") |
|
|
cosmopedia = load_dataset( |
|
|
"HuggingFaceTB/cosmopedia", |
|
|
"web_samples_v1", |
|
|
split="train", |
|
|
streaming=True |
|
|
).shuffle(seed=42).take(int(TOTAL_SAMPLES * 0.30)) |
|
|
|
|
|
print(" OpenWebText (25%)") |
|
|
openwebtext = load_dataset( |
|
|
"openwebtext", |
|
|
split="train", |
|
|
streaming=True |
|
|
).shuffle(seed=42).take(int(TOTAL_SAMPLES * 0.25)) |
|
|
|
|
|
|
|
|
print("\nMixing datasets...") |
|
|
train_dataset = interleave_datasets( |
|
|
[fineweb, cosmopedia, openwebtext], |
|
|
probabilities=[0.45, 0.30, 0.25], |
|
|
seed=42 |
|
|
) |
|
|
|
|
|
|
|
|
def tokenize(examples): |
|
|
if 'text' in examples: |
|
|
texts = examples['text'] |
|
|
elif 'content' in examples: |
|
|
texts = examples['content'] |
|
|
else: |
|
|
texts = list(examples.values())[0] |
|
|
|
|
|
return tokenizer( |
|
|
texts, |
|
|
truncation=True, |
|
|
max_length=MAX_LENGTH, |
|
|
padding=False |
|
|
) |
|
|
|
|
|
|
|
|
print("Tokenizing...") |
|
|
tokenized_train = train_dataset.map( |
|
|
tokenize, |
|
|
batched=True, |
|
|
remove_columns=train_dataset.column_names |
|
|
).filter(lambda x: len(x['input_ids']) >= 128) |
|
|
|
|
|
|
|
|
val_dataset = load_dataset( |
|
|
"HuggingFaceFW/fineweb-edu", |
|
|
name="sample-10BT", |
|
|
split="train", |
|
|
streaming=True |
|
|
).take(1000) |
|
|
|
|
|
val_tokenized = val_dataset.map( |
|
|
tokenize, |
|
|
batched=True, |
|
|
remove_columns=val_dataset.column_names |
|
|
).filter(lambda x: len(x['input_ids']) >= 128) |
|
|
|
|
|
|
|
|
print("\nBuilding model...") |
|
|
config = RecursiveLanguageModelConfig( |
|
|
vocab_size=len(tokenizer), |
|
|
embedding_dim=512, |
|
|
num_layers=6, |
|
|
num_attention_heads=8, |
|
|
max_recursion_steps=5, |
|
|
max_position_embeddings=512, |
|
|
intermediate_size=2048, |
|
|
pad_token_id=tokenizer.pad_token_id, |
|
|
bos_token_id=tokenizer.pad_token_id, |
|
|
eos_token_id=tokenizer.pad_token_id, |
|
|
simple_recursion_steps=1, |
|
|
medium_recursion_steps=3, |
|
|
complex_recursion_steps=5, |
|
|
use_adaptive_stopping=True, |
|
|
hidden_dropout_prob=0.1, |
|
|
attention_dropout_prob=0.1 |
|
|
) |
|
|
|
|
|
model = RecursiveLanguageModel(config) |
|
|
|
|
|
params = sum(p.numel() for p in model.parameters()) / 1e6 |
|
|
print(f"Model parameters: {params:.1f}M") |
|
|
|
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
|
|
|
|
|
|
data_collator = DataCollatorForLanguageModeling( |
|
|
tokenizer=tokenizer, |
|
|
mlm=False |
|
|
) |
|
|
|
|
|
steps_per_epoch = TOTAL_SAMPLES // (BATCH_SIZE * GRAD_ACCUM) |
|
|
max_steps = steps_per_epoch * EPOCHS |
|
|
|
|
|
print(f"\nTraining steps: {max_steps}") |
|
|
print(f"Effective batch size: {BATCH_SIZE * GRAD_ACCUM}") |
|
|
|
|
|
training_args = TrainingArguments( |
|
|
output_dir="./checkpoints", |
|
|
max_steps=max_steps, |
|
|
per_device_train_batch_size=BATCH_SIZE, |
|
|
per_device_eval_batch_size=BATCH_SIZE, |
|
|
gradient_accumulation_steps=GRAD_ACCUM, |
|
|
learning_rate=LEARNING_RATE, |
|
|
weight_decay=0.01, |
|
|
warmup_steps=500, |
|
|
fp16=True, |
|
|
logging_steps=100, |
|
|
eval_strategy="steps", |
|
|
eval_steps=1000, |
|
|
save_steps=1000, |
|
|
save_total_limit=2, |
|
|
load_best_model_at_end=True, |
|
|
metric_for_best_model="eval_loss", |
|
|
report_to="none", |
|
|
max_grad_norm=1.0, |
|
|
save_safetensors=False, |
|
|
) |
|
|
|
|
|
|
|
|
class CustomTrainer(Trainer): |
|
|
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): |
|
|
outputs = model(**inputs) |
|
|
return (outputs.loss, outputs) if return_outputs else outputs.loss |
|
|
|
|
|
def evaluation_loop(self, dataloader, description, prediction_loss_only=None, |
|
|
ignore_keys=None, metric_key_prefix="eval"): |
|
|
output = super().evaluation_loop( |
|
|
dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix |
|
|
) |
|
|
|
|
|
if output.metrics.get(f"{metric_key_prefix}_loss") is not None: |
|
|
try: |
|
|
perplexity = math.exp(output.metrics[f"{metric_key_prefix}_loss"]) |
|
|
output.metrics[f"{metric_key_prefix}_perplexity"] = perplexity |
|
|
except OverflowError: |
|
|
output.metrics[f"{metric_key_prefix}_perplexity"] = float("inf") |
|
|
|
|
|
return output |
|
|
|
|
|
def training_step(self, model, inputs, num_items_in_batch=None): |
|
|
loss = super().training_step(model, inputs, num_items_in_batch) |
|
|
|
|
|
if self.state.global_step % 50 == 0: |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
return loss |
|
|
|
|
|
trainer = CustomTrainer( |
|
|
model=model, |
|
|
args=training_args, |
|
|
train_dataset=tokenized_train, |
|
|
eval_dataset=val_tokenized, |
|
|
data_collator=data_collator |
|
|
) |
|
|
|
|
|
|
|
|
print("\nStarting training...") |
|
|
print("-" * 60) |
|
|
|
|
|
try: |
|
|
trainer.train() |
|
|
|
|
|
|
|
|
print("\nFinal evaluation...") |
|
|
metrics = trainer.evaluate() |
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("FINAL RESULTS:") |
|
|
print("="*60) |
|
|
print(f"Evaluation Loss: {metrics['eval_loss']:.4f}") |
|
|
|
|
|
if 'eval_perplexity' in metrics: |
|
|
print(f"Perplexity: {metrics['eval_perplexity']:.2f}") |
|
|
else: |
|
|
try: |
|
|
perplexity = math.exp(metrics['eval_loss']) |
|
|
print(f"Perplexity: {perplexity:.2f}") |
|
|
except OverflowError: |
|
|
print(f"Perplexity: inf (loss too high)") |
|
|
print("="*60 + "\n") |
|
|
|
|
|
|
|
|
print("Saving model...") |
|
|
model.save_pretrained("./recursive-lm") |
|
|
tokenizer.save_pretrained("./recursive-lm") |
|
|
print("Model saved successfully!") |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
print("\n\nTraining interrupted by user") |
|
|
print("Saving current model state...") |
|
|
model.save_pretrained("./recursive-lm-interrupted") |
|
|
tokenizer.save_pretrained("./recursive-lm-interrupted") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\n\nTraining stopped due to: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|
|
|
|
|
|
try: |
|
|
print("\nAttempting to save model...") |
|
|
model.save_pretrained("./recursive-lm-error") |
|
|
tokenizer.save_pretrained("./recursive-lm-error") |
|
|
print("Model saved!") |
|
|
except: |
|
|
print("Could not save model") |
|
|
|
|
|
print("\nTraining complete!") |
|
|
|