| |
|
| | """
|
| | Enhanced Training Script - Uses your existing trainer.py with HF datasets
|
| | This integrates with your current MambaSwarmTrainer system
|
| | """
|
| |
|
| | import os
|
| | import sys
|
| | from pathlib import Path
|
| | import logging
|
| |
|
| |
|
| | project_root = Path(__file__).parent.parent
|
| | sys.path.append(str(project_root))
|
| |
|
| |
|
| | from core.config import MambaConfig
|
| | from training.trainer import MambaSwarmTrainer
|
| |
|
| |
|
| | from datasets import load_dataset
|
| | import json
|
| |
|
| | logger = logging.getLogger(__name__)
|
| |
|
| | def prepare_hf_dataset_for_existing_system(dataset_name: str = "wikitext-103-v1",
|
| | output_path: str = "train_data.txt"):
|
| | """
|
| | Download HF dataset and convert to format your existing trainer expects
|
| | """
|
| |
|
| | logger.info(f"π₯ Loading {dataset_name} from Hugging Face...")
|
| |
|
| | try:
|
| |
|
| | if dataset_name == "wikitext-103-v1":
|
| | dataset = load_dataset("wikitext", "wikitext-103-v1", split="train")
|
| | text_column = "text"
|
| | elif dataset_name == "openwebtext":
|
| | dataset = load_dataset("openwebtext", split="train[:10000]")
|
| | text_column = "text"
|
| | elif dataset_name == "tiny_shakespeare":
|
| | dataset = load_dataset("tiny_shakespeare", split="train")
|
| | text_column = "text"
|
| | else:
|
| |
|
| | dataset = load_dataset(dataset_name, split="train")
|
| | text_column = "text"
|
| |
|
| |
|
| | logger.info(f"π Converting to text format...")
|
| |
|
| | with open(output_path, 'w', encoding='utf-8') as f:
|
| | for example in dataset:
|
| | text = example.get(text_column, "")
|
| | if text and len(text.strip()) > 20:
|
| | f.write(text.strip() + "\n\n")
|
| |
|
| | logger.info(f"β
Dataset saved to {output_path}")
|
| | return output_path
|
| |
|
| | except Exception as e:
|
| | logger.error(f"β Failed to load {dataset_name}: {e}")
|
| |
|
| |
|
| | logger.info("Creating fallback training data...")
|
| | with open(output_path, 'w', encoding='utf-8') as f:
|
| | for i in range(1000):
|
| | f.write(f"This is training example number {i}. It contains meaningful text for language modeling.\n\n")
|
| |
|
| | return output_path
|
| |
|
| | def run_existing_trainer_with_hf_data():
|
| | """
|
| | Use your existing MambaSwarmTrainer but with HF dataset
|
| | """
|
| |
|
| | logger.info("π Starting Mamba Swarm Training with HF Data")
|
| | logger.info("=" * 60)
|
| |
|
| |
|
| | logger.info("Step 1: Preparing Hugging Face dataset...")
|
| | dataset_path = prepare_hf_dataset_for_existing_system("wikitext-103-v1", "train_data.txt")
|
| |
|
| |
|
| | logger.info("Step 2: Creating MambaConfig...")
|
| | config = MambaConfig(
|
| |
|
| | vocab_size=50257,
|
| | d_model=768,
|
| | n_layers=8,
|
| |
|
| |
|
| | batch_size=2,
|
| | learning_rate=1e-4,
|
| | max_seq_len=512,
|
| |
|
| |
|
| | num_specialists=20,
|
| |
|
| |
|
| | warmup_steps=100,
|
| | max_steps=2000,
|
| |
|
| |
|
| | train_data_path=dataset_path
|
| | )
|
| |
|
| | logger.info(f"β
Config created:")
|
| | logger.info(f" - Model: {config.d_model}D, {config.n_layers} layers")
|
| | logger.info(f" - Specialists: {config.num_specialists}")
|
| | logger.info(f" - Batch size: {config.batch_size}")
|
| | logger.info(f" - Training data: {config.train_data_path}")
|
| |
|
| |
|
| | logger.info("Step 3: Initializing MambaSwarmTrainer...")
|
| | try:
|
| | trainer = MambaSwarmTrainer(config)
|
| | logger.info("β
Trainer initialized successfully")
|
| | except Exception as e:
|
| | logger.error(f"β Trainer initialization failed: {e}")
|
| | return False
|
| |
|
| |
|
| | logger.info("Step 4: Starting training pipeline...")
|
| | logger.info("This will run your 4-phase training:")
|
| | logger.info(" Phase 1: Foundation training")
|
| | logger.info(" Phase 2: Specialist training")
|
| | logger.info(" Phase 3: Aggregator training")
|
| | logger.info(" Phase 4: End-to-end fine-tuning")
|
| |
|
| | try:
|
| |
|
| | trainer.full_training_pipeline()
|
| |
|
| | logger.info("π Training completed successfully!")
|
| |
|
| |
|
| | checkpoint_dir = "checkpoints"
|
| | os.makedirs(checkpoint_dir, exist_ok=True)
|
| | checkpoint_path = os.path.join(checkpoint_dir, "mamba_swarm_hf_trained.pt")
|
| | trainer.save_checkpoint(checkpoint_path)
|
| |
|
| | logger.info(f"πΎ Checkpoint saved: {checkpoint_path}")
|
| |
|
| |
|
| | logger.info("π Running evaluation...")
|
| | eval_results = trainer.evaluate(eval_steps=50)
|
| | logger.info(f"Evaluation results: {eval_results}")
|
| |
|
| | return True
|
| |
|
| | except Exception as e:
|
| | logger.error(f"β Training failed: {e}")
|
| | return False
|
| |
|
| | def quick_test_run():
|
| | """Quick test with minimal settings"""
|
| |
|
| | logger.info("π Quick Test Run")
|
| |
|
| |
|
| | dataset_path = prepare_hf_dataset_for_existing_system("tiny_shakespeare", "test_data.txt")
|
| |
|
| |
|
| | config = MambaConfig(
|
| | d_model=256,
|
| | n_layers=4,
|
| | batch_size=1,
|
| | num_specialists=5,
|
| | warmup_steps=10,
|
| | max_steps=50,
|
| | train_data_path=dataset_path
|
| | )
|
| |
|
| | trainer = MambaSwarmTrainer(config)
|
| |
|
| |
|
| | logger.info("Running foundation training only...")
|
| | trainer.train_foundation_phase(num_steps=20)
|
| |
|
| | logger.info("β
Quick test completed!")
|
| |
|
| | if __name__ == "__main__":
|
| | import argparse
|
| |
|
| |
|
| | logging.basicConfig(
|
| | level=logging.INFO,
|
| | format='%(asctime)s - %(levelname)s - %(message)s'
|
| | )
|
| |
|
| | parser = argparse.ArgumentParser(description="Enhanced Mamba training with HF datasets")
|
| | parser.add_argument("--quick-test", action="store_true", help="Run quick test with minimal settings")
|
| | parser.add_argument("--dataset", default="wikitext-103-v1", help="HuggingFace dataset to use")
|
| |
|
| | args = parser.parse_args()
|
| |
|
| | if args.quick_test:
|
| | quick_test_run()
|
| | else:
|
| | success = run_existing_trainer_with_hf_data()
|
| | if success:
|
| | print("\nπ Training completed successfully!")
|
| | print("Your trained Mamba swarm model is ready to use!")
|
| | else:
|
| | print("\nβ Training failed. Check the logs above for details.")
|
| |
|