import torch import torch.nn as nn import math from transformers import AutoTokenizer, Trainer, TrainingArguments, DataCollatorForLanguageModeling from datasets import load_dataset, interleave_datasets from mixture_of_recursion import RecursiveLanguageModel, RecursiveLanguageModelConfig import gc # Configuration TOTAL_SAMPLES = 50000 BATCH_SIZE = 1 GRAD_ACCUM = 32 EPOCHS = 3 LEARNING_RATE = 3e-4 MAX_LENGTH = 384 print("Starting training with 50K premium samples") print("-" * 60) # Load tokenizer print("\nLoading tokenizer...") tokenizer = AutoTokenizer.from_pretrained("gpt2") tokenizer.pad_token = tokenizer.eos_token print(f"Tokenizer vocab size: {len(tokenizer)}") print(f"Pad token ID: {tokenizer.pad_token_id}") # Load datasets print("\nLoading datasets...") print(" FineWeb-Edu (45%)") fineweb = load_dataset( "HuggingFaceFW/fineweb-edu", name="sample-10BT", split="train", streaming=True ).shuffle(seed=42).take(int(TOTAL_SAMPLES * 0.45)) print(" Cosmopedia (30%)") cosmopedia = load_dataset( "HuggingFaceTB/cosmopedia", "web_samples_v1", split="train", streaming=True ).shuffle(seed=42).take(int(TOTAL_SAMPLES * 0.30)) print(" OpenWebText (25%)") openwebtext = load_dataset( "openwebtext", split="train", streaming=True ).shuffle(seed=42).take(int(TOTAL_SAMPLES * 0.25)) # Mix datasets print("\nMixing datasets...") train_dataset = interleave_datasets( [fineweb, cosmopedia, openwebtext], probabilities=[0.45, 0.30, 0.25], seed=42 ) # Tokenization function def tokenize(examples): if 'text' in examples: texts = examples['text'] elif 'content' in examples: texts = examples['content'] else: texts = list(examples.values())[0] return tokenizer( texts, truncation=True, max_length=MAX_LENGTH, padding=False ) # Tokenize datasets print("Tokenizing...") tokenized_train = train_dataset.map( tokenize, batched=True, remove_columns=train_dataset.column_names ).filter(lambda x: len(x['input_ids']) >= 128) # Validation set val_dataset = load_dataset( "HuggingFaceFW/fineweb-edu", name="sample-10BT", split="train", streaming=True ).take(1000) val_tokenized = val_dataset.map( tokenize, batched=True, remove_columns=val_dataset.column_names ).filter(lambda x: len(x['input_ids']) >= 128) # Build model print("\nBuilding model...") config = RecursiveLanguageModelConfig( vocab_size=len(tokenizer), embedding_dim=512, num_layers=6, num_attention_heads=8, max_recursion_steps=5, max_position_embeddings=512, intermediate_size=2048, pad_token_id=tokenizer.pad_token_id, bos_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.pad_token_id, simple_recursion_steps=1, medium_recursion_steps=3, complex_recursion_steps=5, use_adaptive_stopping=True, hidden_dropout_prob=0.1, attention_dropout_prob=0.1 ) model = RecursiveLanguageModel(config) params = sum(p.numel() for p in model.parameters()) / 1e6 print(f"Model parameters: {params:.1f}M") # Clear cache torch.cuda.empty_cache() gc.collect() # Training setup data_collator = DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=False ) steps_per_epoch = TOTAL_SAMPLES // (BATCH_SIZE * GRAD_ACCUM) max_steps = steps_per_epoch * EPOCHS print(f"\nTraining steps: {max_steps}") print(f"Effective batch size: {BATCH_SIZE * GRAD_ACCUM}") training_args = TrainingArguments( output_dir="./checkpoints", max_steps=max_steps, per_device_train_batch_size=BATCH_SIZE, per_device_eval_batch_size=BATCH_SIZE, gradient_accumulation_steps=GRAD_ACCUM, learning_rate=LEARNING_RATE, weight_decay=0.01, warmup_steps=500, fp16=True, logging_steps=100, eval_strategy="steps", eval_steps=1000, save_steps=1000, save_total_limit=2, load_best_model_at_end=True, metric_for_best_model="eval_loss", report_to="none", max_grad_norm=1.0, save_safetensors=False, # Use PyTorch format instead of safetensors ) # Custom trainer with perplexity class CustomTrainer(Trainer): def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): outputs = model(**inputs) return (outputs.loss, outputs) if return_outputs else outputs.loss def evaluation_loop(self, dataloader, description, prediction_loss_only=None, ignore_keys=None, metric_key_prefix="eval"): output = super().evaluation_loop( dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix ) if output.metrics.get(f"{metric_key_prefix}_loss") is not None: try: perplexity = math.exp(output.metrics[f"{metric_key_prefix}_loss"]) output.metrics[f"{metric_key_prefix}_perplexity"] = perplexity except OverflowError: output.metrics[f"{metric_key_prefix}_perplexity"] = float("inf") return output def training_step(self, model, inputs, num_items_in_batch=None): loss = super().training_step(model, inputs, num_items_in_batch) if self.state.global_step % 50 == 0: torch.cuda.empty_cache() return loss trainer = CustomTrainer( model=model, args=training_args, train_dataset=tokenized_train, eval_dataset=val_tokenized, data_collator=data_collator ) # Train print("\nStarting training...") print("-" * 60) try: trainer.train() # Final evaluation print("\nFinal evaluation...") metrics = trainer.evaluate() print("\n" + "="*60) print("FINAL RESULTS:") print("="*60) print(f"Evaluation Loss: {metrics['eval_loss']:.4f}") if 'eval_perplexity' in metrics: print(f"Perplexity: {metrics['eval_perplexity']:.2f}") else: try: perplexity = math.exp(metrics['eval_loss']) print(f"Perplexity: {perplexity:.2f}") except OverflowError: print(f"Perplexity: inf (loss too high)") print("="*60 + "\n") # Save with custom method (handles tied weights properly) print("Saving model...") model.save_pretrained("./recursive-lm") tokenizer.save_pretrained("./recursive-lm") print("Model saved successfully!") except KeyboardInterrupt: print("\n\nTraining interrupted by user") print("Saving current model state...") model.save_pretrained("./recursive-lm-interrupted") tokenizer.save_pretrained("./recursive-lm-interrupted") except Exception as e: print(f"\n\nTraining stopped due to: {e}") import traceback traceback.print_exc() # Try to save anyway try: print("\nAttempting to save model...") model.save_pretrained("./recursive-lm-error") tokenizer.save_pretrained("./recursive-lm-error") print("Model saved!") except: print("Could not save model") print("\nTraining complete!")