Synesthesia / scripts /verify_rocm_stack.py
Ashiedu's picture
Sync unified workbench
0490201 verified
"""End-to-end stack verification for Synesthesia ROCm 7.2.1 environment.
Performs HARD checks (drivers, venvs, GPUs) and SOFT checks (models, auth).
"""
import os
import subprocess
import sys
from typing import Dict, List, Tuple
from rich.console import Console
from rich.table import Table
from ML_Pipeline.shared import env
console = Console()
def run_check(cmd: List[str], env_vars: Dict[str, str] = None) -> Tuple[bool, str]:
"""Helper to run a check command and capture success/failure."""
try:
res = subprocess.run(cmd, capture_output=True, text=True, timeout=15, env=env_vars)
if res.returncode == 0:
return True, res.stdout.strip()
else:
return False, res.stderr.strip()
except Exception as e:
return False, str(e)
def verify_stack():
"""Main verification loop."""
console.print("[bold blue]🌌 Synesthesia Stack Verification: ROCm 7.2.1[/]\n")
# Pre-fetch env dict
e_dict = env.get_env_dict()
checks = [
# HARD CHECKS: Failure means unusable stack
("ROCm Install", ["ls", "/opt/rocm"], "HARD"),
("rocm-smi visibility", ["rocm-smi", "--showproductname"], "HARD"),
("JAX GPU Visibility", [".venv-jax/bin/python", "-c", "import jax; assert len(jax.devices()) > 0"], "HARD"),
("TF-ROCm Visibility", [".venv-jax/bin/python", "-c", "import tensorflow as tf; assert len(tf.config.list_physical_devices('GPU')) > 0"], "HARD"),
("PyTorch Visibility", [".venv-torch/bin/python", "-c", "import torch; assert torch.cuda.is_available()"], "HARD"),
("BitsAndBytes ROCm", [".venv-torch/bin/python", "-c", "import bitsandbytes; assert bitsandbytes.__version__ >= '0.43.0'"], "HARD"),
("IREE Runtime", [".venv-jax/bin/python", "-c", "import iree.runtime"], "HARD"),
# SOFT CHECKS: Warning only
("HF Authentication", [".venv-torch/bin/python", "-c", "import huggingface_hub; huggingface_hub.whoami()"], "SOFT"),
("Model Directory", ["ls", "Content/MLModels"], "SOFT"),
]
table = Table(title="Health Check Results")
table.add_column("Component", style="cyan")
table.add_column("Tier", style="magenta")
table.add_column("Status", justify="center")
table.add_column("Detail", style="dim")
hard_failed = False
for name, cmd, tier in checks:
success, detail = run_check(cmd, env_vars=e_dict)
status_str = "[green]PASS[/]" if success else "[red]FAIL[/]"
if not success and tier == "SOFT":
status_str = "[yellow]WARN[/]"
table.add_row(name, tier, status_str, detail[:50] + "..." if len(detail) > 50 else detail)
if not success and tier == "HARD":
hard_failed = True
console.print(table)
if hard_failed:
console.print("\n[bold red]CRITICAL: One or more HARD checks failed.[/]")
console.print("Ensure ROCm 7.2.1 is installed and bootstrap_venvs.sh has been run.")
sys.exit(1)
else:
console.print("\n[bold green]βœ… All HARD checks passed. Stack is ready for generation.[/]")
sys.exit(0)
if __name__ == "__main__":
verify_stack()