Spaces:
Running
on
Zero
Running
on
Zero
File size: 10,834 Bytes
3386f25 4aca758 20cfecf 1282ba1 3386f25 4aca758 1282ba1 3386f25 4aca758 37b20e3 4aca758 5de41f0 3386f25 1282ba1 20cfecf 37b20e3 20cfecf d9cf8c2 37b20e3 20cfecf 37b20e3 d9cf8c2 37b20e3 20cfecf 37b20e3 20cfecf 37b20e3 20cfecf 1282ba1 3386f25 1282ba1 3386f25 1282ba1 3386f25 1282ba1 d9cf8c2 20cfecf 1282ba1 d9cf8c2 20cfecf 5de41f0 1dda790 5de41f0 1dda790 5de41f0 20cfecf d9cf8c2 20cfecf 1282ba1 d9cf8c2 1282ba1 d9cf8c2 1282ba1 d9cf8c2 1282ba1 d9cf8c2 1282ba1 d9cf8c2 1282ba1 37b20e3 1282ba1 d9cf8c2 1282ba1 d9cf8c2 1282ba1 d9cf8c2 1282ba1 4aca758 d9cf8c2 1282ba1 4aca758 d9cf8c2 1282ba1 1dda790 4aca758 3386f25 4aca758 1282ba1 d9cf8c2 1282ba1 37a2a42 1282ba1 d9cf8c2 1dda790 3386f25 4aca758 1282ba1 3386f25 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 |
"""
StyleForge CUDA Kernels Package
Custom CUDA kernels for accelerated neural style transfer.
For ZeroGPU/HuggingFace: Pre-compiled kernels are downloaded from HF dataset.
For local: Kernels are JIT-compiled if prebuilt not available.
"""
import torch
import os
from pathlib import Path
# Try to import CUDA kernels, fall back gracefully
_CUDA_KERNELS_AVAILABLE = False
_FusedInstanceNorm2d = None
_KERNELS_COMPILED = False
_LOADED_KERNEL_FUNC = None
# Check if running on ZeroGPU or HuggingFace Spaces
# Use the same detection as app.py - presence of spaces package
try:
from spaces import GPU
_ZERO_GPU = True
except ImportError:
_ZERO_GPU = False
# Path to pre-compiled kernels
_PREBUILT_PATH = Path(__file__).parent / "prebuilt"
_PREBUILT_PATH.mkdir(exist_ok=True)
# HuggingFace dataset for prebuilt kernels
_KERNEL_DATASET = "oliau/styleforge-kernels" # You'll need to create this dataset
def _download_kernels_from_dataset():
"""Download pre-compiled kernels from HuggingFace dataset."""
try:
from huggingface_hub import hf_hub_download
import sys
print(f"Looking for kernels in dataset: {_KERNEL_DATASET}")
# Known kernel file name
kernel_file = "fused_instance_norm.so"
# Download directly to the kernels directory
try:
local_path = hf_hub_download(
repo_id=_KERNEL_DATASET,
filename=kernel_file,
repo_type="dataset",
local_dir=str(_PREBUILT_PATH.parent),
local_dir_use_symlinks=False
)
print(f"Successfully downloaded kernel: {kernel_file} -> {local_path}")
return True
except Exception as e:
print(f"Failed to download {kernel_file}: {e}")
# Try alternative paths in case the file is in a subdirectory
for subdir in ["", "kernels/", "prebuilt/", "build/"]:
try:
alt_path = subdir + kernel_file
local_path = hf_hub_download(
repo_id=_KERNEL_DATASET,
filename=alt_path,
repo_type="dataset",
local_dir=str(_PREBUILT_PATH.parent),
local_dir_use_symlinks=False
)
print(f"Successfully downloaded kernel from {alt_path}: {local_path}")
return True
except Exception:
continue
return False
except ImportError as e:
print(f"huggingface_hub not available: {e}")
return False
except Exception as e:
print(f"Failed to download kernels from dataset: {e}")
import traceback
traceback.print_exc()
return False
def check_cuda_kernels():
"""Check if CUDA kernels are available."""
return _CUDA_KERNELS_AVAILABLE
def get_fused_instance_norm(num_features, **kwargs):
"""
Get FusedInstanceNorm2d module or PyTorch fallback.
On ZeroGPU: Uses pre-compiled kernels if available.
On local: May use custom fused kernels (prebuilt or JIT).
"""
if _FusedInstanceNorm2d is not None:
try:
return _FusedInstanceNorm2d(num_features, **kwargs)
except Exception:
pass
# Fallback to PyTorch (still GPU-accelerated, just not custom fused)
return torch.nn.InstanceNorm2d(num_features, affine=kwargs.get('affine', True))
def load_prebuilt_kernels():
"""
Try to load pre-compiled CUDA kernels from the kernels directory.
On HuggingFace, downloads from dataset if local files not found.
Returns True if successful, False otherwise.
"""
global _FusedInstanceNorm2d, _CUDA_KERNELS_AVAILABLE, _KERNELS_COMPILED
if _KERNELS_COMPILED:
return _CUDA_KERNELS_AVAILABLE
# Check for kernels in the kernels directory (parent of prebuilt) and prebuilt/
kernels_dir = Path(__file__).parent
kernel_files = list(kernels_dir.glob("*.so")) + list(kernels_dir.glob("*.pyd"))
kernel_files += list(_PREBUILT_PATH.glob("*.so")) + list(_PREBUILT_PATH.glob("*.pyd"))
# Try downloading from dataset if not found locally (on ZeroGPU or if CUDA available)
# IMPORTANT: Don't call torch.cuda.is_available() on ZeroGPU at module level!
if not kernel_files:
print(f"No local pre-compiled kernels found. _ZERO_GPU={_ZERO_GPU}")
# On ZeroGPU, always try to download without checking CUDA
# On local, check CUDA first before downloading
should_download = _ZERO_GPU
if not _ZERO_GPU:
try:
should_download = torch.cuda.is_available()
except:
should_download = False
if should_download:
print("Trying HuggingFace dataset...")
if _download_kernels_from_dataset():
# Check again after download - look in kernels directory
kernel_files = list(kernels_dir.glob("*.so")) + list(kernels_dir.glob("*.pyd"))
kernel_files += list(_PREBUILT_PATH.glob("*.so")) + list(_PREBUILT_PATH.glob("*.pyd"))
if not kernel_files:
print("No pre-compiled kernels found")
return False
print(f"Found kernel files: {[f.name for f in kernel_files]}")
try:
import sys
import ctypes
# Try to load each kernel file
for kernel_file in kernel_files:
try:
# First try to load as a Python extension module
module_name = kernel_file.stem
spec = __import__('importlib.util').util.spec_from_file_location(module_name, kernel_file)
if spec and spec.loader:
mod = __import__('importlib.util').util.module_from_spec(spec)
spec.loader.exec_module(mod)
print(f"Loaded pre-compiled kernel module: {kernel_file.name}")
# Check what functions are available in the module
available_funcs = [attr for attr in dir(mod) if not attr.startswith('_')]
print(f"Available functions in kernel: {available_funcs}")
# Try to find the forward function with common naming patterns
forward_func = None
for func_name in ['fused_instance_norm_forward', 'forward', 'fused_instance_norm',
'instance_norm_forward', 'fused_inst_norm']:
if hasattr(mod, func_name):
forward_func = getattr(mod, func_name)
print(f"Using function: {func_name}")
break
if forward_func is None:
print(f"Warning: No suitable forward function found in {kernel_file.name}")
continue
# Store the kernel function globally for use with FusedInstanceNorm2d
_LOADED_KERNEL_FUNC = forward_func
# Create factory function that uses the wrapper with pre-loaded kernel
def make_fused_instance_norm(num_features, **kwargs):
from .instance_norm_wrapper import FusedInstanceNorm2d
# Pass the pre-loaded kernel function
return FusedInstanceNorm2d(num_features, kernel_func=forward_func, **kwargs)
_FusedInstanceNorm2d = make_fused_instance_norm
_CUDA_KERNELS_AVAILABLE = True
_KERNELS_COMPILED = True
print(f"Successfully initialized FusedInstanceNorm2d from {kernel_file.name}")
return True
except Exception as e:
print(f"Failed to load {kernel_file.name} as Python module: {e}")
# Try loading as raw ctypes library
try:
lib = ctypes.CDLL(str(kernel_file))
print(f"Loaded {kernel_file.name} as ctypes library")
# Could add ctypes wrapper here if needed
except Exception as e2:
print(f"Failed to load {kernel_file.name} as ctypes: {e2}")
continue
except Exception as e:
print(f"Failed to load prebuilt kernels: {e}")
return False
def compile_kernels():
"""
Compile CUDA kernels on-demand.
On ZeroGPU: Downloads pre-compiled kernels from dataset.
On local: Compiles custom CUDA kernels.
"""
global _CUDA_KERNELS_AVAILABLE, _FusedInstanceNorm2d, _KERNELS_COMPILED
if _KERNELS_COMPILED:
return _CUDA_KERNELS_AVAILABLE
# On ZeroGPU, try to download pre-compiled kernels from dataset
if _ZERO_GPU:
print("ZeroGPU mode: Attempting to download pre-compiled kernels from dataset...")
if load_prebuilt_kernels():
print("Successfully loaded pre-compiled CUDA kernels from dataset!")
return True
else:
print("No pre-compiled kernels found in dataset, using PyTorch GPU fallback")
_KERNELS_COMPILED = True
return False
# First, try pre-compiled kernels (for local too)
if load_prebuilt_kernels():
print("Using pre-compiled CUDA kernels!")
return True
# Check CUDA availability (safe here since we're not on ZeroGPU)
try:
if not torch.cuda.is_available():
_KERNELS_COMPILED = True
return False
except:
_KERNELS_COMPILED = True
return False
try:
from .instance_norm_wrapper import FusedInstanceNorm2d
_FusedInstanceNorm2d = FusedInstanceNorm2d
_CUDA_KERNELS_AVAILABLE = True
_KERNELS_COMPILED = True
print("CUDA kernels compiled successfully!")
return True
except Exception as e:
print(f"Failed to compile CUDA kernels: {e}")
print("Using PyTorch InstanceNorm2d fallback")
_KERNELS_COMPILED = True
return False
# Auto-compile on import for non-ZeroGPU environments with CUDA
if _ZERO_GPU:
# On ZeroGPU, try to download pre-compiled kernels
print("ZeroGPU detected: Attempting to load pre-compiled kernels from dataset...")
if load_prebuilt_kernels():
print("Using pre-compiled CUDA kernels from dataset!")
else:
print("No pre-compiled kernels available, using PyTorch GPU fallback")
_KERNELS_COMPILED = True
elif not _ZERO_GPU:
# On local, check if CUDA is available and compile
try:
if torch.cuda.is_available():
compile_kernels()
except:
_KERNELS_COMPILED = True
__all__ = [
'check_cuda_kernels',
'get_fused_instance_norm',
'FusedInstanceNorm2d',
'compile_kernels',
'load_prebuilt_kernels',
]
|