File size: 12,283 Bytes
201cf4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
#!/usr/bin/env python3
"""GPU connection tests for Colab, HF Spaces, and local backends.

Tests device detection, mixed precision, model placement, forward pass,
and backward pass on all available GPU targets.

Usage:
    python3 scripts/gpu_connection_test.py              # auto-detect
    python3 scripts/gpu_connection_test.py --target cuda # force CUDA
    python3 scripts/gpu_connection_test.py --target mps  # force MPS
    python3 scripts/gpu_connection_test.py --target cpu  # force CPU
    python3 scripts/gpu_connection_test.py --full        # include training loop test
"""
from __future__ import annotations

import argparse
import sys
import time
from dataclasses import dataclass
from typing import List, Tuple


@dataclass
class TestResult:
    name: str
    passed: bool
    detail: str
    elapsed_ms: float = 0.0


def _run_test(name: str, fn) -> TestResult:
    t0 = time.time()
    try:
        detail = fn()
        elapsed = (time.time() - t0) * 1000
        return TestResult(name, True, detail, elapsed)
    except Exception as e:
        elapsed = (time.time() - t0) * 1000
        return TestResult(name, False, str(e), elapsed)


def test_torch_import() -> str:
    import torch
    return f"torch {torch.__version__}"


def test_cuda_available() -> str:
    import torch
    if not torch.cuda.is_available():
        return "CUDA not available (expected on MPS/CPU)"
    name = torch.cuda.get_device_name(0)
    cap = torch.cuda.get_device_capability()
    mem = torch.cuda.get_device_properties(0).total_mem / 1e9
    return f"{name}, compute {cap[0]}.{cap[1]}, {mem:.1f}GB"


def test_mps_available() -> str:
    import torch
    if not hasattr(torch.backends, 'mps') or not torch.backends.mps.is_available():
        return "MPS not available (expected on Linux/Colab)"
    return "MPS available"


def test_accelerate_import() -> str:
    from accelerate import Accelerator
    acc = Accelerator()
    return f"device={acc.device}, mp={acc.mixed_precision}"


def test_device_resolution() -> str:
    import torch
    if torch.cuda.is_available():
        return "cuda"
    if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        return "mps"
    return "cpu"


def test_mixed_precision_support() -> str:
    import torch
    if not torch.cuda.is_available():
        return "skipped (no CUDA)"
    cap = torch.cuda.get_device_capability()
    if cap[0] >= 8:
        # Test bf16
        x = torch.randn(4, 4, device="cuda", dtype=torch.bfloat16)
        y = x @ x.T
        return f"bf16 supported (compute {cap[0]}.{cap[1]})"
    # Test fp16
    x = torch.randn(4, 4, device="cuda", dtype=torch.float16)
    y = x @ x.T
    return f"fp16 supported (compute {cap[0]}.{cap[1]})"


def test_model_placement(target: str) -> str:
    import torch
    from training.core.kan_jepa_generator import create_kan_jepa_model
    from training.core.bidirectional_generator import SimpleVocab

    device = torch.device(target)
    model = create_kan_jepa_model(100, "small")
    model = model.to(device)
    n_params = sum(p.numel() for p in model.parameters())
    actual_dev = next(model.parameters()).device
    return f"{n_params:,} params on {actual_dev}"


def test_forward_pass(target: str) -> str:
    import torch
    from training.core.kan_jepa_generator import create_kan_jepa_model

    device = torch.device(target)
    model = create_kan_jepa_model(100, "small").to(device)
    model.eval()

    src = torch.randint(1, 50, (2, 10), device=device)
    tgt = torch.randint(1, 50, (2, 8), device=device)

    with torch.no_grad():
        logits, info = model(src, tgt)

    return f"logits={list(logits.shape)}, jepa_loss={info['jepa_loss'].item():.4f}"


