build-tools / bitsandbytes /cuda_specs.py
salmankhanpm's picture
Add files using upload-large-folder tool
dc9bb20 verified
import dataclasses
from functools import lru_cache
import logging
import re
import subprocess
from typing import Optional
import torch
@dataclasses.dataclass(frozen=True)
class CUDASpecs:
highest_compute_capability: tuple[int, int]
cuda_version_string: str
cuda_version_tuple: tuple[int, int]
@property
def has_imma(self) -> bool:
return torch.version.hip or self.highest_compute_capability >= (7, 5)
def get_compute_capabilities() -> list[tuple[int, int]]:
return sorted(torch.cuda.get_device_capability(torch.cuda.device(i)) for i in range(torch.cuda.device_count()))
@lru_cache(None)
def get_cuda_version_tuple() -> Optional[tuple[int, int]]:
"""Get CUDA/HIP version as a tuple of (major, minor)."""
try:
if torch.version.cuda:
version_str = torch.version.cuda
elif torch.version.hip:
version_str = torch.version.hip
else:
return None
parts = version_str.split(".")
if len(parts) >= 2:
return tuple(map(int, parts[:2]))
return None
except (AttributeError, ValueError, IndexError):
return None
def get_cuda_version_string() -> Optional[str]:
"""Get CUDA/HIP version as a string."""
version_tuple = get_cuda_version_tuple()
if version_tuple is None:
return None
major, minor = version_tuple
return f"{major * 10 + minor}"
def get_cuda_specs() -> Optional[CUDASpecs]:
"""Get CUDA/HIP specifications."""
if not torch.cuda.is_available():
return None
try:
compute_capabilities = get_compute_capabilities()
if not compute_capabilities:
return None
version_tuple = get_cuda_version_tuple()
if version_tuple is None:
return None
version_string = get_cuda_version_string()
if version_string is None:
return None
return CUDASpecs(
highest_compute_capability=compute_capabilities[-1],
cuda_version_string=version_string,
cuda_version_tuple=version_tuple,
)
except Exception:
return None
def get_rocm_gpu_arch() -> str:
"""Get ROCm GPU architecture."""
logger = logging.getLogger(__name__)
try:
if torch.version.hip:
result = subprocess.run(["rocminfo"], capture_output=True, text=True)
match = re.search(r"Name:\s+gfx([a-zA-Z\d]+)", result.stdout)
if match:
return "gfx" + match.group(1)
else:
return "unknown"
else:
return "unknown"
except Exception as e:
logger.error(f"Could not detect ROCm GPU architecture: {e}")
if torch.cuda.is_available():
logger.warning(
"""
ROCm GPU architecture detection failed despite ROCm being available.
""",
)
return "unknown"
def get_rocm_warpsize() -> int:
"""Get ROCm warp size."""
logger = logging.getLogger(__name__)
try:
if torch.version.hip:
result = subprocess.run(["rocminfo"], capture_output=True, text=True)
match = re.search(r"Wavefront Size:\s+([0-9]{2})\(0x[0-9]{2}\)", result.stdout)
if match:
return int(match.group(1))
else:
# default to 64 to be safe
return 64
else:
# nvidia cards always use 32 warp size
return 32
except Exception as e:
logger.error(f"Could not detect ROCm warp size: {e}. Defaulting to 64. (some 4-bit functions may not work!)")
if torch.cuda.is_available():
logger.warning(
"""
ROCm warp size detection failed despite ROCm being available.
""",
)
return 64