File size: 10,834 Bytes
3386f25
 
 
4aca758
20cfecf
1282ba1
3386f25
 
 
4aca758
1282ba1
3386f25
 
 
 
4aca758
37b20e3
4aca758
5de41f0
 
 
 
 
 
 
3386f25
1282ba1
 
20cfecf
 
 
 
 
 
 
 
 
37b20e3
 
20cfecf
d9cf8c2
 
37b20e3
 
 
 
20cfecf
37b20e3
 
 
 
 
 
 
 
 
d9cf8c2
37b20e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20cfecf
 
37b20e3
 
20cfecf
 
 
37b20e3
 
20cfecf
1282ba1
3386f25
 
 
 
 
 
 
1282ba1
 
 
 
 
 
3386f25
 
 
 
 
1282ba1
3386f25
 
 
1282ba1
 
d9cf8c2
20cfecf
1282ba1
 
 
 
 
 
 
 
d9cf8c2
 
 
 
20cfecf
5de41f0
1dda790
5de41f0
 
1dda790
 
 
 
 
 
 
 
 
 
5de41f0
 
 
 
 
20cfecf
d9cf8c2
20cfecf
1282ba1
 
d9cf8c2
 
1282ba1
 
d9cf8c2
1282ba1
d9cf8c2
 
1282ba1
d9cf8c2
1282ba1
 
 
 
 
d9cf8c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1282ba1
37b20e3
 
 
 
 
 
 
 
 
 
1282ba1
 
d9cf8c2
1282ba1
d9cf8c2
1282ba1
d9cf8c2
 
 
 
 
 
 
 
1282ba1
 
 
 
 
 
 
 
4aca758
 
 
 
d9cf8c2
1282ba1
4aca758
 
 
 
 
 
d9cf8c2
 
 
 
 
 
 
 
 
 
 
 
1282ba1
 
 
 
1dda790
 
 
 
 
 
4aca758
 
 
3386f25
 
 
 
4aca758
 
 
 
 
 
 
 
 
 
1282ba1
d9cf8c2
 
 
1282ba1
37a2a42
1282ba1
d9cf8c2
 
1dda790
 
 
 
 
 
 
3386f25
 
 
 
 
 
4aca758
1282ba1
3386f25
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
"""
StyleForge CUDA Kernels Package
Custom CUDA kernels for accelerated neural style transfer.

For ZeroGPU/HuggingFace: Pre-compiled kernels are downloaded from HF dataset.
For local: Kernels are JIT-compiled if prebuilt not available.
"""

import torch
import os
from pathlib import Path

# Try to import CUDA kernels, fall back gracefully
_CUDA_KERNELS_AVAILABLE = False
_FusedInstanceNorm2d = None
_KERNELS_COMPILED = False
_LOADED_KERNEL_FUNC = None

# Check if running on ZeroGPU or HuggingFace Spaces
# Use the same detection as app.py - presence of spaces package
try:
    from spaces import GPU
    _ZERO_GPU = True
except ImportError:
    _ZERO_GPU = False

# Path to pre-compiled kernels
_PREBUILT_PATH = Path(__file__).parent / "prebuilt"
_PREBUILT_PATH.mkdir(exist_ok=True)

# HuggingFace dataset for prebuilt kernels
_KERNEL_DATASET = "oliau/styleforge-kernels"  # You'll need to create this dataset


def _download_kernels_from_dataset():
    """Download pre-compiled kernels from HuggingFace dataset."""
    try:
        from huggingface_hub import hf_hub_download
        import sys

        print(f"Looking for kernels in dataset: {_KERNEL_DATASET}")

        # Known kernel file name
        kernel_file = "fused_instance_norm.so"

        # Download directly to the kernels directory
        try:
            local_path = hf_hub_download(
                repo_id=_KERNEL_DATASET,
                filename=kernel_file,
                repo_type="dataset",
                local_dir=str(_PREBUILT_PATH.parent),
                local_dir_use_symlinks=False
            )
            print(f"Successfully downloaded kernel: {kernel_file} -> {local_path}")
            return True
        except Exception as e:
            print(f"Failed to download {kernel_file}: {e}")
            # Try alternative paths in case the file is in a subdirectory
            for subdir in ["", "kernels/", "prebuilt/", "build/"]:
                try:
                    alt_path = subdir + kernel_file
                    local_path = hf_hub_download(
                        repo_id=_KERNEL_DATASET,
                        filename=alt_path,
                        repo_type="dataset",
                        local_dir=str(_PREBUILT_PATH.parent),
                        local_dir_use_symlinks=False
                    )
                    print(f"Successfully downloaded kernel from {alt_path}: {local_path}")
                    return True
                except Exception:
                    continue
            return False

    except ImportError as e:
        print(f"huggingface_hub not available: {e}")
        return False
    except Exception as e:
        print(f"Failed to download kernels from dataset: {e}")
        import traceback
        traceback.print_exc()
        return False


