File size: 2,083 Bytes
951f760
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#!/usr/bin/env python3
from __future__ import annotations

import json
import shutil

import torch


PACKAGE_MAP = {
    "mamba_ssm": "mamba-ssm",
    "transformers": "transformers",
}


def build_install_command(*, missing_dependencies: list[str]) -> list[str]:
    packages = [PACKAGE_MAP.get(name, name) for name in missing_dependencies]
    return [] if not packages else ["python", "-m", "pip", "install", *packages]


def diagnose_install_blockers(

    *,

    missing_dependencies: list[str],

    torch_version: str,

    cuda_available: bool,

    nvcc_present: bool,

) -> list[str]:
    blockers: list[str] = []
    if "mamba_ssm" in missing_dependencies:
        if "+cpu" in torch_version or not cuda_available:
            blockers.append("mamba_ssm install likely blocked by CPU-only torch runtime")
        if not nvcc_present:
            blockers.append("mamba_ssm install likely blocked because nvcc is unavailable")
    return blockers


def build_bootstrap_report(*, missing_dependencies: list[str]) -> dict[str, object]:
    ready = len(missing_dependencies) == 0
    packages = [PACKAGE_MAP.get(name, name) for name in missing_dependencies]
    install_hint = "" if ready else f"Install missing benchmark dependencies: {', '.join(packages)}"
    blockers = diagnose_install_blockers(
        missing_dependencies=missing_dependencies,
        torch_version=getattr(torch, "__version__", "unknown"),
        cuda_available=torch.cuda.is_available(),
        nvcc_present=shutil.which("nvcc") is not None,
    )
    return {
        "ready": ready,
        "missing_dependencies": list(missing_dependencies),
        "install_hint": install_hint,
        "install_command": build_install_command(missing_dependencies=missing_dependencies),
        "install_blockers": blockers,
    }


def main() -> int:
    report = build_bootstrap_report(missing_dependencies=["mamba_ssm"])
    print(json.dumps(report, indent=2, sort_keys=True))
    return 0


if __name__ == "__main__":
    raise SystemExit(main())