File size: 3,123 Bytes
3386f25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""
Minimal CUDA build utilities for Hugging Face Spaces
"""

import torch
from pathlib import Path
from typing import List, Optional
from torch.utils.cpp_extension import load_inline

# Global module cache
_COMPILED_MODULES = {}


def compile_inline(
    name: str,
    cuda_source: str,
    cpp_source: str = '',
    functions: Optional[List[str]] = None,
    build_directory: Optional[Path] = None,
    verbose: bool = False,
) -> any:
    """
    Compile CUDA code inline using PyTorch's JIT compilation.
    """
    import time

    if name in _COMPILED_MODULES:
        return _COMPILED_MODULES[name]

    if verbose:
        print(f"Compiling {name}...")

    start_time = time.time()

    # Get CUDA build flags
    cuda_info = get_cuda_info()
    extra_cuda_cflags = cuda_info.get('extra_cuda_cflags', ['-O3'])

    try:
        # Try with with_pybind11 (newer PyTorch)
        try:
            module = load_inline(
                name=name,
                cpp_sources=[cpp_source] if cpp_source else [],
                cuda_sources=[cuda_source] if cuda_source else [],
                extra_cuda_cflags=extra_cuda_cflags,
                verbose=verbose,
                with_pybind11=True
            )
        except TypeError:
            # Fall back to older PyTorch API
            module = load_inline(
                name=name,
                cpp_sources=[cpp_source] if cpp_source else [],
                cuda_sources=[cuda_source] if cuda_source else [],
                extra_cuda_cflags=extra_cuda_cflags,
                verbose=verbose,
            )

        elapsed = time.time() - start_time

        if verbose:
            print(f"{name} compiled successfully in {elapsed:.2f}s")

        _COMPILED_MODULES[name] = module
        return module

    except Exception as e:
        if verbose:
            print(f"Failed to compile {name}: {e}")
        raise


def get_cuda_info() -> dict:
    """Get CUDA system information."""
    info = {
        'cuda_available': torch.cuda.is_available(),
        'cuda_version': torch.version.cuda,
        'pytorch_version': torch.__version__,
    }

    if torch.cuda.is_available():
        major, minor = torch.cuda.get_device_capability(0)
        info['compute_capability'] = f"{major}.{minor}"
        info['device_name'] = torch.cuda.get_device_name(0)

        # Architecture-specific flags
        extra_cuda_cflags = ['-O3', '--use_fast_math']

        # Common architectures
        if major >= 7:
            extra_cuda_cflags.append('-gencode=arch=compute_70,code=sm_70')
        if major >= 7 or (major == 7 and minor >= 5):
            extra_cuda_cflags.append('-gencode=arch=compute_75,code=sm_75')
        if major >= 8:
            extra_cuda_cflags.append('-gencode=arch=compute_80,code=sm_80')
            extra_cuda_cflags.append('-gencode=arch=compute_86,code=sm_86')
        if major >= 9 or (major == 8 and minor >= 9):
            extra_cuda_cflags.append('-gencode=arch=compute_89,code=sm_89')

        info['extra_cuda_cflags'] = extra_cuda_cflags

    else:
        info['extra_cuda_cflags'] = ['-O3']

    return info