def test_backward_pass(target: str) -> str:
    import torch
    from training.core.kan_jepa_generator import create_kan_jepa_model

    device = torch.device(target)
    model = create_kan_jepa_model(100, "small").to(device)
    model.train()

    src = torch.randint(1, 50, (4, 12), device=device)
    tgt = torch.randint(1, 50, (4, 10), device=device)

    logits, info = model(src, tgt[:, :-1])
    loss = logits.mean() + info["jepa_loss"]
    loss.backward()

    grad_norms = []
    for p in model.parameters():
        if p.grad is not None:
            grad_norms.append(p.grad.norm().item())

    return f"loss={loss.item():.4f}, grad_params={len(grad_norms)}, max_grad={max(grad_norms):.4f}"


def test_mixed_precision_forward(target: str) -> str:
    import torch
    if target != "cuda":
        return "skipped (CUDA only)"

    from training.core.kan_jepa_generator import create_kan_jepa_model

    device = torch.device("cuda")
    model = create_kan_jepa_model(100, "small").to(device)
    model.train()

    cap = torch.cuda.get_device_capability()
    dtype = torch.bfloat16 if cap[0] >= 8 else torch.float16

    src = torch.randint(1, 50, (4, 12), device=device)
    tgt = torch.randint(1, 50, (4, 10), device=device)

    with torch.autocast(device_type="cuda", dtype=dtype):
        logits, info = model(src, tgt[:, :-1])
        loss = logits.mean() + info["jepa_loss"]

    loss.backward()
    return f"autocast {dtype} OK, loss={loss.item():.4f}"


def test_accelerate_training(target: str) -> str:
    try:
        from training.core.accelerate_trainer import AccelerateTrainer, AccelerateConfig
    except ImportError:
        return "accelerate_trainer not available"
    from training.core.kan_jepa_generator import create_kan_jepa_model
    from training.core.bidirectional_generator import SimpleVocab

    pairs = [
        ("Find all nodes", "MATCH (n) RETURN n"),
        ("Count people", "MATCH (p:Person) RETURN count(p)"),
        ("Find movies", "MATCH (m:Movie) RETURN m"),
        ("Who knows who", "MATCH (a)-[:KNOWS]->(b) RETURN a, b"),
    ]
    vocab = SimpleVocab.build_from_corpus(
        [t for p in pairs for t in p], max_size=100)
    model = create_kan_jepa_model(len(vocab), "small")

    cfg = AccelerateConfig(
        epochs=3, batch_size=2, gradient_accumulation_steps=1,
        mixed_precision="no", log_every=1, eval_samples=0)

    trainer = AccelerateTrainer(model, vocab, pairs, cfg)
    result = trainer.train(verbose=False)

    return f"3 epochs OK, loss={result['final_loss']:.4f}, {result['training_time_s']:.1f}s on {result['device']}"


def test_colab_detection() -> str:
    """Detect if running inside Google Colab."""
    try:
        import google.colab  # noqa: F401
        return "running in Colab"
    except ImportError:
        return "not in Colab (local environment)"


def test_hf_space_detection() -> str:
    """Detect if running inside HuggingFace Spaces."""
    import os
    if os.environ.get("SPACE_ID"):
        return f"HF Space: {os.environ['SPACE_ID']}"
    return "not in HF Spaces"


def test_modular_max() -> str:
    """Test Modular MAX / Mojo availability."""
    try:
        import max as _max
        ver = getattr(_max, "__version__", "unknown")
        return f"MAX {ver} available"
    except ImportError:
        return "MAX not installed (pip install modular)"


def test_mlx_available() -> str:
    """Test Apple MLX framework."""
    try:
        import mlx.core as mx
        ver = mx.__version__ if hasattr(mx, "__version__") else "unknown"
        # Quick compute test
        a = mx.ones((4, 4))
        b = mx.ones((4, 4))
        c = a @ b
        mx.eval(c)
        return f"MLX {ver}, matmul OK, unified memory"
    except ImportError:
        return "MLX not installed (pip install mlx)"
    except Exception as e:
        return f"MLX import OK but compute failed: {e}"


def test_snowflake_available() -> str:
    """Test Snowflake ML SDK availability."""
    import os
    if os.environ.get("SNOWFLAKE_ACCOUNT"):
        return f"SPCS environment: {os.environ['SNOWFLAKE_ACCOUNT']}"
    try:
        import snowflake.ml  # noqa: F401
        return "snowflake-ml-python installed (set SNOWFLAKE_ACCOUNT to connect)"
    except ImportError:
        return "snowflake-ml not installed (pip install snowflake-ml-python)"


