feather-a10-runtime / overlay /scripts /bootstrap_benchmark_env.py
Jackoatmon's picture
Update Feather training runtime image
951f760 verified
#!/usr/bin/env python3
from __future__ import annotations
import json
import shutil
import torch
PACKAGE_MAP = {
"mamba_ssm": "mamba-ssm",
"transformers": "transformers",
}
def build_install_command(*, missing_dependencies: list[str]) -> list[str]:
packages = [PACKAGE_MAP.get(name, name) for name in missing_dependencies]
return [] if not packages else ["python", "-m", "pip", "install", *packages]
def diagnose_install_blockers(
*,
missing_dependencies: list[str],
torch_version: str,
cuda_available: bool,
nvcc_present: bool,
) -> list[str]:
blockers: list[str] = []
if "mamba_ssm" in missing_dependencies:
if "+cpu" in torch_version or not cuda_available:
blockers.append("mamba_ssm install likely blocked by CPU-only torch runtime")
if not nvcc_present:
blockers.append("mamba_ssm install likely blocked because nvcc is unavailable")
return blockers
def build_bootstrap_report(*, missing_dependencies: list[str]) -> dict[str, object]:
ready = len(missing_dependencies) == 0
packages = [PACKAGE_MAP.get(name, name) for name in missing_dependencies]
install_hint = "" if ready else f"Install missing benchmark dependencies: {', '.join(packages)}"
blockers = diagnose_install_blockers(
missing_dependencies=missing_dependencies,
torch_version=getattr(torch, "__version__", "unknown"),
cuda_available=torch.cuda.is_available(),
nvcc_present=shutil.which("nvcc") is not None,
)
return {
"ready": ready,
"missing_dependencies": list(missing_dependencies),
"install_hint": install_hint,
"install_command": build_install_command(missing_dependencies=missing_dependencies),
"install_blockers": blockers,
}
def main() -> int:
report = build_bootstrap_report(missing_dependencies=["mamba_ssm"])
print(json.dumps(report, indent=2, sort_keys=True))
return 0
if __name__ == "__main__":
raise SystemExit(main())