|
|
"""
|
|
|
Main entry point for Mamba Swarm
|
|
|
100 units of 70M parameter Mamba encoders for distributed language modeling
|
|
|
"""
|
|
|
|
|
|
import os
|
|
|
import sys
|
|
|
import argparse
|
|
|
import logging
|
|
|
import asyncio
|
|
|
from pathlib import Path
|
|
|
from typing import Dict, Any, Optional
|
|
|
|
|
|
|
|
|
project_root = Path(__file__).parent
|
|
|
sys.path.insert(0, str(project_root))
|
|
|
|
|
|
|
|
|
from core.config import MambaSwarmConfig
|
|
|
from system.mambaSwarm import SwarmEngine
|
|
|
from system.inference import InferenceEngine
|
|
|
from api.api_server import run_server
|
|
|
from api.load_balancer import run_load_balancer, LoadBalancingStrategy
|
|
|
from training.trainer import DistributedTrainer
|
|
|
from monitoring.metrics import MambaSwarmMetrics
|
|
|
from monitoring.profiler import MambaSwarmProfiler
|
|
|
from monitoring.evaluator import MambaSwarmEvaluator
|
|
|
from checkpoints.checkpoint_manager import CheckpointManager
|
|
|
from training.trainer import setup_logging, get_device_info
|
|
|
|
|
|
def setup_argument_parser():
|
|
|
"""Setup command line argument parser"""
|
|
|
parser = argparse.ArgumentParser(description="Mamba Swarm - Distributed Language Model")
|
|
|
|
|
|
|
|
|
parser.add_argument("mode", choices=["train", "serve", "evaluate", "load_balance"],
|
|
|
help="Operation mode")
|
|
|
|
|
|
|
|
|
parser.add_argument("--config", type=str, default="config/default.yaml",
|
|
|
help="Configuration file path")
|
|
|
parser.add_argument("--checkpoint", type=str, default=None,
|
|
|
help="Checkpoint to load")
|
|
|
|
|
|
|
|
|
parser.add_argument("--epochs", type=int, default=10,
|
|
|
help="Number of training epochs")
|
|
|
parser.add_argument("--batch-size", type=int, default=8,
|
|
|
help="Training batch size")
|
|
|
parser.add_argument("--learning-rate", type=float, default=1e-4,
|
|
|
help="Learning rate")
|
|
|
parser.add_argument("--data-path", type=str, default="data/",
|
|
|
help="Training data path")
|
|
|
|
|
|
|
|
|
parser.add_argument("--host", type=str, default="0.0.0.0",
|
|
|
help="Server host")
|
|
|
parser.add_argument("--port", type=int, default=8000,
|
|
|
help="Server port")
|
|
|
parser.add_argument("--workers", type=int, default=1,
|
|
|
help="Number of worker processes")
|
|
|
|
|
|
|
|
|
parser.add_argument("--servers", type=str, nargs="+",
|
|
|
help="Backend server addresses (host:port)")
|
|
|
parser.add_argument("--strategy", type=str, default="resource_aware",
|
|
|
choices=["round_robin", "least_connections", "weighted_round_robin",
|
|
|
"least_response_time", "hash_based", "resource_aware"],
|
|
|
help="Load balancing strategy")
|
|
|
|
|
|
|
|
|
parser.add_argument("--eval-data", type=str, default="data/eval/",
|
|
|
help="Evaluation data path")
|
|
|
parser.add_argument("--output-report", type=str, default=None,
|
|
|
help="Evaluation report output path")
|
|
|
|
|
|
|
|
|
parser.add_argument("--num-encoders", type=int, default=100,
|
|
|
help="Number of Mamba encoders")
|
|
|
parser.add_argument("--encoder-params", type=int, default=70000000,
|
|
|
help="Parameters per encoder (70M)")
|
|
|
parser.add_argument("--device", type=str, default="auto",
|
|
|
help="Device to use (cuda, cpu, auto)")
|
|
|
parser.add_argument("--distributed", action="store_true",
|
|
|
help="Enable distributed training")
|
|
|
|
|
|
|
|
|
parser.add_argument("--enable-metrics", action="store_true",
|
|
|
help="Enable metrics collection")
|
|
|
parser.add_argument("--enable-profiling", action="store_true",
|
|
|
help="Enable performance profiling")
|
|
|
parser.add_argument("--metrics-port", type=int, default=9090,
|
|
|
help="Metrics server port")
|
|
|
|
|
|
|
|
|
parser.add_argument("--log-level", type=str, default="INFO",
|
|
|
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
|
|
help="Logging level")
|
|
|
parser.add_argument("--log-file", type=str, default=None,
|
|
|
help="Log file path")
|
|
|
|
|
|
return parser
|
|
|
|
|
|
async def train_mode(args, config: MambaSwarmConfig):
|
|
|
"""Training mode"""
|
|
|
logging.info("Starting Mamba Swarm training...")
|
|
|
|
|
|
|
|
|
metrics = MambaSwarmMetrics() if args.enable_metrics else None
|
|
|
profiler = MambaSwarmProfiler() if args.enable_profiling else None
|
|
|
|
|
|
|
|
|
swarm_engine = SwarmEngine(config)
|
|
|
swarm_engine.initialize()
|
|
|
|
|
|
|
|
|
checkpoint_manager = CheckpointManager(
|
|
|
checkpoint_dir=config.checkpoint_dir,
|
|
|
max_checkpoints=config.max_checkpoints,
|
|
|
save_interval=config.save_interval
|
|
|
)
|
|
|
|
|
|
|
|
|
if args.checkpoint:
|
|
|
checkpoint_data = checkpoint_manager.load_checkpoint(args.checkpoint)
|
|
|
if checkpoint_data:
|
|
|
swarm_engine.load_state_dict(checkpoint_data["model_state"])
|
|
|
logging.info(f"Loaded checkpoint: {args.checkpoint}")
|
|
|
|
|
|
|
|
|
trainer = DistributedTrainer(
|
|
|
swarm_engine=swarm_engine,
|
|
|
config=config,
|
|
|
checkpoint_manager=checkpoint_manager,
|
|
|
metrics=metrics,
|
|
|
profiler=profiler
|
|
|
)
|
|
|
|
|
|
try:
|
|
|
|
|
|
if metrics:
|
|
|
metrics.start_monitoring()
|
|
|
if profiler:
|
|
|
profiler.start_profiling()
|
|
|
|
|
|
|
|
|
await trainer.train(
|
|
|
data_path=args.data_path,
|
|
|
epochs=args.epochs,
|
|
|
batch_size=args.batch_size,
|
|
|
learning_rate=args.learning_rate
|
|
|
)
|
|
|
|
|
|
finally:
|
|
|
|
|
|
if metrics:
|
|
|
metrics.stop_monitoring()
|
|
|
if profiler:
|
|
|
profiler.cleanup()
|
|
|
swarm_engine.shutdown()
|
|
|
|
|
|
def serve_mode(args, config: MambaSwarmConfig):
|
|
|
"""API serving mode"""
|
|
|
logging.info("Starting Mamba Swarm API server...")
|
|
|
|
|
|
|
|
|
run_server(
|
|
|
host=args.host,
|
|
|
port=args.port,
|
|
|
workers=args.workers
|
|
|
)
|
|
|
|
|
|
def load_balance_mode(args, config: MambaSwarmConfig):
|
|
|
"""Load balancer mode"""
|
|
|
logging.info("Starting Mamba Swarm load balancer...")
|
|
|
|
|
|
|
|
|
servers = []
|
|
|
for server_addr in args.servers or []:
|
|
|
if ":" in server_addr:
|
|
|
host, port = server_addr.split(":", 1)
|
|
|
servers.append((host, int(port)))
|
|
|
else:
|
|
|
servers.append((server_addr, 8000))
|
|
|
|
|
|
if not servers:
|
|
|
logging.error("No backend servers specified")
|
|
|
return
|
|
|
|
|
|
|
|
|
strategy_map = {
|
|
|
"round_robin": LoadBalancingStrategy.ROUND_ROBIN,
|
|
|
"least_connections": LoadBalancingStrategy.LEAST_CONNECTIONS,
|
|
|
"weighted_round_robin": LoadBalancingStrategy.WEIGHTED_ROUND_ROBIN,
|
|
|
"least_response_time": LoadBalancingStrategy.LEAST_RESPONSE_TIME,
|
|
|
"hash_based": LoadBalancingStrategy.HASH_BASED,
|
|
|
"resource_aware": LoadBalancingStrategy.RESOURCE_AWARE
|
|
|
}
|
|
|
|
|
|
strategy = strategy_map.get(args.strategy, LoadBalancingStrategy.RESOURCE_AWARE)
|
|
|
|
|
|
|
|
|
run_load_balancer(
|
|
|
servers=servers,
|
|
|
host=args.host,
|
|
|
port=args.port,
|
|
|
strategy=strategy
|
|
|
)
|
|
|
|
|
|
async def evaluate_mode(args, config: MambaSwarmConfig):
|
|
|
"""Evaluation mode"""
|
|
|
logging.info("Starting Mamba Swarm evaluation...")
|
|
|
|
|
|
|
|
|
swarm_engine = SwarmEngine(config)
|
|
|
swarm_engine.initialize()
|
|
|
|
|
|
|
|
|
if args.checkpoint:
|
|
|
checkpoint_manager = CheckpointManager(config.checkpoint_dir)
|
|
|
checkpoint_data = checkpoint_manager.load_checkpoint(args.checkpoint)
|
|
|
if checkpoint_data:
|
|
|
swarm_engine.load_state_dict(checkpoint_data["model_state"])
|
|
|
logging.info(f"Loaded checkpoint: {args.checkpoint}")
|
|
|
|
|
|
|
|
|
evaluator = MambaSwarmEvaluator(swarm_engine, config.__dict__)
|
|
|
|
|
|
try:
|
|
|
|
|
|
result = evaluator.run_comprehensive_evaluation()
|
|
|
|
|
|
|
|
|
print(f"\nEvaluation Results:")
|
|
|
print(f"Overall Score: {result.overall_score:.3f}")
|
|
|
print(f"Execution Time: {result.execution_time:.2f}s")
|
|
|
print(f"Total Metrics: {len(result.individual_metrics)}")
|
|
|
|
|
|
|
|
|
print(f"\nTop Metrics:")
|
|
|
for metric in result.individual_metrics[:10]:
|
|
|
print(f" {metric.metric_name}: {metric.score:.3f}")
|
|
|
|
|
|
|
|
|
output_path = args.output_report or f"evaluation_report_{int(result.timestamp)}.json"
|
|
|
report_file = evaluator.export_evaluation_report(result, output_path)
|
|
|
print(f"\nDetailed report saved to: {report_file}")
|
|
|
|
|
|
finally:
|
|
|
swarm_engine.shutdown()
|
|
|
|
|
|
def validate_config(args) -> MambaSwarmConfig:
|
|
|
"""Validate and create configuration"""
|
|
|
|
|
|
|
|
|
if os.path.exists(args.config):
|
|
|
config = MambaSwarmConfig.from_file(args.config)
|
|
|
else:
|
|
|
logging.warning(f"Config file {args.config} not found, using defaults")
|
|
|
config = MambaSwarmConfig()
|
|
|
|
|
|
|
|
|
if args.num_encoders:
|
|
|
config.num_encoders = args.num_encoders
|
|
|
if args.encoder_params:
|
|
|
config.encoder_params = args.encoder_params
|
|
|
|
|
|
|
|
|
if args.device == "auto":
|
|
|
device_info = get_device_info()
|
|
|
config.device = "cuda" if device_info["cuda_available"] else "cpu"
|
|
|
else:
|
|
|
config.device = args.device
|
|
|
|
|
|
|
|
|
total_params = config.num_encoders * config.encoder_params
|
|
|
logging.info(f"Configuration: {config.num_encoders} encoders × {config.encoder_params/1e6:.0f}M params = {total_params/1e9:.1f}B total parameters")
|
|
|
|
|
|
return config
|
|
|
|
|
|
def main():
|
|
|
"""Main entry point"""
|
|
|
parser = setup_argument_parser()
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
setup_logging(
|
|
|
level=getattr(logging, args.log_level),
|
|
|
log_file=args.log_file
|
|
|
)
|
|
|
|
|
|
|
|
|
print("=" * 60)
|
|
|
print("🐍 Mamba Swarm - Distributed Language Model")
|
|
|
print("100 × 70M Parameter Mamba Encoders")
|
|
|
print("=" * 60)
|
|
|
|
|
|
|
|
|
try:
|
|
|
config = validate_config(args)
|
|
|
except Exception as e:
|
|
|
logging.error(f"Configuration validation failed: {e}")
|
|
|
sys.exit(1)
|
|
|
|
|
|
|
|
|
device_info = get_device_info()
|
|
|
logging.info(f"System: {device_info['cpu_count']} CPUs, {device_info['memory_gb']:.1f}GB RAM")
|
|
|
if device_info["cuda_available"]:
|
|
|
logging.info(f"GPU: {device_info['gpu_count']} devices, {device_info['gpu_memory_gb']:.1f}GB VRAM")
|
|
|
|
|
|
|
|
|
try:
|
|
|
if args.mode == "train":
|
|
|
asyncio.run(train_mode(args, config))
|
|
|
elif args.mode == "serve":
|
|
|
serve_mode(args, config)
|
|
|
elif args.mode == "load_balance":
|
|
|
load_balance_mode(args, config)
|
|
|
elif args.mode == "evaluate":
|
|
|
asyncio.run(evaluate_mode(args, config))
|
|
|
else:
|
|
|
logging.error(f"Unknown mode: {args.mode}")
|
|
|
sys.exit(1)
|
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
logging.info("Received interrupt signal, shutting down...")
|
|
|
except Exception as e:
|
|
|
logging.error(f"Application error: {e}", exc_info=True)
|
|
|
sys.exit(1)
|
|
|
|
|
|
logging.info("Mamba Swarm shutdown complete")
|
|
|
|
|
|
def print_usage_examples():
|
|
|
"""Print usage examples"""
|
|
|
examples = """
|
|
|
Usage Examples:
|
|
|
|
|
|
1. Training:
|
|
|
python main.py train --data-path ./data/train --epochs 10 --batch-size 8 --enable-metrics
|
|
|
|
|
|
2. Serving:
|
|
|
python main.py serve --host 0.0.0.0 --port 8000 --checkpoint best_model.pt
|
|
|
|
|
|
3. Load Balancing:
|
|
|
python main.py load_balance --servers localhost:8000 localhost:8001 localhost:8002 --strategy resource_aware
|
|
|
|
|
|
4. Evaluation:
|
|
|
python main.py evaluate --checkpoint best_model.pt --eval-data ./data/eval --output-report eval_results.json
|
|
|
|
|
|
5. Distributed Training:
|
|
|
python main.py train --distributed --num-encoders 100 --batch-size 4 --enable-profiling
|
|
|
|
|
|
Configuration File Example (config.yaml):
|
|
|
---
|
|
|
num_encoders: 100
|
|
|
encoder_params: 70000000
|
|
|
hidden_size: 2048
|
|
|
num_layers: 32
|
|
|
vocab_size: 50000
|
|
|
max_sequence_length: 2048
|
|
|
device: "auto"
|
|
|
checkpoint_dir: "./checkpoints"
|
|
|
max_checkpoints: 10
|
|
|
save_interval: 1000
|
|
|
learning_rate: 1e-4
|
|
|
warmup_steps: 1000
|
|
|
weight_decay: 0.01
|
|
|
gradient_clip_norm: 1.0
|
|
|
mixed_precision: true
|
|
|
gradient_accumulation_steps: 8
|
|
|
"""
|
|
|
print(examples)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
if len(sys.argv) == 2 and sys.argv[1] in ["--help-examples", "-he"]:
|
|
|
print_usage_examples()
|
|
|
sys.exit(0)
|
|
|
|
|
|
main() |