def check_cuda_kernels():
    """Check if CUDA kernels are available."""
    return _CUDA_KERNELS_AVAILABLE


def get_fused_instance_norm(num_features, **kwargs):
    """
    Get FusedInstanceNorm2d module or PyTorch fallback.

    On ZeroGPU: Uses pre-compiled kernels if available.
    On local: May use custom fused kernels (prebuilt or JIT).
    """
    if _FusedInstanceNorm2d is not None:
        try:
            return _FusedInstanceNorm2d(num_features, **kwargs)
        except Exception:
            pass
    # Fallback to PyTorch (still GPU-accelerated, just not custom fused)
    return torch.nn.InstanceNorm2d(num_features, affine=kwargs.get('affine', True))


def load_prebuilt_kernels():
    """
    Try to load pre-compiled CUDA kernels from the kernels directory.
    On HuggingFace, downloads from dataset if local files not found.

    Returns True if successful, False otherwise.
    """
    global _FusedInstanceNorm2d, _CUDA_KERNELS_AVAILABLE, _KERNELS_COMPILED

    if _KERNELS_COMPILED:
        return _CUDA_KERNELS_AVAILABLE

    # Check for kernels in the kernels directory (parent of prebuilt) and prebuilt/
    kernels_dir = Path(__file__).parent
    kernel_files = list(kernels_dir.glob("*.so")) + list(kernels_dir.glob("*.pyd"))
    kernel_files += list(_PREBUILT_PATH.glob("*.so")) + list(_PREBUILT_PATH.glob("*.pyd"))

    # Try downloading from dataset if not found locally (on ZeroGPU or if CUDA available)
    # IMPORTANT: Don't call torch.cuda.is_available() on ZeroGPU at module level!
    if not kernel_files:
        print(f"No local pre-compiled kernels found. _ZERO_GPU={_ZERO_GPU}")
        # On ZeroGPU, always try to download without checking CUDA
        # On local, check CUDA first before downloading
        should_download = _ZERO_GPU
        if not _ZERO_GPU:
            try:
                should_download = torch.cuda.is_available()
            except:
                should_download = False

        if should_download:
            print("Trying HuggingFace dataset...")
            if _download_kernels_from_dataset():
                # Check again after download - look in kernels directory
                kernel_files = list(kernels_dir.glob("*.so")) + list(kernels_dir.glob("*.pyd"))
                kernel_files += list(_PREBUILT_PATH.glob("*.so")) + list(_PREBUILT_PATH.glob("*.pyd"))

    if not kernel_files:
        print("No pre-compiled kernels found")
        return False

    print(f"Found kernel files: {[f.name for f in kernel_files]}")

    try:
        import sys
        import ctypes

        # Try to load each kernel file
        for kernel_file in kernel_files:
            try:
                # First try to load as a Python extension module
                module_name = kernel_file.stem
                spec = __import__('importlib.util').util.spec_from_file_location(module_name, kernel_file)
                if spec and spec.loader:
                    mod = __import__('importlib.util').util.module_from_spec(spec)
                    spec.loader.exec_module(mod)
                    print(f"Loaded pre-compiled kernel module: {kernel_file.name}")

                    # Check what functions are available in the module
                    available_funcs = [attr for attr in dir(mod) if not attr.startswith('_')]
                    print(f"Available functions in kernel: {available_funcs}")

                    # Try to find the forward function with common naming patterns
                    forward_func = None
                    for func_name in ['fused_instance_norm_forward', 'forward', 'fused_instance_norm',
                                      'instance_norm_forward', 'fused_inst_norm']:
                        if hasattr(mod, func_name):
                            forward_func = getattr(mod, func_name)
                            print(f"Using function: {func_name}")
                            break

                    if forward_func is None:
                        print(f"Warning: No suitable forward function found in {kernel_file.name}")
                        continue

                    # Store the kernel function globally for use with FusedInstanceNorm2d
                    _LOADED_KERNEL_FUNC = forward_func

                    # Create factory function that uses the wrapper with pre-loaded kernel
                    def make_fused_instance_norm(num_features, **kwargs):
                        from .instance_norm_wrapper import FusedInstanceNorm2d
                        # Pass the pre-loaded kernel function
                        return FusedInstanceNorm2d(num_features, kernel_func=forward_func, **kwargs)

                    _FusedInstanceNorm2d = make_fused_instance_norm
                    _CUDA_KERNELS_AVAILABLE = True
                    _KERNELS_COMPILED = True
                    print(f"Successfully initialized FusedInstanceNorm2d from {kernel_file.name}")
                    return True

            except Exception as e:
                print(f"Failed to load {kernel_file.name} as Python module: {e}")
                # Try loading as raw ctypes library
                try:
                    lib = ctypes.CDLL(str(kernel_file))
                    print(f"Loaded {kernel_file.name} as ctypes library")
                    # Could add ctypes wrapper here if needed
                except Exception as e2:
                    print(f"Failed to load {kernel_file.name} as ctypes: {e2}")
                continue

    except Exception as e:
        print(f"Failed to load prebuilt kernels: {e}")

    return False


