StyleForge / kernels /__init__.py
github-actions[bot]
Deploy from GitHub - 2026-01-21 11:50:26
1dda790
"""
StyleForge CUDA Kernels Package
Custom CUDA kernels for accelerated neural style transfer.
For ZeroGPU/HuggingFace: Pre-compiled kernels are downloaded from HF dataset.
For local: Kernels are JIT-compiled if prebuilt not available.
"""
import torch
import os
from pathlib import Path
# Try to import CUDA kernels, fall back gracefully
_CUDA_KERNELS_AVAILABLE = False
_FusedInstanceNorm2d = None
_KERNELS_COMPILED = False
_LOADED_KERNEL_FUNC = None
# Check if running on ZeroGPU or HuggingFace Spaces
# Use the same detection as app.py - presence of spaces package
try:
from spaces import GPU
_ZERO_GPU = True
except ImportError:
_ZERO_GPU = False
# Path to pre-compiled kernels
_PREBUILT_PATH = Path(__file__).parent / "prebuilt"
_PREBUILT_PATH.mkdir(exist_ok=True)
# HuggingFace dataset for prebuilt kernels
_KERNEL_DATASET = "oliau/styleforge-kernels" # You'll need to create this dataset
def _download_kernels_from_dataset():
"""Download pre-compiled kernels from HuggingFace dataset."""
try:
from huggingface_hub import hf_hub_download
import sys
print(f"Looking for kernels in dataset: {_KERNEL_DATASET}")
# Known kernel file name
kernel_file = "fused_instance_norm.so"
# Download directly to the kernels directory
try:
local_path = hf_hub_download(
repo_id=_KERNEL_DATASET,
filename=kernel_file,
repo_type="dataset",
local_dir=str(_PREBUILT_PATH.parent),
local_dir_use_symlinks=False
)
print(f"Successfully downloaded kernel: {kernel_file} -> {local_path}")
return True
except Exception as e:
print(f"Failed to download {kernel_file}: {e}")
# Try alternative paths in case the file is in a subdirectory
for subdir in ["", "kernels/", "prebuilt/", "build/"]:
try:
alt_path = subdir + kernel_file
local_path = hf_hub_download(
repo_id=_KERNEL_DATASET,
filename=alt_path,
repo_type="dataset",
local_dir=str(_PREBUILT_PATH.parent),
local_dir_use_symlinks=False
)
print(f"Successfully downloaded kernel from {alt_path}: {local_path}")
return True
except Exception:
continue
return False
except ImportError as e:
print(f"huggingface_hub not available: {e}")
return False
except Exception as e:
print(f"Failed to download kernels from dataset: {e}")
import traceback
traceback.print_exc()
return False
def check_cuda_kernels():
"""Check if CUDA kernels are available."""
return _CUDA_KERNELS_AVAILABLE
def get_fused_instance_norm(num_features, **kwargs):
"""
Get FusedInstanceNorm2d module or PyTorch fallback.
On ZeroGPU: Uses pre-compiled kernels if available.
On local: May use custom fused kernels (prebuilt or JIT).
"""
if _FusedInstanceNorm2d is not None:
try:
return _FusedInstanceNorm2d(num_features, **kwargs)
except Exception:
pass
# Fallback to PyTorch (still GPU-accelerated, just not custom fused)
return torch.nn.InstanceNorm2d(num_features, affine=kwargs.get('affine', True))
def load_prebuilt_kernels():
"""
Try to load pre-compiled CUDA kernels from the kernels directory.
On HuggingFace, downloads from dataset if local files not found.
Returns True if successful, False otherwise.
"""
global _FusedInstanceNorm2d, _CUDA_KERNELS_AVAILABLE, _KERNELS_COMPILED
if _KERNELS_COMPILED:
return _CUDA_KERNELS_AVAILABLE
# Check for kernels in the kernels directory (parent of prebuilt) and prebuilt/
kernels_dir = Path(__file__).parent
kernel_files = list(kernels_dir.glob("*.so")) + list(kernels_dir.glob("*.pyd"))
kernel_files += list(_PREBUILT_PATH.glob("*.so")) + list(_PREBUILT_PATH.glob("*.pyd"))
# Try downloading from dataset if not found locally (on ZeroGPU or if CUDA available)
# IMPORTANT: Don't call torch.cuda.is_available() on ZeroGPU at module level!
if not kernel_files:
print(f"No local pre-compiled kernels found. _ZERO_GPU={_ZERO_GPU}")
# On ZeroGPU, always try to download without checking CUDA
# On local, check CUDA first before downloading
should_download = _ZERO_GPU
if not _ZERO_GPU:
try:
should_download = torch.cuda.is_available()
except:
should_download = False
if should_download:
print("Trying HuggingFace dataset...")
if _download_kernels_from_dataset():
# Check again after download - look in kernels directory
kernel_files = list(kernels_dir.glob("*.so")) + list(kernels_dir.glob("*.pyd"))
kernel_files += list(_PREBUILT_PATH.glob("*.so")) + list(_PREBUILT_PATH.glob("*.pyd"))
if not kernel_files:
print("No pre-compiled kernels found")
return False
print(f"Found kernel files: {[f.name for f in kernel_files]}")
try:
import sys
import ctypes
# Try to load each kernel file
for kernel_file in kernel_files:
try:
# First try to load as a Python extension module
module_name = kernel_file.stem
spec = __import__('importlib.util').util.spec_from_file_location(module_name, kernel_file)
if spec and spec.loader:
mod = __import__('importlib.util').util.module_from_spec(spec)
spec.loader.exec_module(mod)
print(f"Loaded pre-compiled kernel module: {kernel_file.name}")
# Check what functions are available in the module
available_funcs = [attr for attr in dir(mod) if not attr.startswith('_')]
print(f"Available functions in kernel: {available_funcs}")
# Try to find the forward function with common naming patterns
forward_func = None
for func_name in ['fused_instance_norm_forward', 'forward', 'fused_instance_norm',
'instance_norm_forward', 'fused_inst_norm']:
if hasattr(mod, func_name):
forward_func = getattr(mod, func_name)
print(f"Using function: {func_name}")
break
if forward_func is None:
print(f"Warning: No suitable forward function found in {kernel_file.name}")
continue
# Store the kernel function globally for use with FusedInstanceNorm2d
_LOADED_KERNEL_FUNC = forward_func
# Create factory function that uses the wrapper with pre-loaded kernel
def make_fused_instance_norm(num_features, **kwargs):
from .instance_norm_wrapper import FusedInstanceNorm2d
# Pass the pre-loaded kernel function
return FusedInstanceNorm2d(num_features, kernel_func=forward_func, **kwargs)
_FusedInstanceNorm2d = make_fused_instance_norm
_CUDA_KERNELS_AVAILABLE = True
_KERNELS_COMPILED = True
print(f"Successfully initialized FusedInstanceNorm2d from {kernel_file.name}")
return True
except Exception as e:
print(f"Failed to load {kernel_file.name} as Python module: {e}")
# Try loading as raw ctypes library
try:
lib = ctypes.CDLL(str(kernel_file))
print(f"Loaded {kernel_file.name} as ctypes library")
# Could add ctypes wrapper here if needed
except Exception as e2:
print(f"Failed to load {kernel_file.name} as ctypes: {e2}")
continue
except Exception as e:
print(f"Failed to load prebuilt kernels: {e}")
return False
def compile_kernels():
"""
Compile CUDA kernels on-demand.
On ZeroGPU: Downloads pre-compiled kernels from dataset.
On local: Compiles custom CUDA kernels.
"""
global _CUDA_KERNELS_AVAILABLE, _FusedInstanceNorm2d, _KERNELS_COMPILED
if _KERNELS_COMPILED:
return _CUDA_KERNELS_AVAILABLE
# On ZeroGPU, try to download pre-compiled kernels from dataset
if _ZERO_GPU:
print("ZeroGPU mode: Attempting to download pre-compiled kernels from dataset...")
if load_prebuilt_kernels():
print("Successfully loaded pre-compiled CUDA kernels from dataset!")
return True
else:
print("No pre-compiled kernels found in dataset, using PyTorch GPU fallback")
_KERNELS_COMPILED = True
return False
# First, try pre-compiled kernels (for local too)
if load_prebuilt_kernels():
print("Using pre-compiled CUDA kernels!")
return True
# Check CUDA availability (safe here since we're not on ZeroGPU)
try:
if not torch.cuda.is_available():
_KERNELS_COMPILED = True
return False
except:
_KERNELS_COMPILED = True
return False
try:
from .instance_norm_wrapper import FusedInstanceNorm2d
_FusedInstanceNorm2d = FusedInstanceNorm2d
_CUDA_KERNELS_AVAILABLE = True
_KERNELS_COMPILED = True
print("CUDA kernels compiled successfully!")
return True
except Exception as e:
print(f"Failed to compile CUDA kernels: {e}")
print("Using PyTorch InstanceNorm2d fallback")
_KERNELS_COMPILED = True
return False
# Auto-compile on import for non-ZeroGPU environments with CUDA
if _ZERO_GPU:
# On ZeroGPU, try to download pre-compiled kernels
print("ZeroGPU detected: Attempting to load pre-compiled kernels from dataset...")
if load_prebuilt_kernels():
print("Using pre-compiled CUDA kernels from dataset!")
else:
print("No pre-compiled kernels available, using PyTorch GPU fallback")
_KERNELS_COMPILED = True
elif not _ZERO_GPU:
# On local, check if CUDA is available and compile
try:
if torch.cuda.is_available():
compile_kernels()
except:
_KERNELS_COMPILED = True
__all__ = [
'check_cuda_kernels',
'get_fused_instance_norm',
'FusedInstanceNorm2d',
'compile_kernels',
'load_prebuilt_kernels',
]