Spaces:
Sleeping
Sleeping
| import argparse | |
| import sys | |
| import torch | |
| import os | |
| import torch.nn.functional as F | |
| from aetheris.config import AetherisConfig | |
| from aetheris.model import HybridMambaMoE | |
| from aetheris.data import create_streaming_loader, get_tokenizer | |
| from aetheris.utils import load_latest_checkpoint, calculate_model_stats | |
| from aetheris.trainer import Trainer | |
| def train_command(args): | |
| print(f"\n{'='*70}") | |
| print(f"Aetheris Training") | |
| print(f"Config: {args.config}") | |
| if args.hf_token: | |
| print(f"Using Hugging Face token: {args.hf_token[:10]}...") | |
| from huggingface_hub import login | |
| login(token=args.hf_token) | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| if device.type == 'cuda': | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| torch.cuda.empty_cache() | |
| config = AetherisConfig.from_yaml(args.config) | |
| tokenizer = get_tokenizer() | |
| print(f"Device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}") | |
| print(f"Model Size: d_model={config.d_model}, layers={config.n_layer}") | |
| print(f"{'='*70}\n") | |
| model = HybridMambaMoE(config).to(device) | |
| # Apply weight initialization | |
| print("Applying proper weight initialization...") | |
| model.apply(model._init_weights) | |
| # Calculate model stats | |
| stats = calculate_model_stats(model) | |
| print(f"Total Parameters: {stats['total_params']:,}") | |
| print(f"Trainable Parameters: {stats['trainable_params']:,}") | |
| # Use lower learning rate for stability | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01, | |
| betas=(0.9, 0.95), eps=1e-8, fused=False if device.type == 'cpu' else True) | |
| scaler = torch.amp.GradScaler('cuda' if device.type == 'cuda' else 'cpu', init_scale=2**10) | |
| start_step, current_stage = load_latest_checkpoint(model, optimizer, scaler, device, args.checkpoint_dir, args.checkpoint_name) | |
| trainer = Trainer(model, optimizer, scaler, config, device, args.checkpoint_dir) | |
| # --- STAGE 1: PRE-TRAINING --- | |
| if current_stage == "Pre-Training" or start_step == 0: | |
| pt_loader = create_streaming_loader("cerebras/SlimPajama-627B", "train", | |
| tokenizer, config, args.batch_size, mode="pretrain", | |
| hf_token=args.hf_token, start_step=start_step) | |
| # Validation loader (no skipping needed, always from start of val set) | |
| pt_val_loader = create_streaming_loader("cerebras/SlimPajama-627B", "validation", | |
| tokenizer, config, args.batch_size, mode="pretrain", | |
| hf_token=args.hf_token) | |
| start_step = trainer.train_epoch(pt_loader, total_steps=args.pretrain_steps, | |
| start_step=start_step, stage_name="Pre-Training", | |
| val_loader=pt_val_loader) | |
| current_stage = "SFT" | |
| start_step = 0 | |
| # --- STAGE 2: SFT --- | |
| print("\n=== STAGE 2: SFT ===") | |
| for param_group in optimizer.param_groups: | |
| param_group['lr'] = 5e-5 | |
| sft_loader = create_streaming_loader("OpenAssistant/oasst1", "train", | |
| tokenizer, config, args.batch_size, mode="sft", | |
| hf_token=args.hf_token, start_step=start_step) | |
| sft_val_loader = create_streaming_loader("OpenAssistant/oasst1", "validation", | |
| tokenizer, config, args.batch_size, mode="sft", | |
| hf_token=args.hf_token) | |
| trainer.train_epoch(sft_loader, total_steps=args.sft_steps, | |
| start_step=start_step, stage_name="SFT", | |
| val_loader=sft_val_loader) | |
| print("\nTraining Complete!") | |
| def generate_command(args): | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| config = AetherisConfig.from_yaml(args.config) | |
| tokenizer = get_tokenizer() | |
| model = HybridMambaMoE(config).to(device).to(config.torch_dtype) | |
| load_latest_checkpoint(model, None, None, device, args.checkpoint_dir, args.checkpoint_name) | |
| model.eval() | |
| prompt = args.prompt | |
| max_new_tokens = args.max_new_tokens | |
| temperature = args.temperature | |
| top_k = args.top_k | |
| top_p = args.top_p | |
| repetition_penalty = args.repetition_penalty | |
| input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device) | |
| generated_ids = input_ids.clone() | |
| history_ids = set(input_ids[0].tolist()) | |
| print("-" * 50) | |
| print(f"Prompt: {prompt}") | |
| print("Generated Continuation:") | |
| for _ in range(max_new_tokens): | |
| # Check if we should use autocast (skip if model uses float32) | |
| use_autocast = True | |
| if config.torch_dtype == torch.float32: | |
| use_autocast = False | |
| if use_autocast: | |
| with torch.amp.autocast('cuda' if device.type == 'cuda' else 'cpu', dtype=model.config.torch_dtype): | |
| outputs = model(generated_ids) | |
| logits = outputs['logits'] | |
| next_token_logits = logits[:, -1, :] | |
| else: | |
| outputs = model(generated_ids) | |
| logits = outputs['logits'] | |
| next_token_logits = logits[:, -1, :] | |
| # Repetition penalty | |
| for token_id in history_ids: | |
| if token_id < next_token_logits.size(-1): | |
| logit = next_token_logits[0, token_id].item() | |
| if logit > 0: | |
| next_token_logits[0, token_id] = logit / repetition_penalty | |
| else: | |
| next_token_logits[0, token_id] = logit * repetition_penalty | |
| # Temperature | |
| if temperature > 0: | |
| next_token_logits = next_token_logits / temperature | |
| # Top-p / Top-k | |
| if top_p < 1.0: | |
| sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) | |
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | |
| sorted_indices_to_remove = cumulative_probs > top_p | |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
| sorted_indices_to_remove[..., 0] = False | |
| indices_to_remove = sorted_indices[sorted_indices_to_remove] | |
| next_token_logits.scatter_(1, indices_to_remove.unsqueeze(0), float('-inf')) | |
| elif top_k > 0: | |
| top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k) | |
| next_token_logits = torch.full_like(next_token_logits, float('-inf')) | |
| next_token_logits.scatter_(1, top_k_indices, top_k_logits) | |
| # Sample | |
| next_token_probs = F.softmax(next_token_logits, dim=-1) | |
| next_token = torch.multinomial(next_token_probs, num_samples=1) | |
| next_token_item = next_token.item() | |
| if next_token_item == tokenizer.eos_token_id: | |
| break | |
| generated_ids = torch.cat([generated_ids, next_token], dim=-1) | |
| history_ids.add(next_token_item) | |
| new_token_text = tokenizer.decode(next_token.squeeze().tolist(), skip_special_tokens=True) | |
| print(new_token_text, end="", flush=True) | |
| print("\n" + "-" * 50) | |
| def info_command(args): | |
| config = AetherisConfig.from_yaml(args.config) | |
| model = HybridMambaMoE(config) | |
| total_params = 0 | |
| dense_params = 0 # Parameters active for EVERY token | |
| expert_params = 0 # Parameters in all MoE Experts | |
| for name, param in model.named_parameters(): | |
| numel = param.numel() | |
| total_params += numel | |
| if 'experts' in name: | |
| expert_params += numel | |
| else: | |
| dense_params += numel | |
| single_expert_size = expert_params / config.num_experts if config.num_experts > 0 else 0 | |
| active_per_token_params = dense_params + (single_expert_size * config.top_k) | |
| def format_count(count): | |
| return f"{count / 1_000_000:.2f}M" | |
| print("=" * 50) | |
| print("Hybrid Mamba-MoE Model Parameter Analysis") | |
| print("=" * 50) | |
| print(f"Total Model Layers (N_Layer): {config.n_layer}") | |
| print(f"MoE Experts per Layer: {config.num_experts}") | |
| print(f"Active Experts (Top-K): {config.top_k}") | |
| print("-" * 50) | |
| print(f"Total Parameters (Checkpoint Size): {format_count(total_params)}") | |
| print(f"Dense (Always Active) Parameters: {format_count(dense_params)}") | |
| print(f"Expert-Only Parameters: {format_count(expert_params)}") | |
| print("-" * 50) | |
| print(f"**Active Parameters (Per-Token Compute Load): {format_count(active_per_token_params)}**") | |
| print(" (This is the 'Dense' parameters + the K active expert parameters)") | |
| print("=" * 50) | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Aetheris CLI") | |
| subparsers = parser.add_subparsers(dest="command", help="Available commands") | |
| # Train Command | |
| train_parser = subparsers.add_parser("train", help="Train the model") | |
| train_parser.add_argument("--config", type=str, default="configs/default.yaml", help="Path to config file") | |
| train_parser.add_argument("--checkpoint_dir", type=str, default="checkpoints", help="Directory to save checkpoints") | |
| train_parser.add_argument("--hf_token", type=str, default=os.environ.get("HF_TOKEN"), help="HuggingFace Token") | |
| train_parser.add_argument("--batch_size", type=int, default=2, help="Batch size") | |
| train_parser.add_argument("--pretrain_steps", type=int, default=50000, help="Number of pretraining steps") | |
| train_parser.add_argument("--sft_steps", type=int, default=1000, help="Number of SFT steps") | |
| train_parser.add_argument("--checkpoint_name", type=str, default="checkpoint_current.pth", help="Checkpoint file name to load from") | |
| # Generate Command | |
| gen_parser = subparsers.add_parser("generate", help="Generate text") | |
| gen_parser.add_argument("--config", type=str, default="configs/default.yaml", help="Path to config file") | |
| gen_parser.add_argument("--checkpoint_dir", type=str, default="checkpoints", help="Directory with checkpoints") | |
| gen_parser.add_argument("--checkpoint_name", type=str, default="checkpoint_current.pth", help="Checkpoint file name") | |
| gen_parser.add_argument("--prompt", type=str, default="The quick brown fox", help="Prompt for generation") | |
| gen_parser.add_argument("--max_new_tokens", type=int, default=100, help="Max new tokens to generate") | |
| gen_parser.add_argument("--temperature", type=float, default=0.8, help="Sampling temperature") | |
| gen_parser.add_argument("--top_k", type=int, default=0, help="Top-k sampling") | |
| gen_parser.add_argument("--top_p", type=float, default=0.9, help="Top-p sampling") | |
| gen_parser.add_argument("--repetition_penalty", type=float, default=3.0, help="Repetition penalty") | |
| # Serve Command | |
| serve_parser = subparsers.add_parser("serve", help="Start the API server") | |
| serve_parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind") | |
| serve_parser.add_argument("--port", type=int, default=8000, help="Port to bind") | |
| serve_parser.add_argument("--config", type=str, default="configs/default.yaml", help="Path to config file") | |
| serve_parser.add_argument("--checkpoint_dir", type=str, default="checkpoints", help="Directory with checkpoints") | |
| serve_parser.add_argument("--checkpoint_name", type=str, default="checkpoint_current.pth", help="Checkpoint file name") | |
| # Info Command | |
| info_parser = subparsers.add_parser("info", help="Show model info") | |
| info_parser.add_argument("--config", type=str, default="configs/default.yaml", help="Path to config file") | |
| args = parser.parse_args() | |
| if args.command == "train": | |
| train_command(args) | |
| elif args.command == "generate": | |
| generate_command(args) | |
| elif args.command == "serve": | |
| import uvicorn | |
| from aetheris.api.server import app, get_engine | |
| # Initialize engine before starting server | |
| engine = get_engine() | |
| # You might want to pass config/checkpoint paths to get_engine here if it supported arguments | |
| # For now, it defaults or we need to modify get_engine or InferenceEngine to take args. | |
| # But `get_engine` is a simple global accessor. | |
| # Better: Initialize a global engine with args here. | |
| from aetheris.inference import InferenceEngine | |
| import aetheris.api.server | |
| aetheris.api.server.engine = InferenceEngine( | |
| config_path=args.config, | |
| checkpoint_dir=args.checkpoint_dir, | |
| checkpoint_name=args.checkpoint_name | |
| ) | |
| uvicorn.run(app, host=args.host, port=args.port) | |
| elif args.command == "info": | |
| info_command(args) | |
| else: | |
| parser.print_help() | |
| if __name__ == "__main__": | |
| main() | |