File size: 2,221 Bytes
4303959
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from __future__ import annotations
import os
from typing import Optional

import torch


def env_true(name: str, default: bool = False) -> bool:
    v = os.getenv(name)
    if v is None:
        return default
    v = v.strip().lower()
    return v in ("1", "true", "yes", "on")


def env_int(name: str, default: int) -> int:
    try:
        return int(os.getenv(name, str(default)))
    except Exception:
        return default


def is_sm89(device: Optional[torch.device] = None) -> bool:
    dev = device or (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
    if dev.type != "cuda":
        return False
    try:
        cap = torch.cuda.get_device_capability(dev)
        return cap == (8, 9)
    except Exception:
        return False


def torch_triton_version_pairing_ok() -> bool:
    try:
        import triton  # noqa: F401

        tv = triton.__version__
    except ImportError:
        tv = "<none>"
    except Exception:
        tv = "<unknown>"
    try:
        tt = torch.__version__
    except Exception:
        tt = "<unknown>"
    # Basic heuristic: 2.2.x ↔ triton 2.2.x; 2.3.x ↔ 2.3.x; 2.4+ ↔ 3.x
    try:
        major_minor = ".".join((tt or "").split("+")[0].split(".")[:2])
        parts = major_minor.split(".")
        t_major = int(parts[0])
        t_minor = int(parts[1])
        if t_major != 2:
            return True  # do not gate non-2.x
        if t_minor in (2, 3):
            return tv.startswith(f"{t_minor}.")
        if t_minor >= 4:
            return tv.startswith("3.")
        return True
    except (ValueError, IndexError):
        return True


def execution_routing_summary() -> dict:
    """Return a snapshot of routing-related flags and runtime probes."""
    info = {
        "cuda": torch.cuda.is_available(),
        "sm89": is_sm89(),
        "torch": torch.__version__,
    }
    try:
        import triton

        info["triton"] = triton.__version__
    except Exception:
        info["triton"] = "<none>"
    info["NSA_USE_TRITON_SEL"] = env_true("NSA_USE_TRITON_SEL", False)
    info["NSA_TRITON_SEL_FORCE"] = env_true("NSA_TRITON_SEL_FORCE", False)
    info["NSA_USE_FA2"] = env_true("NSA_USE_FA2", False)
    return info