George Yang
Initial deployment: Add GPU Memory Calculator with Docker
36ed1cd
"""CLI interface for GPU Memory Calculator."""
import json
import sys
from pathlib import Path
from typing import TYPE_CHECKING, Literal
import click
if TYPE_CHECKING:
from gpu_mem_calculator.core.calculator import GPUMemoryCalculator
from gpu_mem_calculator.core.models import MemoryResult
@click.group()
@click.version_option(version="0.1.0")
def main() -> None:
"""GPU Memory Calculator for LLM Training.
Calculate GPU memory requirements for training Large Language Models
with various training engines (PyTorch DDP, DeepSpeed, Megatron-LM, FSDP).
"""
pass
@main.command()
@click.option(
"--config",
"-c",
type=click.Path(exists=True),
help="Path to JSON configuration file",
)
@click.option(
"--preset",
"-p",
type=str,
help="Name of a preset model configuration",
)
@click.option(
"--output",
"-o",
type=click.Path(),
help="Output file path (default: stdout)",
)
@click.option(
"--format",
"-f",
type=click.Choice(["json", "yaml", "table"]),
default="table",
help="Output format (default: table)",
)
def calculate(
config: str | None,
preset: str | None,
output: str | None,
format: Literal["json", "yaml", "table"],
) -> None:
"""Calculate GPU memory requirements from config file or preset.
Examples:
gpu-mem-calc calculate --config configs/llama2_7b.json
gpu-mem-calc calculate --preset llama2-7b
gpu-mem-calc calculate -p mixtral-8x7b --format json
"""
if not config and not preset:
click.echo("Error: Either --config or --preset is required", err=True)
sys.exit(1)
if config and preset:
click.echo("Error: Cannot use both --config and --preset", err=True)
sys.exit(1)
try:
import tempfile
from gpu_mem_calculator.core.calculator import GPUMemoryCalculator
if preset:
# Load preset configuration
from gpu_mem_calculator.config.presets import get_preset_config
preset_config = get_preset_config(preset)
if preset_config is None:
click.echo(
f"Error: Preset '{preset}' not found. "
"Use 'gpu-mem-calc presets' to list available presets.",
err=True,
)
sys.exit(1)
# Write preset to temp file for from_config_file
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
json.dump(preset_config, f, indent=2)
temp_path = f.name
calculator = GPUMemoryCalculator.from_config_file(temp_path)
Path(temp_path).unlink() # Clean up temp file
elif config:
calculator = GPUMemoryCalculator.from_config_file(config)
else:
# This should never happen due to the checks above
click.echo("Error: Either --config or --preset is required", err=True)
sys.exit(1)
result = calculator.calculate()
# Format output
if format == "json":
output_text = json.dumps(result.model_dump(mode="json"), indent=2)
elif format == "yaml":
try:
import yaml # type: ignore[import-untyped]
output_text = yaml.dump(result.model_dump(mode="json"), default_flow_style=False)
except ImportError:
click.echo(
"Error: YAML format requires PyYAML. Install with: pip install pyyaml",
err=True,
)
sys.exit(1)
else: # table
output_text = _format_result_as_table(result, calculator)
# Write output
if output:
Path(output).write_text(output_text)
click.echo(f"Results written to {output}")
else:
click.echo(output_text)
except Exception as e:
click.echo(f"Error: {e}", err=True)
sys.exit(1)
@main.command()
@click.argument(
"params",
type=float,
required=True,
)
@click.option(
"--gpus",
"-g",
type=int,
default=1,
help="Number of GPUs (default: 1)",
)
@click.option(
"--gpu-mem",
"-m",
type=float,
default=80.0,
help="GPU memory in GB (default: 80.0)",
)
@click.option(
"--engine",
"-e",
type=click.Choice(["pytorch", "deepspeed", "megatron", "fsdp"]),
default="pytorch",
help="Training engine (default: pytorch)",
)
@click.option(
"--dtype",
"-d",
type=click.Choice(["fp32", "fp16", "bf16"]),
default="bf16",
help="Data type (default: bf16)",
)
def quick(
params: float,
gpus: int,
gpu_mem: float,
engine: str,
dtype: str,
) -> None:
"""Quick calculation from model size (in billions of parameters).
Example:
gpu-mem-calc quick 7 --gpus 8 --engine deepspeed
"""
try:
from gpu_mem_calculator.core.calculator import GPUMemoryCalculator
from gpu_mem_calculator.core.models import (
DType,
EngineConfig,
EngineType,
GPUConfig,
ModelConfig,
ParallelismConfig,
TrainingConfig,
)
# Map engine string to EngineType
engine_map = {
"pytorch": EngineType.PYTORCH_DDP,
"deepspeed": EngineType.DEEPSPEED,
"megatron": EngineType.MEGATRON_LM,
"fsdp": EngineType.FSDP,
}
# Map dtype string to DType
dtype_map = {
"fp32": DType.FP32,
"fp16": DType.FP16,
"bf16": DType.BF16,
}
# Create a minimal config for quick calculation
# Estimate model architecture from parameter count
# Rough approximation based on typical transformer models
num_params = int(params * 1e9)
# Estimate hidden size and layers from param count
# These are rough approximations
if params <= 1:
hidden_size, num_layers = 768, 12
elif params <= 7:
hidden_size, num_layers = 4096, 32
elif params <= 13:
hidden_size, num_layers = 5120, 40
elif params <= 30:
hidden_size, num_layers = 6656, 60
elif params <= 65:
hidden_size, num_layers = 8192, 80
else:
hidden_size, num_layers = 12288, 96
model_config = ModelConfig(
name="quick-estimate",
num_parameters=num_params,
num_layers=num_layers,
hidden_size=hidden_size,
num_attention_heads=hidden_size // 128,
vocab_size=32000,
max_seq_len=2048,
)
training_config = TrainingConfig(
batch_size=1,
gradient_accumulation_steps=1,
dtype=dtype_map[dtype],
)
parallelism_config = ParallelismConfig(data_parallel_size=gpus)
engine_config = EngineConfig(
type=engine_map[engine],
zero_stage=2 if engine == "deepspeed" else None,
)
gpu_config = GPUConfig(num_gpus=gpus, gpu_memory_gb=gpu_mem)
calculator = GPUMemoryCalculator(
model_config=model_config,
training_config=training_config,
parallelism_config=parallelism_config,
engine_config=engine_config,
gpu_config=gpu_config,
)
result = calculator.calculate()
# Display results
click.echo(_format_result_as_table(result, calculator))
except Exception as e:
click.echo(f"Error: {e}", err=True)
sys.exit(1)
@main.command()
@click.argument(
"config_path",
type=click.Path(exists=True),
)
def validate(config_path: str) -> None:
"""Validate a configuration file.
Example:
gpu-mem-calc validate configs/my_config.json
"""
try:
from gpu_mem_calculator.config import ConfigParser
ConfigParser.parse_full_config(config_path)
click.echo(f"✓ Configuration file '{config_path}' is valid")
except Exception as e:
click.echo(f"✗ Validation failed: {e}", err=True)
sys.exit(1)
@main.command()
@click.option(
"--format",
"-f",
type=click.Choice(["list", "json", "table"]),
default="list",
help="Output format (default: list)",
)
def presets(format: str) -> None:
"""List available model preset configurations.
Examples:
gpu-mem-calc presets
gpu-mem-calc presets --format table
gpu-mem-calc presets -f json
"""
try:
from gpu_mem_calculator.config.presets import list_presets
all_presets = list_presets()
if not all_presets:
click.echo("No presets found.")
return
if format == "json":
click.echo(json.dumps(all_presets, indent=2))
elif format == "table":
from rich.console import Console
from rich.table import Table
console = Console()
table = Table(
title="Available Model Presets",
show_header=True,
header_style="bold magenta",
)
table.add_column("Preset Name", style="cyan", width=25)
table.add_column("Display Name", style="green", width=30)
table.add_column("Description", style="yellow")
for name, info in sorted(all_presets.items()):
table.add_row(name, info["display_name"], info["description"])
console.print(table)
else: # list format
click.echo("Available model presets:\n")
for name, info in sorted(all_presets.items()): # type: ignore[annotation-unchecked]
click.echo(f" {name:25} - {info['display_name']}")
if info.get("description"):
click.echo(f"{'':27}{info['description']}")
click.echo()
except Exception as e:
click.echo(f"Error: {e}", err=True)
sys.exit(1)
def _format_result_as_table(result: MemoryResult, calculator: "GPUMemoryCalculator") -> str:
"""Format result as ASCII table."""
from rich.console import Console
from rich.table import Table
console = Console()
# Main results table
table = Table(
title="GPU Memory Calculation Results",
show_header=True,
header_style="bold magenta",
)
table.add_column("Metric", style="cyan", width=30)
table.add_column("Value", style="green")
# Memory results
table.add_row("Memory per GPU", f"{result.total_memory_per_gpu_gb:.2f} GB")
table.add_row("Total GPU Memory", f"{result.total_memory_all_gpus_gb:.2f} GB")
table.add_row("CPU Memory", f"{result.cpu_memory_gb:.2f} GB")
table.add_row("", "") # Spacer
# Breakdown
table.add_row("Model Parameters", f"{result.breakdown.model_params_gb:.2f} GB")
table.add_row("Gradients", f"{result.breakdown.gradients_gb:.2f} GB")
table.add_row("Optimizer States", f"{result.breakdown.optimizer_states_gb:.2f} GB")
table.add_row("Activations", f"{result.breakdown.activations_gb:.2f} GB")
table.add_row("Overhead", f"{result.breakdown.overhead_gb:.2f} GB")
table.add_row("", "") # Spacer
# Feasibility
status = "✓ Fits" if result.fits_on_gpu else "✗ OOM"
table.add_row("Status", status)
table.add_row("Memory Utilization", f"{result.memory_utilization_percent:.1f}%")
if result.recommended_batch_size:
table.add_row("Recommended Batch Size", str(result.recommended_batch_size))
# Capture table output
from io import StringIO
buffer = StringIO()
console.file = buffer
console.print(table)
return buffer.getvalue()
if __name__ == "__main__":
main()