ane-kan-runtime / scripts /gpu_connection_test.py
JohnGenetica's picture
Deploy ANE KAN runtime Space
201cf4d verified
#!/usr/bin/env python3
"""GPU connection tests for Colab, HF Spaces, and local backends.
Tests device detection, mixed precision, model placement, forward pass,
and backward pass on all available GPU targets.
Usage:
python3 scripts/gpu_connection_test.py # auto-detect
python3 scripts/gpu_connection_test.py --target cuda # force CUDA
python3 scripts/gpu_connection_test.py --target mps # force MPS
python3 scripts/gpu_connection_test.py --target cpu # force CPU
python3 scripts/gpu_connection_test.py --full # include training loop test
"""
from __future__ import annotations
import argparse
import sys
import time
from dataclasses import dataclass
from typing import List, Tuple
@dataclass
class TestResult:
name: str
passed: bool
detail: str
elapsed_ms: float = 0.0
def _run_test(name: str, fn) -> TestResult:
t0 = time.time()
try:
detail = fn()
elapsed = (time.time() - t0) * 1000
return TestResult(name, True, detail, elapsed)
except Exception as e:
elapsed = (time.time() - t0) * 1000
return TestResult(name, False, str(e), elapsed)
def test_torch_import() -> str:
import torch
return f"torch {torch.__version__}"
def test_cuda_available() -> str:
import torch
if not torch.cuda.is_available():
return "CUDA not available (expected on MPS/CPU)"
name = torch.cuda.get_device_name(0)
cap = torch.cuda.get_device_capability()
mem = torch.cuda.get_device_properties(0).total_mem / 1e9
return f"{name}, compute {cap[0]}.{cap[1]}, {mem:.1f}GB"
def test_mps_available() -> str:
import torch
if not hasattr(torch.backends, 'mps') or not torch.backends.mps.is_available():
return "MPS not available (expected on Linux/Colab)"
return "MPS available"
def test_accelerate_import() -> str:
from accelerate import Accelerator
acc = Accelerator()
return f"device={acc.device}, mp={acc.mixed_precision}"
def test_device_resolution() -> str:
import torch
if torch.cuda.is_available():
return "cuda"
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
return "mps"
return "cpu"
def test_mixed_precision_support() -> str:
import torch
if not torch.cuda.is_available():
return "skipped (no CUDA)"
cap = torch.cuda.get_device_capability()
if cap[0] >= 8:
# Test bf16
x = torch.randn(4, 4, device="cuda", dtype=torch.bfloat16)
y = x @ x.T
return f"bf16 supported (compute {cap[0]}.{cap[1]})"
# Test fp16
x = torch.randn(4, 4, device="cuda", dtype=torch.float16)
y = x @ x.T
return f"fp16 supported (compute {cap[0]}.{cap[1]})"
def test_model_placement(target: str) -> str:
import torch
from training.core.kan_jepa_generator import create_kan_jepa_model
from training.core.bidirectional_generator import SimpleVocab
device = torch.device(target)
model = create_kan_jepa_model(100, "small")
model = model.to(device)
n_params = sum(p.numel() for p in model.parameters())
actual_dev = next(model.parameters()).device
return f"{n_params:,} params on {actual_dev}"
def test_forward_pass(target: str) -> str:
import torch
from training.core.kan_jepa_generator import create_kan_jepa_model
device = torch.device(target)
model = create_kan_jepa_model(100, "small").to(device)
model.eval()
src = torch.randint(1, 50, (2, 10), device=device)
tgt = torch.randint(1, 50, (2, 8), device=device)
with torch.no_grad():
logits, info = model(src, tgt)
return f"logits={list(logits.shape)}, jepa_loss={info['jepa_loss'].item():.4f}"
def test_backward_pass(target: str) -> str:
import torch
from training.core.kan_jepa_generator import create_kan_jepa_model
device = torch.device(target)
model = create_kan_jepa_model(100, "small").to(device)
model.train()
src = torch.randint(1, 50, (4, 12), device=device)
tgt = torch.randint(1, 50, (4, 10), device=device)
logits, info = model(src, tgt[:, :-1])
loss = logits.mean() + info["jepa_loss"]
loss.backward()
grad_norms = []
for p in model.parameters():
if p.grad is not None:
grad_norms.append(p.grad.norm().item())
return f"loss={loss.item():.4f}, grad_params={len(grad_norms)}, max_grad={max(grad_norms):.4f}"
def test_mixed_precision_forward(target: str) -> str:
import torch
if target != "cuda":
return "skipped (CUDA only)"
from training.core.kan_jepa_generator import create_kan_jepa_model
device = torch.device("cuda")
model = create_kan_jepa_model(100, "small").to(device)
model.train()
cap = torch.cuda.get_device_capability()
dtype = torch.bfloat16 if cap[0] >= 8 else torch.float16
src = torch.randint(1, 50, (4, 12), device=device)
tgt = torch.randint(1, 50, (4, 10), device=device)
with torch.autocast(device_type="cuda", dtype=dtype):
logits, info = model(src, tgt[:, :-1])
loss = logits.mean() + info["jepa_loss"]
loss.backward()
return f"autocast {dtype} OK, loss={loss.item():.4f}"
def test_accelerate_training(target: str) -> str:
try:
from training.core.accelerate_trainer import AccelerateTrainer, AccelerateConfig
except ImportError:
return "accelerate_trainer not available"
from training.core.kan_jepa_generator import create_kan_jepa_model
from training.core.bidirectional_generator import SimpleVocab
pairs = [
("Find all nodes", "MATCH (n) RETURN n"),
("Count people", "MATCH (p:Person) RETURN count(p)"),
("Find movies", "MATCH (m:Movie) RETURN m"),
("Who knows who", "MATCH (a)-[:KNOWS]->(b) RETURN a, b"),
]
vocab = SimpleVocab.build_from_corpus(
[t for p in pairs for t in p], max_size=100)
model = create_kan_jepa_model(len(vocab), "small")
cfg = AccelerateConfig(
epochs=3, batch_size=2, gradient_accumulation_steps=1,
mixed_precision="no", log_every=1, eval_samples=0)
trainer = AccelerateTrainer(model, vocab, pairs, cfg)
result = trainer.train(verbose=False)
return f"3 epochs OK, loss={result['final_loss']:.4f}, {result['training_time_s']:.1f}s on {result['device']}"
def test_colab_detection() -> str:
"""Detect if running inside Google Colab."""
try:
import google.colab # noqa: F401
return "running in Colab"
except ImportError:
return "not in Colab (local environment)"
def test_hf_space_detection() -> str:
"""Detect if running inside HuggingFace Spaces."""
import os
if os.environ.get("SPACE_ID"):
return f"HF Space: {os.environ['SPACE_ID']}"
return "not in HF Spaces"
def test_modular_max() -> str:
"""Test Modular MAX / Mojo availability."""
try:
import max as _max
ver = getattr(_max, "__version__", "unknown")
return f"MAX {ver} available"
except ImportError:
return "MAX not installed (pip install modular)"
def test_mlx_available() -> str:
"""Test Apple MLX framework."""
try:
import mlx.core as mx
ver = mx.__version__ if hasattr(mx, "__version__") else "unknown"
# Quick compute test
a = mx.ones((4, 4))
b = mx.ones((4, 4))
c = a @ b
mx.eval(c)
return f"MLX {ver}, matmul OK, unified memory"
except ImportError:
return "MLX not installed (pip install mlx)"
except Exception as e:
return f"MLX import OK but compute failed: {e}"
def test_snowflake_available() -> str:
"""Test Snowflake ML SDK availability."""
import os
if os.environ.get("SNOWFLAKE_ACCOUNT"):
return f"SPCS environment: {os.environ['SNOWFLAKE_ACCOUNT']}"
try:
import snowflake.ml # noqa: F401
return "snowflake-ml-python installed (set SNOWFLAKE_ACCOUNT to connect)"
except ImportError:
return "snowflake-ml not installed (pip install snowflake-ml-python)"
def test_unified_backend() -> str:
"""Test unified backend detection across all 7 backends."""
import sys, os
sys.path.insert(0, os.getcwd())
from training.core.unified_backend import detect_backend, probe_all_backends
backend = detect_backend()
all_infos = probe_all_backends()
available = [i.name for i in all_infos if i.available]
unavailable = [i.name for i in all_infos if not i.available]
return (f"selected={backend.name}, "
f"available=[{', '.join(available)}], "
f"not_found=[{', '.join(unavailable)}]")
def test_memory_estimate() -> str:
"""Estimate GPU memory for full training."""
import torch
from training.core.kan_jepa_generator import create_kan_jepa_model
model = create_kan_jepa_model(2000, "text2cypher")
n_params = sum(p.numel() for p in model.parameters())
# fp32: 4 bytes/param. With Adam: ~3x (params + grads + 2 momentum)
mem_fp32 = n_params * 4 * 3 / 1e6 # MB
mem_fp16 = n_params * 2 * 3 / 1e6
return f"{n_params:,} params, est. {mem_fp32:.0f}MB fp32 / {mem_fp16:.0f}MB fp16"
def run_all(target: str, full: bool = False) -> List[TestResult]:
results = []
# Environment detection
results.append(_run_test("torch import", test_torch_import))
results.append(_run_test("CUDA available", test_cuda_available))
results.append(_run_test("MPS available", test_mps_available))
results.append(_run_test("Colab detection", test_colab_detection))
results.append(_run_test("HF Space detection", test_hf_space_detection))
results.append(_run_test("device resolution", test_device_resolution))
results.append(_run_test("accelerate import", test_accelerate_import))
results.append(_run_test("mixed precision support", test_mixed_precision_support))
results.append(_run_test("Modular MAX / Mojo", test_modular_max))
results.append(_run_test("Apple MLX", test_mlx_available))
results.append(_run_test("Snowflake SPCS", test_snowflake_available))
results.append(_run_test("unified backend (7 adapters)", test_unified_backend))
results.append(_run_test("memory estimate", test_memory_estimate))
# Model tests on target device
results.append(_run_test(f"model placement [{target}]",
lambda: test_model_placement(target)))
results.append(_run_test(f"forward pass [{target}]",
lambda: test_forward_pass(target)))
results.append(_run_test(f"backward pass [{target}]",
lambda: test_backward_pass(target)))
results.append(_run_test(f"mixed precision fwd [{target}]",
lambda: test_mixed_precision_forward(target)))
if full:
results.append(_run_test(f"accelerate training [{target}]",
lambda: test_accelerate_training(target)))
return results
def main():
parser = argparse.ArgumentParser(description="GPU connection tests")
parser.add_argument("--target", choices=["auto", "cuda", "mps", "cpu"],
default="auto", help="Target device")
parser.add_argument("--full", action="store_true",
help="Include training loop test")
args = parser.parse_args()
# Resolve target
if args.target == "auto":
import torch
if torch.cuda.is_available():
target = "cuda"
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
target = "mps"
else:
target = "cpu"
else:
target = args.target
print(f"=== GPU Connection Tests (target: {target}) ===\n")
results = run_all(target, args.full)
# Print results
passed = sum(1 for r in results if r.passed)
total = len(results)
for r in results:
status = "PASS" if r.passed else "FAIL"
print(f" [{status}] {r.name}: {r.detail} ({r.elapsed_ms:.0f}ms)")
print(f"\n{passed}/{total} tests passed")
if passed < total:
failed = [r for r in results if not r.passed]
print("\nFailed tests:")
for r in failed:
print(f" - {r.name}: {r.detail}")
sys.exit(1)
if __name__ == "__main__":
main()