#!/usr/bin/env python3 """Quick environment verification for Synesthesia ROCm stack. Checks: 1. ROCm SMI visibility 2. env.py imports and exports correct vars 3. PyTorch GPU detection 4. HF_TOKEN is set 5. JAX ROCm visibility (optional) Returns exit code 0 only if all critical checks pass. """ import os import subprocess import sys from pathlib import Path def check_rocm_smi(): """Check if ROCm SMI detects the GPU.""" try: result = subprocess.run( ["rocm-smi", "--showproductname"], capture_output=True, text=True, timeout=10, ) if result.returncode == 0 and "6700" in result.stdout.lower(): return True, result.stdout.strip().split("\n")[0] elif result.returncode == 0: return True, f"ROCm detected (GPU: {result.stdout.strip().split(chr(10))[0]})" else: return False, f"rocm-smi failed: {result.stderr.strip()}" except FileNotFoundError: return False, "rocm-smi not found (ROCm not installed?)" except subprocess.TimeoutExpired: return False, "rocm-smi timed out" except Exception as e: return False, str(e) def check_env_py(): """Check if env.py imports and exports correct variables.""" try: # Add project root to path project_root = Path(__file__).resolve().parent.parent sys.path.insert(0, str(project_root)) from ML_Pipeline.shared import env e = env.get_env_dict() checks = { "HSA_OVERRIDE_GFX_VERSION": e.get("HSA_OVERRIDE_GFX_VERSION"), "HSA_ENABLE_SDMA": e.get("HSA_ENABLE_SDMA"), "JAX_PLATFORMS": e.get("JAX_PLATFORMS"), } missing = [k for k, v in checks.items() if not v] if missing: return False, f"Missing env vars: {missing}" return True, f"All env vars set (HF_TOKEN: {'set' if e.get('HF_TOKEN') else 'NOT SET'})" except ImportError as e: return False, f"Import error: {e}" except Exception as e: return False, f"Error: {e}" def check_torch_gpu(): """Check if PyTorch can see ROCm GPU.""" try: import torch # Check for ROCm/HIP has_hip = hasattr(torch.version, "hip") and torch.version.hip is not None # Check CUDA (PyTorch uses cuda.is_available() for ROCm too) has_gpu = torch.cuda.is_available() if hasattr(torch, "cuda") else False if has_hip and has_gpu: device_name = torch.cuda.get_device_name(0) if has_gpu else "Unknown" return True, f"ROCm GPU detected: {device_name}" elif has_hip: return True, "ROCm available but no GPU visible" else: return False, "PyTorch ROCm not available" except ImportError: return False, "PyTorch not installed" except Exception as e: return False, f"Error: {e}" def check_hf_token(): """Check if HF_TOKEN is set.""" token = os.environ.get("HF_TOKEN") if token and len(token) > 10: return True, f"HF_TOKEN set ({len(token)} chars)" elif token: return False, "HF_TOKEN set but too short (may be invalid)" else: # Check in env.py try: project_root = Path(__file__).resolve().parent.parent sys.path.insert(0, str(project_root)) from ML_Pipeline.shared.env import HF_TOKEN if HF_TOKEN: return True, "HF_TOKEN loaded from env.py" except: pass return False, "HF_TOKEN not set" def check_jax_rocm(): """Check if JAX can see ROCm devices (optional).""" try: import jax devices = jax.devices() if len(devices) > 0: device_types = set(type(d).__name__ for d in devices) return True, f"JAX devices: {len(devices)} ({', '.join(device_types)})" else: return False, "JAX installed but no devices found" except ImportError: return None, "JAX not installed (optional)" except Exception as e: return False, f"Error: {e}" def main(): """Run all checks and report results.""" print("=" * 60) print("Synesthesia Environment Verification") print("=" * 60) print() checks = [ ("ROCm SMI", check_rocm_smi, True), ("env.py", check_env_py, True), ("PyTorch GPU", check_torch_gpu, True), ("HF_TOKEN", check_hf_token, True), ("JAX ROCm", check_jax_rocm, False), ] results = [] critical_failed = False for name, check_fn, is_critical in checks: try: passed, message = check_fn() except Exception as e: passed = False message = f"Check failed: {e}" status = "✓ PASS" if passed else ("⚠ WARN" if not is_critical else "✗ FAIL") if not passed and is_critical: critical_failed = True results.append((name, status, message, is_critical)) print(f"[{status}] {name}: {message}") print() print("=" * 60) if critical_failed: print("RESULT: CRITICAL CHECKS FAILED") print("Fix the above errors before proceeding.") return 1 else: print("RESULT: ALL CRITICAL CHECKS PASSED") print("Environment is ready for runtime module implementation.") return 0 if __name__ == "__main__": sys.exit(main())