Girinath11's picture
Rename train (2).py to train.py
47502d0 verified
import torch
import torch.nn as nn
import math
from transformers import AutoTokenizer, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import load_dataset, interleave_datasets
from mixture_of_recursion import RecursiveLanguageModel, RecursiveLanguageModelConfig
import gc
# Configuration
TOTAL_SAMPLES = 50000
BATCH_SIZE = 1
GRAD_ACCUM = 32
EPOCHS = 3
LEARNING_RATE = 3e-4
MAX_LENGTH = 384
print("Starting training with 50K premium samples")
print("-" * 60)
# Load tokenizer
print("\nLoading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
print(f"Tokenizer vocab size: {len(tokenizer)}")
print(f"Pad token ID: {tokenizer.pad_token_id}")
# Load datasets
print("\nLoading datasets...")
print(" FineWeb-Edu (45%)")
fineweb = load_dataset(
"HuggingFaceFW/fineweb-edu",
name="sample-10BT",
split="train",
streaming=True
).shuffle(seed=42).take(int(TOTAL_SAMPLES * 0.45))
print(" Cosmopedia (30%)")
cosmopedia = load_dataset(
"HuggingFaceTB/cosmopedia",
"web_samples_v1",
split="train",
streaming=True
).shuffle(seed=42).take(int(TOTAL_SAMPLES * 0.30))
print(" OpenWebText (25%)")
openwebtext = load_dataset(
"openwebtext",
split="train",
streaming=True
).shuffle(seed=42).take(int(TOTAL_SAMPLES * 0.25))
# Mix datasets
print("\nMixing datasets...")
train_dataset = interleave_datasets(
[fineweb, cosmopedia, openwebtext],
probabilities=[0.45, 0.30, 0.25],
seed=42
)
# Tokenization function
def tokenize(examples):
if 'text' in examples:
texts = examples['text']
elif 'content' in examples:
texts = examples['content']
else:
texts = list(examples.values())[0]
return tokenizer(
texts,
truncation=True,
max_length=MAX_LENGTH,
padding=False
)
# Tokenize datasets
print("Tokenizing...")
tokenized_train = train_dataset.map(
tokenize,
batched=True,
remove_columns=train_dataset.column_names
).filter(lambda x: len(x['input_ids']) >= 128)
# Validation set
val_dataset = load_dataset(
"HuggingFaceFW/fineweb-edu",
name="sample-10BT",
split="train",
streaming=True
).take(1000)
val_tokenized = val_dataset.map(
tokenize,
batched=True,
remove_columns=val_dataset.column_names
).filter(lambda x: len(x['input_ids']) >= 128)
# Build model
print("\nBuilding model...")
config = RecursiveLanguageModelConfig(
vocab_size=len(tokenizer),
embedding_dim=512,
num_layers=6,
num_attention_heads=8,
max_recursion_steps=5,
max_position_embeddings=512,
intermediate_size=2048,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.pad_token_id,
simple_recursion_steps=1,
medium_recursion_steps=3,
complex_recursion_steps=5,
use_adaptive_stopping=True,
hidden_dropout_prob=0.1,
attention_dropout_prob=0.1
)
model = RecursiveLanguageModel(config)
params = sum(p.numel() for p in model.parameters()) / 1e6
print(f"Model parameters: {params:.1f}M")
# Clear cache
torch.cuda.empty_cache()
gc.collect()
# Training setup
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False
)
steps_per_epoch = TOTAL_SAMPLES // (BATCH_SIZE * GRAD_ACCUM)
max_steps = steps_per_epoch * EPOCHS
print(f"\nTraining steps: {max_steps}")
print(f"Effective batch size: {BATCH_SIZE * GRAD_ACCUM}")
training_args = TrainingArguments(
output_dir="./checkpoints",
max_steps=max_steps,
per_device_train_batch_size=BATCH_SIZE,
per_device_eval_batch_size=BATCH_SIZE,
gradient_accumulation_steps=GRAD_ACCUM,
learning_rate=LEARNING_RATE,
weight_decay=0.01,
warmup_steps=500,
fp16=True,
logging_steps=100,
eval_strategy="steps",
eval_steps=1000,
save_steps=1000,
save_total_limit=2,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
report_to="none",
max_grad_norm=1.0,
save_safetensors=False, # Use PyTorch format instead of safetensors
)
# Custom trainer with perplexity
class CustomTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
outputs = model(**inputs)
return (outputs.loss, outputs) if return_outputs else outputs.loss
def evaluation_loop(self, dataloader, description, prediction_loss_only=None,
ignore_keys=None, metric_key_prefix="eval"):
output = super().evaluation_loop(
dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
)
if output.metrics.get(f"{metric_key_prefix}_loss") is not None:
try:
perplexity = math.exp(output.metrics[f"{metric_key_prefix}_loss"])
output.metrics[f"{metric_key_prefix}_perplexity"] = perplexity
except OverflowError:
output.metrics[f"{metric_key_prefix}_perplexity"] = float("inf")
return output
def training_step(self, model, inputs, num_items_in_batch=None):
loss = super().training_step(model, inputs, num_items_in_batch)
if self.state.global_step % 50 == 0:
torch.cuda.empty_cache()
return loss
trainer = CustomTrainer(
model=model,
args=training_args,
train_dataset=tokenized_train,
eval_dataset=val_tokenized,
data_collator=data_collator
)
# Train
print("\nStarting training...")
print("-" * 60)
try:
trainer.train()
# Final evaluation
print("\nFinal evaluation...")
metrics = trainer.evaluate()
print("\n" + "="*60)
print("FINAL RESULTS:")
print("="*60)
print(f"Evaluation Loss: {metrics['eval_loss']:.4f}")
if 'eval_perplexity' in metrics:
print(f"Perplexity: {metrics['eval_perplexity']:.2f}")
else:
try:
perplexity = math.exp(metrics['eval_loss'])
print(f"Perplexity: {perplexity:.2f}")
except OverflowError:
print(f"Perplexity: inf (loss too high)")
print("="*60 + "\n")
# Save with custom method (handles tied weights properly)
print("Saving model...")
model.save_pretrained("./recursive-lm")
tokenizer.save_pretrained("./recursive-lm")
print("Model saved successfully!")
except KeyboardInterrupt:
print("\n\nTraining interrupted by user")
print("Saving current model state...")
model.save_pretrained("./recursive-lm-interrupted")
tokenizer.save_pretrained("./recursive-lm-interrupted")
except Exception as e:
print(f"\n\nTraining stopped due to: {e}")
import traceback
traceback.print_exc()
# Try to save anyway
try:
print("\nAttempting to save model...")
model.save_pretrained("./recursive-lm-error")
tokenizer.save_pretrained("./recursive-lm-error")
print("Model saved!")
except:
print("Could not save model")
print("\nTraining complete!")