#!/usr/bin/env python3 """GPU connection tests for Colab, HF Spaces, and local backends. Tests device detection, mixed precision, model placement, forward pass, and backward pass on all available GPU targets. Usage: python3 scripts/gpu_connection_test.py # auto-detect python3 scripts/gpu_connection_test.py --target cuda # force CUDA python3 scripts/gpu_connection_test.py --target mps # force MPS python3 scripts/gpu_connection_test.py --target cpu # force CPU python3 scripts/gpu_connection_test.py --full # include training loop test """ from __future__ import annotations import argparse import sys import time from dataclasses import dataclass from typing import List, Tuple @dataclass class TestResult: name: str passed: bool detail: str elapsed_ms: float = 0.0 def _run_test(name: str, fn) -> TestResult: t0 = time.time() try: detail = fn() elapsed = (time.time() - t0) * 1000 return TestResult(name, True, detail, elapsed) except Exception as e: elapsed = (time.time() - t0) * 1000 return TestResult(name, False, str(e), elapsed) def test_torch_import() -> str: import torch return f"torch {torch.__version__}" def test_cuda_available() -> str: import torch if not torch.cuda.is_available(): return "CUDA not available (expected on MPS/CPU)" name = torch.cuda.get_device_name(0) cap = torch.cuda.get_device_capability() mem = torch.cuda.get_device_properties(0).total_mem / 1e9 return f"{name}, compute {cap[0]}.{cap[1]}, {mem:.1f}GB" def test_mps_available() -> str: import torch if not hasattr(torch.backends, 'mps') or not torch.backends.mps.is_available(): return "MPS not available (expected on Linux/Colab)" return "MPS available" def test_accelerate_import() -> str: from accelerate import Accelerator acc = Accelerator() return f"device={acc.device}, mp={acc.mixed_precision}" def test_device_resolution() -> str: import torch if torch.cuda.is_available(): return "cuda" if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): return "mps" return "cpu" def test_mixed_precision_support() -> str: import torch if not torch.cuda.is_available(): return "skipped (no CUDA)" cap = torch.cuda.get_device_capability() if cap[0] >= 8: # Test bf16 x = torch.randn(4, 4, device="cuda", dtype=torch.bfloat16) y = x @ x.T return f"bf16 supported (compute {cap[0]}.{cap[1]})" # Test fp16 x = torch.randn(4, 4, device="cuda", dtype=torch.float16) y = x @ x.T return f"fp16 supported (compute {cap[0]}.{cap[1]})" def test_model_placement(target: str) -> str: import torch from training.core.kan_jepa_generator import create_kan_jepa_model from training.core.bidirectional_generator import SimpleVocab device = torch.device(target) model = create_kan_jepa_model(100, "small") model = model.to(device) n_params = sum(p.numel() for p in model.parameters()) actual_dev = next(model.parameters()).device return f"{n_params:,} params on {actual_dev}" def test_forward_pass(target: str) -> str: import torch from training.core.kan_jepa_generator import create_kan_jepa_model device = torch.device(target) model = create_kan_jepa_model(100, "small").to(device) model.eval() src = torch.randint(1, 50, (2, 10), device=device) tgt = torch.randint(1, 50, (2, 8), device=device) with torch.no_grad(): logits, info = model(src, tgt) return f"logits={list(logits.shape)}, jepa_loss={info['jepa_loss'].item():.4f}" def test_backward_pass(target: str) -> str: import torch from training.core.kan_jepa_generator import create_kan_jepa_model device = torch.device(target) model = create_kan_jepa_model(100, "small").to(device) model.train() src = torch.randint(1, 50, (4, 12), device=device) tgt = torch.randint(1, 50, (4, 10), device=device) logits, info = model(src, tgt[:, :-1]) loss = logits.mean() + info["jepa_loss"] loss.backward() grad_norms = [] for p in model.parameters(): if p.grad is not None: grad_norms.append(p.grad.norm().item()) return f"loss={loss.item():.4f}, grad_params={len(grad_norms)}, max_grad={max(grad_norms):.4f}" def test_mixed_precision_forward(target: str) -> str: import torch if target != "cuda": return "skipped (CUDA only)" from training.core.kan_jepa_generator import create_kan_jepa_model device = torch.device("cuda") model = create_kan_jepa_model(100, "small").to(device) model.train() cap = torch.cuda.get_device_capability() dtype = torch.bfloat16 if cap[0] >= 8 else torch.float16 src = torch.randint(1, 50, (4, 12), device=device) tgt = torch.randint(1, 50, (4, 10), device=device) with torch.autocast(device_type="cuda", dtype=dtype): logits, info = model(src, tgt[:, :-1]) loss = logits.mean() + info["jepa_loss"] loss.backward() return f"autocast {dtype} OK, loss={loss.item():.4f}" def test_accelerate_training(target: str) -> str: try: from training.core.accelerate_trainer import AccelerateTrainer, AccelerateConfig except ImportError: return "accelerate_trainer not available" from training.core.kan_jepa_generator import create_kan_jepa_model from training.core.bidirectional_generator import SimpleVocab pairs = [ ("Find all nodes", "MATCH (n) RETURN n"), ("Count people", "MATCH (p:Person) RETURN count(p)"), ("Find movies", "MATCH (m:Movie) RETURN m"), ("Who knows who", "MATCH (a)-[:KNOWS]->(b) RETURN a, b"), ] vocab = SimpleVocab.build_from_corpus( [t for p in pairs for t in p], max_size=100) model = create_kan_jepa_model(len(vocab), "small") cfg = AccelerateConfig( epochs=3, batch_size=2, gradient_accumulation_steps=1, mixed_precision="no", log_every=1, eval_samples=0) trainer = AccelerateTrainer(model, vocab, pairs, cfg) result = trainer.train(verbose=False) return f"3 epochs OK, loss={result['final_loss']:.4f}, {result['training_time_s']:.1f}s on {result['device']}" def test_colab_detection() -> str: """Detect if running inside Google Colab.""" try: import google.colab # noqa: F401 return "running in Colab" except ImportError: return "not in Colab (local environment)" def test_hf_space_detection() -> str: """Detect if running inside HuggingFace Spaces.""" import os if os.environ.get("SPACE_ID"): return f"HF Space: {os.environ['SPACE_ID']}" return "not in HF Spaces" def test_modular_max() -> str: """Test Modular MAX / Mojo availability.""" try: import max as _max ver = getattr(_max, "__version__", "unknown") return f"MAX {ver} available" except ImportError: return "MAX not installed (pip install modular)" def test_mlx_available() -> str: """Test Apple MLX framework.""" try: import mlx.core as mx ver = mx.__version__ if hasattr(mx, "__version__") else "unknown" # Quick compute test a = mx.ones((4, 4)) b = mx.ones((4, 4)) c = a @ b mx.eval(c) return f"MLX {ver}, matmul OK, unified memory" except ImportError: return "MLX not installed (pip install mlx)" except Exception as e: return f"MLX import OK but compute failed: {e}" def test_snowflake_available() -> str: """Test Snowflake ML SDK availability.""" import os if os.environ.get("SNOWFLAKE_ACCOUNT"): return f"SPCS environment: {os.environ['SNOWFLAKE_ACCOUNT']}" try: import snowflake.ml # noqa: F401 return "snowflake-ml-python installed (set SNOWFLAKE_ACCOUNT to connect)" except ImportError: return "snowflake-ml not installed (pip install snowflake-ml-python)" def test_unified_backend() -> str: """Test unified backend detection across all 7 backends.""" import sys, os sys.path.insert(0, os.getcwd()) from training.core.unified_backend import detect_backend, probe_all_backends backend = detect_backend() all_infos = probe_all_backends() available = [i.name for i in all_infos if i.available] unavailable = [i.name for i in all_infos if not i.available] return (f"selected={backend.name}, " f"available=[{', '.join(available)}], " f"not_found=[{', '.join(unavailable)}]") def test_memory_estimate() -> str: """Estimate GPU memory for full training.""" import torch from training.core.kan_jepa_generator import create_kan_jepa_model model = create_kan_jepa_model(2000, "text2cypher") n_params = sum(p.numel() for p in model.parameters()) # fp32: 4 bytes/param. With Adam: ~3x (params + grads + 2 momentum) mem_fp32 = n_params * 4 * 3 / 1e6 # MB mem_fp16 = n_params * 2 * 3 / 1e6 return f"{n_params:,} params, est. {mem_fp32:.0f}MB fp32 / {mem_fp16:.0f}MB fp16" def run_all(target: str, full: bool = False) -> List[TestResult]: results = [] # Environment detection results.append(_run_test("torch import", test_torch_import)) results.append(_run_test("CUDA available", test_cuda_available)) results.append(_run_test("MPS available", test_mps_available)) results.append(_run_test("Colab detection", test_colab_detection)) results.append(_run_test("HF Space detection", test_hf_space_detection)) results.append(_run_test("device resolution", test_device_resolution)) results.append(_run_test("accelerate import", test_accelerate_import)) results.append(_run_test("mixed precision support", test_mixed_precision_support)) results.append(_run_test("Modular MAX / Mojo", test_modular_max)) results.append(_run_test("Apple MLX", test_mlx_available)) results.append(_run_test("Snowflake SPCS", test_snowflake_available)) results.append(_run_test("unified backend (7 adapters)", test_unified_backend)) results.append(_run_test("memory estimate", test_memory_estimate)) # Model tests on target device results.append(_run_test(f"model placement [{target}]", lambda: test_model_placement(target))) results.append(_run_test(f"forward pass [{target}]", lambda: test_forward_pass(target))) results.append(_run_test(f"backward pass [{target}]", lambda: test_backward_pass(target))) results.append(_run_test(f"mixed precision fwd [{target}]", lambda: test_mixed_precision_forward(target))) if full: results.append(_run_test(f"accelerate training [{target}]", lambda: test_accelerate_training(target))) return results def main(): parser = argparse.ArgumentParser(description="GPU connection tests") parser.add_argument("--target", choices=["auto", "cuda", "mps", "cpu"], default="auto", help="Target device") parser.add_argument("--full", action="store_true", help="Include training loop test") args = parser.parse_args() # Resolve target if args.target == "auto": import torch if torch.cuda.is_available(): target = "cuda" elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): target = "mps" else: target = "cpu" else: target = args.target print(f"=== GPU Connection Tests (target: {target}) ===\n") results = run_all(target, args.full) # Print results passed = sum(1 for r in results if r.passed) total = len(results) for r in results: status = "PASS" if r.passed else "FAIL" print(f" [{status}] {r.name}: {r.detail} ({r.elapsed_ms:.0f}ms)") print(f"\n{passed}/{total} tests passed") if passed < total: failed = [r for r in results if not r.passed] print("\nFailed tests:") for r in failed: print(f" - {r.name}: {r.detail}") sys.exit(1) if __name__ == "__main__": main()