def test_unified_backend() -> str:
    """Test unified backend detection across all 7 backends."""
    import sys, os
    sys.path.insert(0, os.getcwd())
    from training.core.unified_backend import detect_backend, probe_all_backends
    backend = detect_backend()
    all_infos = probe_all_backends()
    available = [i.name for i in all_infos if i.available]
    unavailable = [i.name for i in all_infos if not i.available]
    return (f"selected={backend.name}, "
            f"available=[{', '.join(available)}], "
            f"not_found=[{', '.join(unavailable)}]")


def test_memory_estimate() -> str:
    """Estimate GPU memory for full training."""
    import torch
    from training.core.kan_jepa_generator import create_kan_jepa_model

    model = create_kan_jepa_model(2000, "text2cypher")
    n_params = sum(p.numel() for p in model.parameters())
    # fp32: 4 bytes/param. With Adam: ~3x (params + grads + 2 momentum)
    mem_fp32 = n_params * 4 * 3 / 1e6  # MB
    mem_fp16 = n_params * 2 * 3 / 1e6
    return f"{n_params:,} params, est. {mem_fp32:.0f}MB fp32 / {mem_fp16:.0f}MB fp16"


def run_all(target: str, full: bool = False) -> List[TestResult]:
    results = []

    # Environment detection
    results.append(_run_test("torch import", test_torch_import))
    results.append(_run_test("CUDA available", test_cuda_available))
    results.append(_run_test("MPS available", test_mps_available))
    results.append(_run_test("Colab detection", test_colab_detection))
    results.append(_run_test("HF Space detection", test_hf_space_detection))
    results.append(_run_test("device resolution", test_device_resolution))
    results.append(_run_test("accelerate import", test_accelerate_import))
    results.append(_run_test("mixed precision support", test_mixed_precision_support))
    results.append(_run_test("Modular MAX / Mojo", test_modular_max))
    results.append(_run_test("Apple MLX", test_mlx_available))
    results.append(_run_test("Snowflake SPCS", test_snowflake_available))
    results.append(_run_test("unified backend (7 adapters)", test_unified_backend))
    results.append(_run_test("memory estimate", test_memory_estimate))

    # Model tests on target device
    results.append(_run_test(f"model placement [{target}]",
                             lambda: test_model_placement(target)))
    results.append(_run_test(f"forward pass [{target}]",
                             lambda: test_forward_pass(target)))
    results.append(_run_test(f"backward pass [{target}]",
                             lambda: test_backward_pass(target)))
    results.append(_run_test(f"mixed precision fwd [{target}]",
                             lambda: test_mixed_precision_forward(target)))

    if full:
        results.append(_run_test(f"accelerate training [{target}]",
                                 lambda: test_accelerate_training(target)))

    return results


def main():
    parser = argparse.ArgumentParser(description="GPU connection tests")
    parser.add_argument("--target", choices=["auto", "cuda", "mps", "cpu"],
                        default="auto", help="Target device")
    parser.add_argument("--full", action="store_true",
                        help="Include training loop test")
    args = parser.parse_args()

    # Resolve target
    if args.target == "auto":
        import torch
        if torch.cuda.is_available():
            target = "cuda"
        elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
            target = "mps"
        else:
            target = "cpu"
    else:
        target = args.target

    print(f"=== GPU Connection Tests (target: {target}) ===\n")

    results = run_all(target, args.full)

    # Print results
    passed = sum(1 for r in results if r.passed)
    total = len(results)

    for r in results:
        status = "PASS" if r.passed else "FAIL"
        print(f"  [{status}] {r.name}: {r.detail} ({r.elapsed_ms:.0f}ms)")

    print(f"\n{passed}/{total} tests passed")

    if passed < total:
        failed = [r for r in results if not r.passed]
        print("\nFailed tests:")
        for r in failed:
            print(f"  - {r.name}: {r.detail}")
        sys.exit(1)


if __name__ == "__main__":
    main()