| """ |
| Cost estimation utilities for cloud GPU training. |
| |
| Usage: |
| from src.cost_estimate import CostTracker, detect_hardware |
| |
| tracker = CostTracker(gpu_type="RTX_A4000") |
| tracker.start() |
| # ... training loop ... |
| tracker.update(epoch=1, total_epochs=100) |
| tracker.summary() |
| """ |
|
|
| import time |
| from dataclasses import dataclass |
| from typing import Optional |
|
|
|
|
| |
| GPU_PRICES = { |
| "RTX_A4000": 0.20, |
| "RTX_A5000": 0.28, |
| "RTX_3090": 0.22, |
| "RTX_4090": 0.44, |
| "A40": 0.39, |
| "A100_40GB": 1.09, |
| "A100_80GB": 1.59, |
| "H100": 2.49, |
| "CPU": 0.0, |
| } |
|
|
|
|
| def detect_cloud_provider() -> str: |
| """Detect cloud provider from environment or metadata.""" |
| import os |
|
|
| |
| if os.getenv("RUNPOD_POD_ID"): |
| return "runpod" |
| if os.getenv("LINODE_ID") or os.getenv("LINODE_DATACENTER_ID"): |
| return "linode" |
| if os.getenv("AWS_EXECUTION_ENV") or os.getenv("AWS_REGION"): |
| return "aws" |
| if os.getenv("GOOGLE_CLOUD_PROJECT") or os.getenv("GCP_PROJECT"): |
| return "gcp" |
| if os.getenv("AZURE_CLIENT_ID") or os.getenv("MSI_ENDPOINT"): |
| return "azure" |
| if os.getenv("LAMBDA_LABS_API_KEY"): |
| return "lambda" |
| if os.getenv("VAST_CONTAINERLABEL"): |
| return "vast" |
| if os.getenv("COLAB_GPU"): |
| return "colab" |
| if os.getenv("KAGGLE_KERNEL_RUN_TYPE"): |
| return "kaggle" |
|
|
| |
| try: |
| import subprocess |
|
|
| |
| result = subprocess.run( |
| ["curl", "-s", "-m", "1", "http://169.254.169.254/v1/instance"], |
| capture_output=True, timeout=2 |
| ) |
| if result.returncode == 0 and b"instance" in result.stdout.lower(): |
| return "linode" |
|
|
| |
| result = subprocess.run( |
| ["curl", "-s", "-m", "1", "http://169.254.169.254/latest/meta-data/ami-id"], |
| capture_output=True, timeout=2 |
| ) |
| if result.returncode == 0 and b"ami-" in result.stdout: |
| return "aws" |
|
|
| |
| result = subprocess.run( |
| ["curl", "-s", "-m", "1", "-H", "Metadata-Flavor: Google", |
| "http://metadata.google.internal/computeMetadata/v1/"], |
| capture_output=True, timeout=2 |
| ) |
| if result.returncode == 0 and result.stdout: |
| return "gcp" |
|
|
| except Exception: |
| pass |
|
|
| |
| try: |
| with open("/etc/hostname", "r") as f: |
| hostname = f.read().lower() |
| if "linode" in hostname: |
| return "linode" |
| except Exception: |
| pass |
|
|
| |
| try: |
| with open("/sys/class/dmi/id/sys_vendor", "r") as f: |
| vendor = f.read().strip().lower() |
| if "linode" in vendor: |
| return "linode" |
| if "amazon" in vendor: |
| return "aws" |
| if "google" in vendor: |
| return "gcp" |
| if "microsoft" in vendor: |
| return "azure" |
| except Exception: |
| pass |
|
|
| |
| try: |
| import subprocess |
| result = subprocess.run( |
| ["cat", "/sys/class/dmi/id/product_name"], |
| capture_output=True, timeout=2 |
| ) |
| if result.returncode == 0: |
| product = result.stdout.decode().lower() |
| if "linode" in product: |
| return "linode" |
| if "amazon" in product or "ec2" in product: |
| return "aws" |
| if "google" in product: |
| return "gcp" |
| except Exception: |
| pass |
|
|
| return "local" |
|
|
|
|
| @dataclass |
| class HardwareInfo: |
| """Detected hardware information.""" |
| device_type: str |
| gpu_name: Optional[str] = None |
| gpu_memory_gb: Optional[float] = None |
| cpu_name: Optional[str] = None |
| cpu_cores: Optional[int] = None |
| ram_gb: Optional[float] = None |
| cloud_provider: str = "local" |
|
|
| def get_gpu_type(self) -> str: |
| """Map detected GPU to pricing category.""" |
| if self.device_type == "cpu" or not self.gpu_name: |
| return "CPU" |
|
|
| name = self.gpu_name.upper() |
|
|
| |
| if "H100" in name: |
| return "H100" |
| elif "A100" in name: |
| if self.gpu_memory_gb and self.gpu_memory_gb > 50: |
| return "A100_80GB" |
| return "A100_40GB" |
| elif "A40" in name: |
| return "A40" |
| elif "4090" in name: |
| return "RTX_4090" |
| elif "3090" in name: |
| return "RTX_3090" |
| elif "A5000" in name: |
| return "RTX_A5000" |
| elif "A4000" in name: |
| return "RTX_A4000" |
| else: |
| return "RTX_A4000" |
|
|
| def to_dict(self) -> dict: |
| """Convert to dictionary for logging.""" |
| return { |
| "device_type": self.device_type, |
| "gpu_name": self.gpu_name, |
| "gpu_memory_gb": self.gpu_memory_gb, |
| "cpu_name": self.cpu_name, |
| "cpu_cores": self.cpu_cores, |
| "ram_gb": self.ram_gb, |
| "gpu_type": self.get_gpu_type(), |
| "cloud_provider": self.cloud_provider, |
| } |
|
|
| def __str__(self) -> str: |
| provider = f"[{self.cloud_provider}] " if self.cloud_provider != "local" else "" |
| if self.device_type == "cuda" and self.gpu_name: |
| mem = f" ({self.gpu_memory_gb:.1f}GB)" if self.gpu_memory_gb else "" |
| return f"{provider}{self.gpu_name}{mem}" |
| else: |
| ram = f", {self.ram_gb:.1f}GB RAM" if self.ram_gb else "" |
| return f"{provider}CPU: {self.cpu_name or 'Unknown'} ({self.cpu_cores} cores{ram})" |
|
|
|
|
| def detect_hardware() -> HardwareInfo: |
| """Detect available hardware (GPU/CPU) and cloud provider.""" |
| import platform |
| import os |
|
|
| |
| cloud_provider = detect_cloud_provider() |
|
|
| |
| cpu_name = platform.processor() or "Unknown" |
| cpu_cores = os.cpu_count() |
|
|
| |
| try: |
| import subprocess |
| if platform.system() == "Linux": |
| mem_info = subprocess.check_output(["free", "-b"]).decode() |
| ram_bytes = int(mem_info.split("\n")[1].split()[1]) |
| ram_gb = ram_bytes / (1024**3) |
| else: |
| ram_gb = None |
| except Exception: |
| ram_gb = None |
|
|
| |
| try: |
| import torch |
| if torch.cuda.is_available(): |
| gpu_name = torch.cuda.get_device_name(0) |
| gpu_memory_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3) |
| return HardwareInfo( |
| device_type="cuda", |
| gpu_name=gpu_name, |
| gpu_memory_gb=gpu_memory_gb, |
| cpu_name=cpu_name, |
| cpu_cores=cpu_cores, |
| ram_gb=ram_gb, |
| cloud_provider=cloud_provider, |
| ) |
| except Exception: |
| pass |
|
|
| return HardwareInfo( |
| device_type="cpu", |
| cpu_name=cpu_name, |
| cpu_cores=cpu_cores, |
| ram_gb=ram_gb, |
| cloud_provider=cloud_provider, |
| ) |
|
|
|
|
| @dataclass |
| class CostTracker: |
| """Track training time and estimate costs.""" |
|
|
| gpu_type: str = "RTX_A4000" |
|
|
| def __post_init__(self): |
| self.start_time: Optional[float] = None |
| self.hourly_rate = GPU_PRICES.get(self.gpu_type, 0.20) |
| self.last_report_time: Optional[float] = None |
| self.report_interval = 300 |
|
|
| def start(self): |
| """Start the cost tracker.""" |
| self.start_time = time.time() |
| self.last_report_time = self.start_time |
|
|
| def elapsed_seconds(self) -> float: |
| """Get elapsed time in seconds.""" |
| if self.start_time is None: |
| return 0 |
| return time.time() - self.start_time |
|
|
| def elapsed_hours(self) -> float: |
| """Get elapsed time in hours.""" |
| return self.elapsed_seconds() / 3600 |
|
|
| def current_cost(self) -> float: |
| """Get current cost in USD.""" |
| return self.elapsed_hours() * self.hourly_rate |
|
|
| def estimate_total_cost(self, progress: float) -> float: |
| """ |
| Estimate total cost based on current progress. |
| |
| Args: |
| progress: Training progress (0.0 to 1.0) |
| """ |
| if progress <= 0: |
| return 0 |
| return self.current_cost() / progress |
|
|
| def estimate_remaining_cost(self, progress: float) -> float: |
| """Estimate remaining cost.""" |
| return self.estimate_total_cost(progress) - self.current_cost() |
|
|
| def estimate_remaining_time(self, progress: float) -> float: |
| """Estimate remaining time in seconds.""" |
| if progress <= 0: |
| return 0 |
| elapsed = self.elapsed_seconds() |
| total_time = elapsed / progress |
| return total_time - elapsed |
|
|
| def format_time(self, seconds: float) -> str: |
| """Format seconds to human readable string.""" |
| if seconds < 60: |
| return f"{seconds:.0f}s" |
| elif seconds < 3600: |
| mins = seconds / 60 |
| return f"{mins:.1f}m" |
| else: |
| hours = seconds / 3600 |
| return f"{hours:.1f}h" |
|
|
| def format_cost(self, cost: float) -> str: |
| """Format cost to human readable string.""" |
| if cost < 0.01: |
| return f"${cost:.4f}" |
| elif cost < 1: |
| return f"${cost:.3f}" |
| else: |
| return f"${cost:.2f}" |
|
|
| def should_report(self) -> bool: |
| """Check if it's time to report costs.""" |
| if self.last_report_time is None: |
| return True |
| return time.time() - self.last_report_time >= self.report_interval |
|
|
| def get_status(self, epoch: int, total_epochs: int) -> str: |
| """Get formatted status string with cost info.""" |
| progress = epoch / total_epochs if total_epochs > 0 else 0 |
|
|
| current = self.current_cost() |
| estimated_total = self.estimate_total_cost(progress) |
| remaining_time = self.estimate_remaining_time(progress) |
|
|
| return ( |
| f"Cost: {self.format_cost(current)} | " |
| f"Est. total: {self.format_cost(estimated_total)} | " |
| f"ETA: {self.format_time(remaining_time)}" |
| ) |
|
|
| def update(self, epoch: int, total_epochs: int, force: bool = False) -> Optional[str]: |
| """ |
| Update and optionally return status if report interval passed. |
| |
| Returns status string if it's time to report, None otherwise. |
| """ |
| if force or self.should_report(): |
| self.last_report_time = time.time() |
| return self.get_status(epoch, total_epochs) |
| return None |
|
|
| def summary(self, epoch: int, total_epochs: int) -> str: |
| """Get final summary.""" |
| progress = epoch / total_epochs if total_epochs > 0 else 1.0 |
| elapsed = self.elapsed_seconds() |
| cost = self.current_cost() |
|
|
| lines = [ |
| "=" * 50, |
| "Cost Summary", |
| "=" * 50, |
| f" GPU: {self.gpu_type} (${self.hourly_rate}/hr)", |
| f" Duration: {self.format_time(elapsed)}", |
| f" Total cost: {self.format_cost(cost)}", |
| ] |
|
|
| if progress < 1.0: |
| estimated = self.estimate_total_cost(progress) |
| lines.append(f" Est. full training: {self.format_cost(estimated)}") |
|
|
| lines.append("=" * 50) |
| return "\n".join(lines) |
|
|