Spaces:
Sleeping
Sleeping
File size: 3,123 Bytes
3386f25 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
"""
Minimal CUDA build utilities for Hugging Face Spaces
"""
import torch
from pathlib import Path
from typing import List, Optional
from torch.utils.cpp_extension import load_inline
# Global module cache
_COMPILED_MODULES = {}
def compile_inline(
name: str,
cuda_source: str,
cpp_source: str = '',
functions: Optional[List[str]] = None,
build_directory: Optional[Path] = None,
verbose: bool = False,
) -> any:
"""
Compile CUDA code inline using PyTorch's JIT compilation.
"""
import time
if name in _COMPILED_MODULES:
return _COMPILED_MODULES[name]
if verbose:
print(f"Compiling {name}...")
start_time = time.time()
# Get CUDA build flags
cuda_info = get_cuda_info()
extra_cuda_cflags = cuda_info.get('extra_cuda_cflags', ['-O3'])
try:
# Try with with_pybind11 (newer PyTorch)
try:
module = load_inline(
name=name,
cpp_sources=[cpp_source] if cpp_source else [],
cuda_sources=[cuda_source] if cuda_source else [],
extra_cuda_cflags=extra_cuda_cflags,
verbose=verbose,
with_pybind11=True
)
except TypeError:
# Fall back to older PyTorch API
module = load_inline(
name=name,
cpp_sources=[cpp_source] if cpp_source else [],
cuda_sources=[cuda_source] if cuda_source else [],
extra_cuda_cflags=extra_cuda_cflags,
verbose=verbose,
)
elapsed = time.time() - start_time
if verbose:
print(f"{name} compiled successfully in {elapsed:.2f}s")
_COMPILED_MODULES[name] = module
return module
except Exception as e:
if verbose:
print(f"Failed to compile {name}: {e}")
raise
def get_cuda_info() -> dict:
"""Get CUDA system information."""
info = {
'cuda_available': torch.cuda.is_available(),
'cuda_version': torch.version.cuda,
'pytorch_version': torch.__version__,
}
if torch.cuda.is_available():
major, minor = torch.cuda.get_device_capability(0)
info['compute_capability'] = f"{major}.{minor}"
info['device_name'] = torch.cuda.get_device_name(0)
# Architecture-specific flags
extra_cuda_cflags = ['-O3', '--use_fast_math']
# Common architectures
if major >= 7:
extra_cuda_cflags.append('-gencode=arch=compute_70,code=sm_70')
if major >= 7 or (major == 7 and minor >= 5):
extra_cuda_cflags.append('-gencode=arch=compute_75,code=sm_75')
if major >= 8:
extra_cuda_cflags.append('-gencode=arch=compute_80,code=sm_80')
extra_cuda_cflags.append('-gencode=arch=compute_86,code=sm_86')
if major >= 9 or (major == 8 and minor >= 9):
extra_cuda_cflags.append('-gencode=arch=compute_89,code=sm_89')
info['extra_cuda_cflags'] = extra_cuda_cflags
else:
info['extra_cuda_cflags'] = ['-O3']
return info
|