Pomilon
Deploy Aetheris to HF Space
1df0e33
import argparse
import sys
import torch
import os
import torch.nn.functional as F
from aetheris.config import AetherisConfig
from aetheris.model import HybridMambaMoE
from aetheris.data import create_streaming_loader, get_tokenizer
from aetheris.utils import load_latest_checkpoint, calculate_model_stats
from aetheris.trainer import Trainer
def train_command(args):
print(f"\n{'='*70}")
print(f"Aetheris Training")
print(f"Config: {args.config}")
if args.hf_token:
print(f"Using Hugging Face token: {args.hf_token[:10]}...")
from huggingface_hub import login
login(token=args.hf_token)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type == 'cuda':
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.cuda.empty_cache()
config = AetherisConfig.from_yaml(args.config)
tokenizer = get_tokenizer()
print(f"Device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
print(f"Model Size: d_model={config.d_model}, layers={config.n_layer}")
print(f"{'='*70}\n")
model = HybridMambaMoE(config).to(device)
# Apply weight initialization
print("Applying proper weight initialization...")
model.apply(model._init_weights)
# Calculate model stats
stats = calculate_model_stats(model)
print(f"Total Parameters: {stats['total_params']:,}")
print(f"Trainable Parameters: {stats['trainable_params']:,}")
# Use lower learning rate for stability
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01,
betas=(0.9, 0.95), eps=1e-8, fused=False if device.type == 'cpu' else True)
scaler = torch.amp.GradScaler('cuda' if device.type == 'cuda' else 'cpu', init_scale=2**10)
start_step, current_stage = load_latest_checkpoint(model, optimizer, scaler, device, args.checkpoint_dir, args.checkpoint_name)
trainer = Trainer(model, optimizer, scaler, config, device, args.checkpoint_dir)
# --- STAGE 1: PRE-TRAINING ---
if current_stage == "Pre-Training" or start_step == 0:
pt_loader = create_streaming_loader("cerebras/SlimPajama-627B", "train",
tokenizer, config, args.batch_size, mode="pretrain",
hf_token=args.hf_token, start_step=start_step)
# Validation loader (no skipping needed, always from start of val set)
pt_val_loader = create_streaming_loader("cerebras/SlimPajama-627B", "validation",
tokenizer, config, args.batch_size, mode="pretrain",
hf_token=args.hf_token)
start_step = trainer.train_epoch(pt_loader, total_steps=args.pretrain_steps,
start_step=start_step, stage_name="Pre-Training",
val_loader=pt_val_loader)
current_stage = "SFT"
start_step = 0
# --- STAGE 2: SFT ---
print("\n=== STAGE 2: SFT ===")
for param_group in optimizer.param_groups:
param_group['lr'] = 5e-5
sft_loader = create_streaming_loader("OpenAssistant/oasst1", "train",
tokenizer, config, args.batch_size, mode="sft",
hf_token=args.hf_token, start_step=start_step)
sft_val_loader = create_streaming_loader("OpenAssistant/oasst1", "validation",
tokenizer, config, args.batch_size, mode="sft",
hf_token=args.hf_token)
trainer.train_epoch(sft_loader, total_steps=args.sft_steps,
start_step=start_step, stage_name="SFT",
val_loader=sft_val_loader)
print("\nTraining Complete!")
@torch.no_grad()
def generate_command(args):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
config = AetherisConfig.from_yaml(args.config)
tokenizer = get_tokenizer()
model = HybridMambaMoE(config).to(device).to(config.torch_dtype)
load_latest_checkpoint(model, None, None, device, args.checkpoint_dir, args.checkpoint_name)
model.eval()
prompt = args.prompt
max_new_tokens = args.max_new_tokens
temperature = args.temperature
top_k = args.top_k
top_p = args.top_p
repetition_penalty = args.repetition_penalty
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
generated_ids = input_ids.clone()
history_ids = set(input_ids[0].tolist())
print("-" * 50)
print(f"Prompt: {prompt}")
print("Generated Continuation:")
for _ in range(max_new_tokens):
# Check if we should use autocast (skip if model uses float32)
use_autocast = True
if config.torch_dtype == torch.float32:
use_autocast = False
if use_autocast:
with torch.amp.autocast('cuda' if device.type == 'cuda' else 'cpu', dtype=model.config.torch_dtype):
outputs = model(generated_ids)
logits = outputs['logits']
next_token_logits = logits[:, -1, :]
else:
outputs = model(generated_ids)
logits = outputs['logits']
next_token_logits = logits[:, -1, :]
# Repetition penalty
for token_id in history_ids:
if token_id < next_token_logits.size(-1):
logit = next_token_logits[0, token_id].item()
if logit > 0:
next_token_logits[0, token_id] = logit / repetition_penalty
else:
next_token_logits[0, token_id] = logit * repetition_penalty
# Temperature
if temperature > 0:
next_token_logits = next_token_logits / temperature
# Top-p / Top-k
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = False
indices_to_remove = sorted_indices[sorted_indices_to_remove]
next_token_logits.scatter_(1, indices_to_remove.unsqueeze(0), float('-inf'))
elif top_k > 0:
top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k)
next_token_logits = torch.full_like(next_token_logits, float('-inf'))
next_token_logits.scatter_(1, top_k_indices, top_k_logits)
# Sample
next_token_probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(next_token_probs, num_samples=1)
next_token_item = next_token.item()
if next_token_item == tokenizer.eos_token_id:
break
generated_ids = torch.cat([generated_ids, next_token], dim=-1)
history_ids.add(next_token_item)
new_token_text = tokenizer.decode(next_token.squeeze().tolist(), skip_special_tokens=True)
print(new_token_text, end="", flush=True)
print("\n" + "-" * 50)
def info_command(args):
config = AetherisConfig.from_yaml(args.config)
model = HybridMambaMoE(config)
total_params = 0
dense_params = 0 # Parameters active for EVERY token
expert_params = 0 # Parameters in all MoE Experts
for name, param in model.named_parameters():
numel = param.numel()
total_params += numel
if 'experts' in name:
expert_params += numel
else:
dense_params += numel
single_expert_size = expert_params / config.num_experts if config.num_experts > 0 else 0
active_per_token_params = dense_params + (single_expert_size * config.top_k)
def format_count(count):
return f"{count / 1_000_000:.2f}M"
print("=" * 50)
print("Hybrid Mamba-MoE Model Parameter Analysis")
print("=" * 50)
print(f"Total Model Layers (N_Layer): {config.n_layer}")
print(f"MoE Experts per Layer: {config.num_experts}")
print(f"Active Experts (Top-K): {config.top_k}")
print("-" * 50)
print(f"Total Parameters (Checkpoint Size): {format_count(total_params)}")
print(f"Dense (Always Active) Parameters: {format_count(dense_params)}")
print(f"Expert-Only Parameters: {format_count(expert_params)}")
print("-" * 50)
print(f"**Active Parameters (Per-Token Compute Load): {format_count(active_per_token_params)}**")
print(" (This is the 'Dense' parameters + the K active expert parameters)")
print("=" * 50)
def main():
parser = argparse.ArgumentParser(description="Aetheris CLI")
subparsers = parser.add_subparsers(dest="command", help="Available commands")
# Train Command
train_parser = subparsers.add_parser("train", help="Train the model")
train_parser.add_argument("--config", type=str, default="configs/default.yaml", help="Path to config file")
train_parser.add_argument("--checkpoint_dir", type=str, default="checkpoints", help="Directory to save checkpoints")
train_parser.add_argument("--hf_token", type=str, default=os.environ.get("HF_TOKEN"), help="HuggingFace Token")
train_parser.add_argument("--batch_size", type=int, default=2, help="Batch size")
train_parser.add_argument("--pretrain_steps", type=int, default=50000, help="Number of pretraining steps")
train_parser.add_argument("--sft_steps", type=int, default=1000, help="Number of SFT steps")
train_parser.add_argument("--checkpoint_name", type=str, default="checkpoint_current.pth", help="Checkpoint file name to load from")
# Generate Command
gen_parser = subparsers.add_parser("generate", help="Generate text")
gen_parser.add_argument("--config", type=str, default="configs/default.yaml", help="Path to config file")
gen_parser.add_argument("--checkpoint_dir", type=str, default="checkpoints", help="Directory with checkpoints")
gen_parser.add_argument("--checkpoint_name", type=str, default="checkpoint_current.pth", help="Checkpoint file name")
gen_parser.add_argument("--prompt", type=str, default="The quick brown fox", help="Prompt for generation")
gen_parser.add_argument("--max_new_tokens", type=int, default=100, help="Max new tokens to generate")
gen_parser.add_argument("--temperature", type=float, default=0.8, help="Sampling temperature")
gen_parser.add_argument("--top_k", type=int, default=0, help="Top-k sampling")
gen_parser.add_argument("--top_p", type=float, default=0.9, help="Top-p sampling")
gen_parser.add_argument("--repetition_penalty", type=float, default=3.0, help="Repetition penalty")
# Serve Command
serve_parser = subparsers.add_parser("serve", help="Start the API server")
serve_parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind")
serve_parser.add_argument("--port", type=int, default=8000, help="Port to bind")
serve_parser.add_argument("--config", type=str, default="configs/default.yaml", help="Path to config file")
serve_parser.add_argument("--checkpoint_dir", type=str, default="checkpoints", help="Directory with checkpoints")
serve_parser.add_argument("--checkpoint_name", type=str, default="checkpoint_current.pth", help="Checkpoint file name")
# Info Command
info_parser = subparsers.add_parser("info", help="Show model info")
info_parser.add_argument("--config", type=str, default="configs/default.yaml", help="Path to config file")
args = parser.parse_args()
if args.command == "train":
train_command(args)
elif args.command == "generate":
generate_command(args)
elif args.command == "serve":
import uvicorn
from aetheris.api.server import app, get_engine
# Initialize engine before starting server
engine = get_engine()
# You might want to pass config/checkpoint paths to get_engine here if it supported arguments
# For now, it defaults or we need to modify get_engine or InferenceEngine to take args.
# But `get_engine` is a simple global accessor.
# Better: Initialize a global engine with args here.
from aetheris.inference import InferenceEngine
import aetheris.api.server
aetheris.api.server.engine = InferenceEngine(
config_path=args.config,
checkpoint_dir=args.checkpoint_dir,
checkpoint_name=args.checkpoint_name
)
uvicorn.run(app, host=args.host, port=args.port)
elif args.command == "info":
info_command(args)
else:
parser.print_help()
if __name__ == "__main__":
main()