def compile_kernels():
    """
    Compile CUDA kernels on-demand.

    On ZeroGPU: Downloads pre-compiled kernels from dataset.
    On local: Compiles custom CUDA kernels.
    """
    global _CUDA_KERNELS_AVAILABLE, _FusedInstanceNorm2d, _KERNELS_COMPILED

    if _KERNELS_COMPILED:
        return _CUDA_KERNELS_AVAILABLE

    # On ZeroGPU, try to download pre-compiled kernels from dataset
    if _ZERO_GPU:
        print("ZeroGPU mode: Attempting to download pre-compiled kernels from dataset...")
        if load_prebuilt_kernels():
            print("Successfully loaded pre-compiled CUDA kernels from dataset!")
            return True
        else:
            print("No pre-compiled kernels found in dataset, using PyTorch GPU fallback")
            _KERNELS_COMPILED = True
            return False

    # First, try pre-compiled kernels (for local too)
    if load_prebuilt_kernels():
        print("Using pre-compiled CUDA kernels!")
        return True

    # Check CUDA availability (safe here since we're not on ZeroGPU)
    try:
        if not torch.cuda.is_available():
            _KERNELS_COMPILED = True
            return False
    except:
        _KERNELS_COMPILED = True
        return False

    try:
        from .instance_norm_wrapper import FusedInstanceNorm2d
        _FusedInstanceNorm2d = FusedInstanceNorm2d
        _CUDA_KERNELS_AVAILABLE = True
        _KERNELS_COMPILED = True
        print("CUDA kernels compiled successfully!")
        return True
    except Exception as e:
        print(f"Failed to compile CUDA kernels: {e}")
        print("Using PyTorch InstanceNorm2d fallback")
        _KERNELS_COMPILED = True
        return False


# Auto-compile on import for non-ZeroGPU environments with CUDA
if _ZERO_GPU:
    # On ZeroGPU, try to download pre-compiled kernels
    print("ZeroGPU detected: Attempting to load pre-compiled kernels from dataset...")
    if load_prebuilt_kernels():
        print("Using pre-compiled CUDA kernels from dataset!")
    else:
        print("No pre-compiled kernels available, using PyTorch GPU fallback")
    _KERNELS_COMPILED = True
elif not _ZERO_GPU:
    # On local, check if CUDA is available and compile
    try:
        if torch.cuda.is_available():
            compile_kernels()
    except:
        _KERNELS_COMPILED = True


__all__ = [
    'check_cuda_kernels',
    'get_fused_instance_norm',
    'FusedInstanceNorm2d',
    'compile_kernels',
    'load_prebuilt_kernels',
]