"""FastAPI backend for GPU Memory Calculator web application.""" import hashlib import json import logging from pathlib import Path from typing import Any from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator from starlette.requests import Request from gpu_mem_calculator.config.presets import load_presets from gpu_mem_calculator.core.calculator import GPUMemoryCalculator from gpu_mem_calculator.core.models import ( EngineConfig, GPUConfig, InferenceConfig, InferenceEngineType, InterconnectType, MemoryResult, ModelConfig, NodeConfig, ParallelismConfig, TrainingConfig, ) from gpu_mem_calculator.core.multinode import MultiNodeCalculator from gpu_mem_calculator.exporters.manager import ExportFormat, ExportManager from gpu_mem_calculator.huggingface import ( HuggingFaceClient, HuggingFaceConfigMapper, HuggingFaceError, InvalidConfigError, ModelNotFoundError, PrivateModelAccessError, ) from gpu_mem_calculator.inference.calculator import InferenceMemoryCalculator # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Create FastAPI app app = FastAPI( title="GPU Memory Calculator", description="Calculate GPU memory requirements for LLM training", version="0.1.0", ) # Configure CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Setup templates and static files BASE_DIR = Path(__file__).parent templates = Jinja2Templates(directory=str(BASE_DIR / "templates")) # Mount static files static_dir = BASE_DIR / "static" if static_dir.exists(): app.mount("/static", StaticFiles(directory=str(static_dir)), name="static") # Request/Response models class CalculateRequest(BaseModel): """Request model for memory calculation with comprehensive validation.""" model: dict[str, Any] = Field(description="Model configuration") training: dict[str, Any] = Field(description="Training configuration") parallelism: dict[str, Any] | None = Field( default=None, description="Parallelism configuration", ) engine: dict[str, Any] | None = Field(default=None, description="Engine configuration") hardware: dict[str, Any] | None = Field(default=None, description="Hardware configuration") @field_validator("model") @classmethod def validate_moe_settings(cls, v: dict[str, Any]) -> dict[str, Any]: """Validate MoE-specific constraints.""" if v.get("moe_enabled"): num_experts = v.get("num_experts", 1) top_k = v.get("top_k", 1) if top_k > num_experts: raise ValueError(f"MoE top_k ({top_k}) cannot exceed num_experts ({num_experts})") if num_experts < 1 or num_experts > 256: raise ValueError(f"num_experts must be between 1 and 256, got {num_experts}") if top_k < 1 or top_k > 8: raise ValueError(f"top_k must be between 1 and 8, got {top_k}") return v @model_validator(mode="after") def validate_parallelism_consistency(self) -> "CalculateRequest": """Validate parallelism settings consistency.""" if self.parallelism and self.hardware: tensor_pp = self.parallelism.get("tensor_parallel_size", 1) pipeline_pp = self.parallelism.get("pipeline_parallel_size", 1) data_pp = self.parallelism.get("data_parallel_size", 1) num_gpus = self.hardware.get("num_gpus", 1) effective_gpus = tensor_pp * pipeline_pp * data_pp if effective_gpus != num_gpus: raise ValueError( f"Parallelism mismatch: tensor_pp ({tensor_pp}) × " f"pipeline_pp ({pipeline_pp}) × data_pp ({data_pp}) = " f"{effective_gpus} GPUs, but num_gpus is set to {num_gpus}. " f"These must match." ) # Validate sequence parallel requires tensor parallel > 1 if self.parallelism and self.parallelism.get("sequence_parallel"): tensor_pp = self.parallelism.get("tensor_parallel_size", 1) if tensor_pp <= 1: raise ValueError( f"Sequence parallelism requires tensor_parallel_size > 1, " f"got {tensor_pp}" ) return self @model_validator(mode="after") def validate_engine_settings(self) -> "CalculateRequest": """Validate engine-specific settings.""" if not self.engine: return self engine_type = self.engine.get("type") zero_stage = self.engine.get("zero_stage", 0) # ZeRO stages only valid for DeepSpeed engines if engine_type not in ["deepspeed", "megatron_deepspeed"] and zero_stage > 0: raise ValueError( f"ZeRO stages are only supported for DeepSpeed engines, " f"got engine_type='{engine_type}' with zero_stage={zero_stage}" ) # Validate ZeRO stage range if zero_stage < 0 or zero_stage > 3: raise ValueError(f"zero_stage must be between 0 and 3, got {zero_stage}") return self class PresetInfo(BaseModel): """Information about a preset model configuration.""" name: str display_name: str description: str config: dict[str, Any] class HuggingFaceRequest(BaseModel): """Request for fetching HuggingFace model metadata.""" model_config = ConfigDict(protected_namespaces=()) model_id: str = Field(description="HuggingFace model ID (e.g., meta-llama/Llama-2-7b-hf)") token: str | None = Field(default=None, description="HF token for private models") # Simple in-memory cache for calculation results # In production, use Redis or similar _calculation_cache: dict[str, tuple[MemoryResult, float]] = {} # key -> (result, timestamp) _CACHE_TTL = 3600 # 1 hour _MAX_CACHE_SIZE = 1000 def _cache_key_from_request(request: CalculateRequest) -> str: """Generate cache key from request.""" request_dict = request.model_dump() # Sort keys for consistent hashing request_str = json.dumps(request_dict, sort_keys=True) return hashlib.md5(request_str.encode()).hexdigest() def _get_cached_result(key: str) -> MemoryResult | None: """Get cached result if available and not expired.""" if key in _calculation_cache: result, timestamp = _calculation_cache[key] import time if time.time() - timestamp < _CACHE_TTL: return result else: # Expired, remove from cache del _calculation_cache[key] return None def _cache_result(key: str, result: MemoryResult) -> None: """Cache calculation result.""" import time # Simple cache eviction if too large if len(_calculation_cache) >= _MAX_CACHE_SIZE: # Remove oldest entry (first key) oldest_key = next(iter(_calculation_cache)) del _calculation_cache[oldest_key] _calculation_cache[key] = (result, time.time()) # Load presets at startup using shared preset loader # The shared loader reads from web/presets/models.json def _load_presets_from_shared() -> dict[str, PresetInfo]: """Load presets using the shared preset loader.""" all_presets = load_presets() return { name: PresetInfo( name=name, display_name=preset.get("display_name", name), description=preset.get("description", ""), config=preset.get("config", {}), ) for name, preset in all_presets.items() } PRESETS = _load_presets_from_shared() # API Routes @app.get("/") async def index(request: Request) -> Any: """Serve the main web page.""" return templates.TemplateResponse("index.html", {"request": request}) @app.get("/api/engines") async def list_engines() -> dict[str, str]: """List supported training engines.""" return { "pytorch_ddp": "PyTorch DDP (Distributed Data Parallel)", "deepspeed": "DeepSpeed ZeRO", "megatron_lm": "Megatron-LM", "fsdp": "PyTorch FSDP (Fully Sharded Data Parallel)", "megatron_deepspeed": "Megatron-LM + DeepSpeed", } @app.get("/api/optimizers") async def list_optimizers() -> dict[str, str]: """List supported optimizers.""" return { "adam": "Adam", "adamw": "AdamW", "adamw_8bit": "AdamW 8-bit", "sgd": "SGD", } @app.get("/api/dtypes") async def list_dtypes() -> dict[str, str]: """List supported data types.""" return { "fp32": "FP32 (32-bit floating point)", "fp16": "FP16 (16-bit floating point)", "bf16": "BF16 (16-bit bfloat)", "int8": "INT8 (8-bit integer)", "int4": "INT4 (4-bit integer)", } @app.get("/api/presets") async def list_presets() -> dict[str, dict[str, str]]: """List all preset model configurations.""" return { name: { "display_name": preset.display_name, "description": preset.description, } for name, preset in PRESETS.items() } @app.get("/api/preset/{preset_name}") async def get_preset(preset_name: str) -> dict[str, Any]: """Get a specific preset configuration.""" if preset_name not in PRESETS: raise HTTPException(status_code=404, detail=f"Preset '{preset_name}' not found") return PRESETS[preset_name].config @app.post("/api/hf/fetch") async def fetch_huggingface_model(request: HuggingFaceRequest) -> dict[str, Any]: """Fetch model metadata from HuggingFace Hub. Args: request: Request with model_id and optional token Returns: Model config with fields filled from HF, plus list of missing fields Raises: HTTPException: If model not found, access denied, or invalid config """ try: # Initialize HF client client = HuggingFaceClient(token=request.token) # Fetch metadata metadata = await client.fetch_model_metadata(request.model_id) # Map to ModelConfig mapper = HuggingFaceConfigMapper() result = mapper.map_to_model_config(metadata["config"], metadata.get("model_info")) return { "model_id": request.model_id, "config": result["config"], "missing_fields": result["missing_fields"], "found_fields": result["found_fields"], "warnings": [], } except PrivateModelAccessError as e: raise HTTPException( status_code=401, detail={ "error": "Authentication required", "message": str(e), "type": "auth_error", }, ) from e except ModelNotFoundError as e: raise HTTPException( status_code=404, detail={ "error": "Model not found", "message": str(e), "type": "not_found", }, ) from e except InvalidConfigError as e: raise HTTPException( status_code=422, detail={ "error": "Invalid model configuration", "message": str(e), "type": "invalid_config", }, ) from e except HuggingFaceError as e: logger.error(f"HuggingFace error: {str(e)}") raise HTTPException( status_code=500, detail={ "error": "HuggingFace API error", "message": str(e), "type": "api_error", }, ) from e except Exception as e: logger.error(f"Unexpected error fetching HF model: {str(e)}", exc_info=True) raise HTTPException( status_code=500, detail={ "error": "Internal server error", "message": "An unexpected error occurred", "type": "server_error", }, ) from e @app.post("/api/calculate") async def calculate_memory(request: CalculateRequest) -> MemoryResult: """Calculate GPU memory requirements. Args: request: Calculation request with model, training, and hardware configs Returns: MemoryResult with complete memory breakdown """ # Check cache first cache_key = _cache_key_from_request(request) cached_result = _get_cached_result(cache_key) if cached_result is not None: logger.info(f"Cache hit for key: {cache_key[:8]}...") return cached_result try: # Parse model configuration model_data = request.model.copy() # Parse num_parameters if it's a string (e.g., "7B", "7000M") if "num_parameters" in model_data and isinstance( model_data["num_parameters"], str, ): from gpu_mem_calculator.config.parser import ConfigParser model_data["num_parameters"] = ConfigParser._parse_num_params( model_data["num_parameters"], ) model_config = ModelConfig(**model_data) # Parse training configuration training_config = TrainingConfig(**request.training) # Parse optional configurations with defaults parallelism_config = ( ParallelismConfig(**request.parallelism) if request.parallelism else ParallelismConfig() ) engine_config = EngineConfig(**request.engine) if request.engine else EngineConfig() gpu_config = GPUConfig(**request.hardware) if request.hardware else GPUConfig() # Create calculator and compute calculator = GPUMemoryCalculator( model_config=model_config, training_config=training_config, parallelism_config=parallelism_config, engine_config=engine_config, gpu_config=gpu_config, ) result = calculator.calculate() # Cache the result _cache_result(cache_key, result) logger.info( f"Calculation successful: {model_config.name}, " f"{result.total_memory_per_gpu_gb:.2f} GB per GPU" ) return result except ValueError as e: # User input validation error logger.warning(f"Validation error: {str(e)}") raise HTTPException( status_code=400, detail={"error": "Validation error", "message": str(e), "type": "validation_error"}, ) from e except Exception as e: # Unexpected system error logger.error(f"Calculation error: {str(e)}", exc_info=True) raise HTTPException( status_code=500, detail={ "error": "Internal server error", "message": "An unexpected error occurred during calculation", }, ) from e @app.post("/api/export/deepspeed") async def export_deepspeed_config(request: CalculateRequest) -> dict[str, Any]: """Export DeepSpeed configuration file. Args: request: Calculation request with model, training, and hardware configs Returns: DeepSpeed config JSON and memory result """ try: # First calculate memory calc_result = await calculate_memory(request) # Generate DeepSpeed config parallelism = request.parallelism or {} training = request.training engine = request.engine or {} train_batch_size = ( training.get("batch_size", 1) * training.get("gradient_accumulation_steps", 1) * parallelism.get("data_parallel_size", 1) ) zero_stage = engine.get("zero_stage", 0) offload_optimizer = engine.get("offload_optimizer", "none") offload_param = engine.get("offload_param", "none") deepspeed_config = { "train_batch_size": train_batch_size, "train_micro_batch_size_per_gpu": training.get("batch_size", 1), "gradient_accumulation_steps": training.get("gradient_accumulation_steps", 1), "optimizer": { "type": training.get("optimizer", "AdamW"), "params": {"lr": 0.0001, "betas": [0.9, 0.999], "eps": 1e-8, "weight_decay": 0.01}, }, "scheduler": { "type": "WarmupLR", "params": {"warmup_min_lr": 0, "warmup_max_lr": 0.0001, "warmup_num_steps": 2000}, }, "fp16": {"enabled": training.get("dtype") in ["fp16", "int4", "int8"]}, "bf16": {"enabled": training.get("dtype") == "bf16"}, "zero_optimization": {"stage": zero_stage}, "gradient_clipping": training.get("gradient_clipping", 1.0), "steps_per_print": 100, } # Add offload config if ZeRO stage >= 1 if zero_stage >= 1: deepspeed_config["zero_optimization"]["offload_optimizer"] = { "device": offload_optimizer } deepspeed_config["zero_optimization"]["offload_param"] = {"device": offload_param} return {"config": deepspeed_config, "memory_result": calc_result} except HTTPException: raise except Exception as e: logger.error(f"DeepSpeed export error: {str(e)}", exc_info=True) raise HTTPException( status_code=500, detail=f"Failed to generate DeepSpeed config: {str(e)}" ) from e @app.post("/api/optimize/batch-size") async def optimize_batch_size(request: CalculateRequest) -> dict[str, Any]: """Find maximum batch size that fits in GPU memory. Uses binary search to find the maximum batch size that doesn't OOM. Args: request: Calculation request with model, training, and hardware configs Returns: Maximum batch size that fits and corresponding memory result """ try: # Create a mutable copy for testing from copy import deepcopy min_batch = 1 max_batch = 512 # Reasonable upper bound best_batch = 1 while min_batch <= max_batch: mid = (min_batch + max_batch) // 2 # Create modified request with test batch size test_request = deepcopy(request) test_request.training["batch_size"] = mid try: # Validate and calculate CalculateRequest.model_validate(test_request) result = await calculate_memory(test_request) if result.fits_on_gpu: best_batch = mid min_batch = mid + 1 else: max_batch = mid - 1 except (ValueError, HTTPException): # Invalid config or doesn't fit max_batch = mid - 1 # Get final result for best batch size final_request = deepcopy(request) final_request.training["batch_size"] = best_batch final_result = await calculate_memory(final_request) return {"max_batch_size": best_batch, "memory_result": final_result} except Exception as e: logger.error(f"Batch size optimization error: {str(e)}", exc_info=True) raise HTTPException( status_code=500, detail=f"Failed to optimize batch size: {str(e)}" ) from e @app.post("/api/validate") async def validate_config(request: CalculateRequest) -> dict[str, Any]: """Validate a configuration without calculating memory. Args: request: Configuration to validate Returns: Validation result with valid flag and any errors """ try: # Pydantic validation happens automatically when creating CalculateRequest # If we get here, the request is valid return {"valid": True, "errors": []} except ValueError as e: # Validation error return {"valid": False, "errors": [str(e)]} except Exception as e: # Unexpected error logger.error(f"Validation error: {str(e)}", exc_info=True) return {"valid": False, "errors": [str(e)]} @app.post("/api/explain-formula") async def explain_formula(request: CalculateRequest) -> dict[str, Any]: """Explain the memory formula used for calculation. Returns detailed information about which formula is being used, with the user's values plugged in, and links to documentation. Args: request: Calculation request with model, training, and hardware configs Returns: Formula explanation with formula type, breakdown, and references """ try: # Get configuration details engine_type = request.engine.get("type", "pytorch_ddp") if request.engine else "pytorch_ddp" num_params = request.model.get("num_parameters", 0) # Parse num_parameters if it's a string (e.g., "7B", "7000M") if isinstance(num_params, str): from gpu_mem_calculator.config.parser import ConfigParser num_params = ConfigParser._parse_num_params(num_params) optimizer = request.training.get("optimizer", "adamw") num_gpus = request.hardware.get("num_gpus", 1) if request.hardware else 1 batch_size = request.training.get("batch_size", 1) # Calculate memory to get the breakdown result = await calculate_memory(request) # Determine formula description based on engine type formula_info = { "engine_type": engine_type, "engine_name": _get_engine_name(engine_type), "formula_components": [], "total_memory_gb": round(result.total_memory_per_gpu_gb, 2), "breakdown": { "model_params_gb": round(result.breakdown.model_params_gb, 2), "gradients_gb": round(result.breakdown.gradients_gb, 2), "optimizer_states_gb": round(result.breakdown.optimizer_states_gb, 2), "activations_gb": round(result.breakdown.activations_gb, 2), "overhead_gb": round(result.breakdown.overhead_gb, 2), }, "references": _get_formula_references(engine_type), } # Add engine-specific formula details if engine_type == "pytorch_ddp": formula_info["formula_description"] = ( "PyTorch DDP stores complete copies of model parameters, gradients, " "and optimizer states on each GPU." ) formula_info["formula_components"] = [ { "name": "Model Parameters", "formula": f"{num_params:,} × 2 bytes (FP16/BF16)", "result": f"{result.breakdown.model_params_gb:.2f} GB", "description": "Full model stored on each GPU", }, { "name": "Gradients", "formula": f"{num_params:,} × 2 bytes (FP16)", "result": f"{result.breakdown.gradients_gb:.2f} GB", "description": "Full gradients during backward pass", }, { "name": "Optimizer States", "formula": _get_optimizer_formula(optimizer, num_params)["formula"], "result": f"{result.breakdown.optimizer_states_gb:.2f} GB", "description": _get_optimizer_formula(optimizer, num_params)["description"], }, ] elif engine_type in ["deepspeed", "megatron_deepspeed"]: zero_stage = request.engine.get("zero_stage", 0) if request.engine else 0 offload_optimizer = ( request.engine.get("offload_optimizer", "none") if request.engine else "none" ) offload_param = ( request.engine.get("offload_param", "none") if request.engine else "none" ) if zero_stage == 0: stage_name = "ZeRO-0 (Baseline)" formula_info["formula_description"] = ( f"{stage_name}: No memory optimization. Same as PyTorch DDP." ) elif zero_stage == 1: stage_name = "ZeRO-1" formula_info["formula_description"] = ( f"{stage_name}: Shards optimizer states across {num_gpus} GPUs. " f"Reduces optimizer memory by {num_gpus}x." ) elif zero_stage == 2: stage_name = "ZeRO-2" formula_info["formula_description"] = ( f"{stage_name}: Shards optimizer states AND gradients across {num_gpus} GPUs. " f"Reduces memory by {num_gpus}x for both components." ) elif zero_stage == 3: stage_name = "ZeRO-3" formula_info["formula_description"] = ( f"{stage_name}: Shards parameters, gradients, AND optimizer states. " f"Only largest layer stored intact. Linear memory reduction with GPU count." ) formula_info["zero_stage"] = zero_stage formula_info["offload_optimizer"] = offload_optimizer formula_info["offload_param"] = offload_param # Add ZeRO-specific components if zero_stage == 3: # Estimate largest layer (approx 10% of params for typical models) largest_params = num_params // 10 formula_info["formula_components"] = [ { "name": "Largest Layer", "formula": f"{largest_params:,} × 4 bytes (FP16 params + grads)", "result": f"{result.breakdown.model_params_gb:.2f} GB", "description": "Gathered during compute, largest layer kept intact", }, { "name": "Sharded Parameters", "formula": f"({num_params:,} × 2 bytes) / {num_gpus} GPUs", "result": "Included in model params", "description": "Remaining parameters sharded across GPUs", }, { "name": "Sharded Optimizer States", "formula": ( ( f"({_get_optimizer_formula(optimizer, num_params)['formula']}) " f"/ {num_gpus} GPUs" ) if offload_optimizer == "none" else f"Offloaded to {offload_optimizer}" ), "result": f"{result.breakdown.optimizer_states_gb:.2f} GB", "description": ( _get_optimizer_formula(optimizer, num_params)["description"] + " (sharded or offloaded)" ), }, ] else: # ZeRO-1 or ZeRO-2 formula_info["formula_components"] = [ { "name": "Model Parameters", "formula": f"{num_params:,} × 2 bytes (FP16)", "result": f"{result.breakdown.model_params_gb:.2f} GB", "description": "Full model on each GPU", }, { "name": "Gradients", "formula": ( f"{num_params:,} × 2 bytes" if zero_stage < 2 else f"({num_params:,} × 2 bytes) / {num_gpus} GPUs" ), "result": f"{result.breakdown.gradients_gb:.2f} GB", "description": ( "Sharded across GPUs" if zero_stage >= 2 else "Full gradients" ), }, { "name": "Optimizer States", "formula": ( ( f"({_get_optimizer_formula(optimizer, num_params)['formula']}) " f"/ {num_gpus} GPUs" ) if offload_optimizer == "none" else f"Offloaded to {offload_optimizer}" ), "result": f"{result.breakdown.optimizer_states_gb:.2f} GB", "description": ( _get_optimizer_formula(optimizer, num_params)["description"] + " (sharded or offloaded)" ), }, ] elif engine_type == "fsdp": sharding_strategy = ( request.engine.get("sharding_strategy", "full_shard") if request.engine else "full_shard" ) if sharding_strategy == "no_shard": strategy_name = "No Sharding (like DDP)" elif sharding_strategy == "shard_grad_op": strategy_name = "Shard Gradients + Optimizer (like ZeRO-2)" else: strategy_name = "Full Shard (like ZeRO-3)" formula_info["sharding_strategy"] = sharding_strategy formula_info["strategy_name"] = strategy_name formula_info["formula_description"] = f"FSDP with {strategy_name.lower()} strategy." elif engine_type == "megatron_lm": formula_info["formula_description"] = ( "Megatron-LM uses tensor and/or pipeline parallelism to " "split the model across GPUs, reducing memory per GPU." ) # Add parallelism info if request.parallelism: tp_size = request.parallelism.get("tensor_parallel_size", 1) pp_size = request.parallelism.get("pipeline_parallel_size", 1) formula_info["parallelism"] = { "tensor_parallel_size": tp_size, "pipeline_parallel_size": pp_size, } # Add activation memory explanation components: list[dict[str, Any]] = formula_info["formula_components"] # type: ignore[assignment] components.append( { "name": "Activations", "formula": ( f"batch_size({batch_size}) × seq_len × hidden_size × " f"layers × ~16 bytes/token/layer" ), "result": f"{result.breakdown.activations_gb:.2f} GB", "description": "Memory from intermediate activations during forward/backward pass", } ) return formula_info except Exception as e: logger.error(f"Formula explanation error: {str(e)}", exc_info=True) raise HTTPException( status_code=500, detail=f"Failed to generate formula explanation: {str(e)}" ) from e def _get_engine_name(engine_type: str) -> str: """Get human-readable engine name.""" names = { "pytorch_ddp": "PyTorch DDP (Distributed Data Parallel)", "deepspeed": "DeepSpeed ZeRO", "megatron_lm": "Megatron-LM", "fsdp": "PyTorch FSDP (Fully Sharded Data Parallel)", "megatron_deepspeed": "Megatron-LM + DeepSpeed", } return names.get(engine_type, engine_type) def _get_optimizer_formula(optimizer: str, num_params: int) -> dict[str, str]: """Get optimizer memory formula based on optimizer type. Args: optimizer: Optimizer type (adam, adamw, sgd, adamw_8bit) num_params: Number of model parameters Returns: Dictionary with 'formula' and 'description' keys """ num_params_formatted = f"{num_params:,}" if optimizer in ["adam", "adamw"]: return { "formula": f"{num_params_formatted} × 12 bytes (Adam/AdamW FP32)", "description": "4 bytes FP32 params + 4 bytes momentum + 4 bytes variance", } elif optimizer == "adamw_8bit": return { "formula": f"{num_params_formatted} × 2 bytes (AdamW 8-bit)", "description": "8-bit quantized optimizer states (2 bytes per parameter)", } elif optimizer == "sgd": return { "formula": f"{num_params_formatted} × 4 bytes (SGD)", "description": "4 bytes FP32 params (no momentum for SGD)", } else: # Default to AdamW return { "formula": f"{num_params_formatted} × 12 bytes (Adam/AdamW FP32)", "description": "4 bytes FP32 params + 4 bytes momentum + 4 bytes variance", } def _get_formula_references(engine_type: str) -> list[dict[str, str]]: """Get authoritative references for the formula.""" references = [ { "title": "EleutherAI Transformer Math 101", "url": "https://blog.eleuther.ai/transformer-math/", "description": "Comprehensive transformer memory breakdown with formulas", }, { "title": "Microsoft Research ZeRO Blog", "url": "https://www.microsoft.com/en-us/research/blog/zero-deepspeed-new-system-optimizations-enable-training-models-with-over-100-billion-parameters/", "description": "ZeRO optimization techniques and memory formulas", }, ] if engine_type in ["deepspeed", "megatron_deepspeed"]: references.append( { "title": "DeepSpeed Memory Documentation", "url": "https://deepspeed.readthedocs.io/en/latest/memory.html", "description": "Official DeepSpeed memory requirements and formulas", } ) elif engine_type == "megatron_lm" or engine_type == "megatron_deepspeed": references.append( { "title": "NVIDIA Megatron-LM", "url": "https://github.com/NVIDIA/Megatron-LM", "description": "Megatron-LM tensor and pipeline parallelism", } ) elif engine_type == "fsdp": references.append( { "title": "PyTorch FSDP Documentation", "url": "https://pytorch.org/docs/stable/fsdp.html", "description": "PyTorch Fully Sharded Data Parallel documentation", } ) return references @app.post("/api/inference/calculate") async def calculate_inference_memory(request: dict[str, Any]) -> dict[str, Any]: """Calculate GPU memory requirements for inference. Args: request: Dictionary with model, inference, and hardware configs Returns: Inference memory result with breakdown """ try: model_data = request.get("model", {}) inference_data = request.get("inference", {}) hardware_data = request.get("hardware", {}) # Parse num_parameters if it's a string if "num_parameters" in model_data and isinstance(model_data["num_parameters"], str): from gpu_mem_calculator.config.parser import ConfigParser model_data["num_parameters"] = ConfigParser._parse_num_params( model_data["num_parameters"] ) # Create model config model_config = ModelConfig(**model_data) # Create inference config kv_cache_quantization = inference_data.get("kv_cache_quantization", "none") if isinstance(kv_cache_quantization, str): from gpu_mem_calculator.core.models import KVCacheQuantization kv_cache_quantization = KVCacheQuantization(kv_cache_quantization) inference_config = InferenceConfig( batch_size=inference_data.get("batch_size", 1), kv_cache_quantization=kv_cache_quantization, use_kv_cache=inference_data.get("use_kv_cache", True), tensor_parallel_size=inference_data.get("tensor_parallel_size", 1), gpu_memory_utilization=inference_data.get("gpu_memory_utilization", 0.9), enable_streaming=inference_data.get("enable_streaming", False), # TGI-specific parameters max_total_tokens=inference_data.get("max_total_tokens"), max_input_tokens=inference_data.get("max_input_tokens"), max_batch_total_tokens=inference_data.get("max_batch_total_tokens"), tgi_quantize=inference_data.get("tgi_quantize", "none"), tgi_dtype=inference_data.get("tgi_dtype", "bfloat16"), sharded=inference_data.get("sharded", False), num_shard=inference_data.get("num_shard"), # vLLM-specific parameters block_size=inference_data.get("block_size"), swap_space_gb=inference_data.get("swap_space_gb", 0.0), enable_prefix_caching=inference_data.get("enable_prefix_caching", False), enforce_eager=inference_data.get("enforce_eager", False), max_num_batched_tokens=inference_data.get("max_num_batched_tokens"), max_num_seqs=inference_data.get("max_num_seqs"), vllm_quantization=inference_data.get("vllm_quantization", "none"), # TensorRT-LLM-specific parameters trt_max_batch_size=inference_data.get("trt_max_batch_size"), trt_max_input_len=inference_data.get("trt_max_input_len"), trt_max_seq_len=inference_data.get("trt_max_seq_len"), trt_max_beam_width=inference_data.get("trt_max_beam_width"), # SGLang-specific parameters chunk_size=inference_data.get("chunk_size"), max_running_requests=inference_data.get("max_running_requests"), disable_radix_cache=inference_data.get("disable_radix_cache", False), enable_p2p=inference_data.get("enable_p2p", False), disable_custom_all_reduce=inference_data.get("disable_custom_all_reduce", False), attention_backend=inference_data.get("attention_backend", "flashinfer"), enable_torch_compile=inference_data.get("enable_torch_compile", False), radix_cache_max_seq_len=inference_data.get("radix_cache_max_seq_len"), speculative_algo=inference_data.get("speculative_algo", "default"), multi_lora_enabled=inference_data.get("multi_lora_enabled", False), ) # Create GPU config gpu_config = GPUConfig( num_gpus=hardware_data.get("num_gpus", 1), gpu_memory_gb=hardware_data.get("gpu_memory_gb", 80), ) # Get engine type engine_type_str = inference_data.get("engine_type", "huggingface") engine_type_map = { "huggingface": InferenceEngineType.HUGGINGFACE, "vllm": InferenceEngineType.VLLM, "tgi": InferenceEngineType.TGI, "tensorrt_llm": InferenceEngineType.TENSORRT_LLM, "sglang": InferenceEngineType.SGLANG, } engine_type = engine_type_map.get(engine_type_str, InferenceEngineType.HUGGINGFACE) # Calculate inference memory calculator = InferenceMemoryCalculator(model_config, inference_config, gpu_config) result = calculator.calculate(engine_type) return { "total_memory_per_gpu_gb": result.total_memory_per_gpu_gb, "total_memory_all_gpus_gb": result.total_memory_all_gpus_gb, "breakdown": { "model_params_gb": result.breakdown.model_params_gb, "kv_cache_gb": result.breakdown.kv_cache_gb, "activations_gb": result.breakdown.activations_gb, "overhead_gb": result.breakdown.overhead_gb, }, "max_supported_batch_size": result.max_supported_batch_size, "estimated_throughput_tokens_per_sec": result.estimated_throughput_tokens_per_sec, "fits_on_gpu": result.fits_on_gpu, "memory_utilization_percent": result.memory_utilization_percent, } except Exception as e: logger.error(f"Inference calculation error: {str(e)}", exc_info=True) raise HTTPException( status_code=500, detail=f"Failed to calculate inference memory: {str(e)}" ) from e @app.post("/api/multinode/calculate") async def calculate_multinode(request: dict[str, Any]) -> dict[str, Any]: """Calculate network overhead for multi-node training. Args: request: Dictionary with model, training, parallelism, engine, and node configs Returns: Network overhead result with suggestions """ try: model_data = request.get("model", {}) training_data = request.get("training", {}) parallelism_data = request.get("parallelism", {}) engine_data = request.get("engine", {}) node_data = request.get("node_config", {}) # Parse num_parameters if it's a string if "num_parameters" in model_data and isinstance(model_data["num_parameters"], str): from gpu_mem_calculator.config.parser import ConfigParser model_data["num_parameters"] = ConfigParser._parse_num_params( model_data["num_parameters"] ) # Create minimal configs for multi-node calculation model_config = ModelConfig( name="multinode-model", num_parameters=model_data.get("num_parameters", 7_000_000_000), num_layers=32, hidden_size=4096, num_attention_heads=32, ) training_config = TrainingConfig( dtype=training_data.get("dtype", "bf16"), batch_size=training_data.get("batch_size", 4), ) parallelism_config = ParallelismConfig( tensor_parallel_size=parallelism_data.get("tensor_parallel_size", 1), pipeline_parallel_size=parallelism_data.get("pipeline_parallel_size", 1), sequence_parallel=parallelism_data.get("sequence_parallel", False), ) engine_config = EngineConfig( type=engine_data.get("type", "deepspeed"), zero_stage=engine_data.get("zero_stage", 3), ) interconnect_type_str = node_data.get("interconnect_type", "infiniband") interconnect_map = { "infiniband": InterconnectType.INFINIBAND, "nvlink": InterconnectType.NVLINK, "ethernet_200g": InterconnectType.ETHERNET_200G, "ethernet_100g": InterconnectType.ETHERNET_100G, "ethernet_25g": InterconnectType.ETHERNET_25G, "ethernet_10g": InterconnectType.ETHERNET_10G, } interconnect_type = interconnect_map.get(interconnect_type_str, InterconnectType.INFINIBAND) node_config = NodeConfig( num_nodes=node_data.get("num_nodes", 2), gpus_per_node=node_data.get("gpus_per_node", 8), interconnect_type=interconnect_type, ) # Calculate network overhead calculator = MultiNodeCalculator( model_config=model_config, training_config=training_config, parallelism_config=parallelism_config, node_config=node_config, engine_config=engine_config, ) overhead = calculator.calculate_network_overhead() # Generate optimization suggestions suggestions: list[str] = [] if overhead.total_overhead_gb > 10: suggestions.append("Consider reducing tensor parallelism to lower AllGather overhead") if overhead.estimated_overhead_ms_per_step and overhead.estimated_overhead_ms_per_step > 50: overhead_val = overhead.estimated_overhead_ms_per_step suggestions.append( f"High communication overhead ({overhead_val:.1f}ms/step). " "Consider upgrading interconnect or reducing model size." ) if interconnect_type_str.startswith("ethernet") and node_config.num_nodes > 2: suggestions.append( "Ethernet interconnect detected. For multi-node training, " "consider InfiniBand for better performance." ) return { "network_overhead": { "total_overhead_gb": overhead.total_overhead_gb, "allreduce_gb": overhead.allreduce_gb, "allgather_gb": overhead.allgather_gb, "reducescatter_gb": overhead.reducescatter_gb, "pipeline_gb": overhead.point_to_point_gb, "estimated_overhead_ms_per_step": overhead.estimated_overhead_ms_per_step, "communication_time_ms_per_step": None, "latency_overhead_ms": None, }, "suggestions": suggestions, } except Exception as e: logger.error(f"Multi-node calculation error: {str(e)}", exc_info=True) raise HTTPException( status_code=500, detail=f"Failed to calculate multi-node overhead: {str(e)}" ) from e @app.post("/api/export/{format}") async def export_framework_config(format: str, request: CalculateRequest) -> dict[str, Any]: """Export configuration to framework-specific format. Args: format: Export format (accelerate, lightning, axolotl, deepspeed, yaml, json) request: Calculation request with all configurations Returns: Exported configuration file content """ try: # Parse configurations model_data = request.model.copy() if "num_parameters" in model_data and isinstance(model_data["num_parameters"], str): from gpu_mem_calculator.config.parser import ConfigParser model_data["num_parameters"] = ConfigParser._parse_num_params( model_data["num_parameters"] ) model_config = ModelConfig(**model_data) training_config = TrainingConfig(**request.training) parallelism_config = ( ParallelismConfig(**request.parallelism) if request.parallelism else ParallelismConfig() ) engine_config = EngineConfig(**request.engine) if request.engine else EngineConfig() # Create minimal node config (not used for single-node export) node_config = NodeConfig(num_nodes=1, gpus_per_node=8) # Map format string to ExportFormat enum format_map = { "accelerate": ExportFormat.ACCELERATE, "lightning": ExportFormat.LIGHTNING, "axolotl": ExportFormat.AXOLOTL, "deepspeed": ExportFormat.DEEPSPEED, "yaml": ExportFormat.YAML, "json": ExportFormat.JSON, } export_format = format_map.get(format.lower()) if not export_format: raise HTTPException( status_code=400, detail=f"Unsupported export format: {format}. Supported: {list(format_map.keys())}", ) # Export configuration manager = ExportManager( model_config=model_config, training_config=training_config, parallelism_config=parallelism_config, engine_config=engine_config, node_config=node_config, ) result = manager.export(export_format) # Generate filename if isinstance(result, dict): filename = f"config_{format}.{result.get('extension', 'txt')}" else: filename = f"config.{format}" return { "format": format, "content": result, "filename": filename, } except HTTPException: raise except Exception as e: logger.error(f"Export error ({format}): {str(e)}", exc_info=True) raise HTTPException( status_code=500, detail=f"Failed to export {format} config: {str(e)}" ) from e def main() -> None: """Run the development server.""" import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000, reload=True) if __name__ == "__main__": main()