"""AMD ROCm device management and monitoring.""" import os import subprocess from loguru import logger def get_device() -> str: """Return 'cuda' (ROCm uses cuda device name in PyTorch) or 'cpu'.""" try: import torch if torch.cuda.is_available(): device_name = torch.cuda.get_device_name(0) logger.info(f"GPU detected: {device_name}") return "cuda" except ImportError: pass logger.warning("No GPU available, falling back to CPU") return "cpu" def get_vram_gb() -> float: """Return available VRAM in GB.""" try: import torch if torch.cuda.is_available(): total = torch.cuda.get_device_properties(0).total_memory return round(total / 1024**3, 1) except Exception: pass return 0.0 def get_gpu_utilization() -> dict: """Return GPU utilization stats via rocm-smi.""" try: result = subprocess.run( ["rocm-smi", "--showuse", "--showmemuse", "--csv"], capture_output=True, text=True, timeout=5 ) if result.returncode == 0: lines = result.stdout.strip().split("\n") if len(lines) >= 2: headers = lines[0].split(",") values = lines[1].split(",") return dict(zip(headers, values)) except (FileNotFoundError, subprocess.TimeoutExpired): pass # Fallback: PyTorch memory stats try: import torch if torch.cuda.is_available(): allocated = torch.cuda.memory_allocated(0) / 1024**3 reserved = torch.cuda.memory_reserved(0) / 1024**3 total = torch.cuda.get_device_properties(0).total_memory / 1024**3 return { "vram_used_gb": round(allocated, 2), "vram_reserved_gb": round(reserved, 2), "vram_total_gb": round(total, 2), "vram_pct": round(allocated / total * 100, 1) if total > 0 else 0, } except Exception: pass return {} def get_optimal_batch_size(model_type: str = "whisper") -> int: """Return optimal batch size based on available VRAM.""" vram = get_vram_gb() if model_type == "whisper": if vram >= 48: return 32 elif vram >= 24: return 16 elif vram >= 16: return 8 return 4 elif model_type == "vision": if vram >= 80: return 8 elif vram >= 48: return 4 return 1 return 1 def log_gpu_status(): stats = get_gpu_utilization() if stats: logger.info(f"GPU stats: {stats}") else: logger.info(f"GPU: {get_device()} | VRAM: {get_vram_gb()} GB")