""" StyleForge CUDA Kernels Package Custom CUDA kernels for accelerated neural style transfer. For ZeroGPU/HuggingFace: Pre-compiled kernels are downloaded from HF dataset. For local: Kernels are JIT-compiled if prebuilt not available. """ import torch import os from pathlib import Path # Try to import CUDA kernels, fall back gracefully _CUDA_KERNELS_AVAILABLE = False _FusedInstanceNorm2d = None _KERNELS_COMPILED = False _LOADED_KERNEL_FUNC = None # Check if running on ZeroGPU or HuggingFace Spaces # Use the same detection as app.py - presence of spaces package try: from spaces import GPU _ZERO_GPU = True except ImportError: _ZERO_GPU = False # Path to pre-compiled kernels _PREBUILT_PATH = Path(__file__).parent / "prebuilt" _PREBUILT_PATH.mkdir(exist_ok=True) # HuggingFace dataset for prebuilt kernels _KERNEL_DATASET = "oliau/styleforge-kernels" # You'll need to create this dataset def _download_kernels_from_dataset(): """Download pre-compiled kernels from HuggingFace dataset.""" try: from huggingface_hub import hf_hub_download import sys print(f"Looking for kernels in dataset: {_KERNEL_DATASET}") # Known kernel file name kernel_file = "fused_instance_norm.so" # Download directly to the kernels directory try: local_path = hf_hub_download( repo_id=_KERNEL_DATASET, filename=kernel_file, repo_type="dataset", local_dir=str(_PREBUILT_PATH.parent), local_dir_use_symlinks=False ) print(f"Successfully downloaded kernel: {kernel_file} -> {local_path}") return True except Exception as e: print(f"Failed to download {kernel_file}: {e}") # Try alternative paths in case the file is in a subdirectory for subdir in ["", "kernels/", "prebuilt/", "build/"]: try: alt_path = subdir + kernel_file local_path = hf_hub_download( repo_id=_KERNEL_DATASET, filename=alt_path, repo_type="dataset", local_dir=str(_PREBUILT_PATH.parent), local_dir_use_symlinks=False ) print(f"Successfully downloaded kernel from {alt_path}: {local_path}") return True except Exception: continue return False except ImportError as e: print(f"huggingface_hub not available: {e}") return False except Exception as e: print(f"Failed to download kernels from dataset: {e}") import traceback traceback.print_exc() return False def check_cuda_kernels(): """Check if CUDA kernels are available.""" return _CUDA_KERNELS_AVAILABLE def get_fused_instance_norm(num_features, **kwargs): """ Get FusedInstanceNorm2d module or PyTorch fallback. On ZeroGPU: Uses pre-compiled kernels if available. On local: May use custom fused kernels (prebuilt or JIT). """ if _FusedInstanceNorm2d is not None: try: return _FusedInstanceNorm2d(num_features, **kwargs) except Exception: pass # Fallback to PyTorch (still GPU-accelerated, just not custom fused) return torch.nn.InstanceNorm2d(num_features, affine=kwargs.get('affine', True)) def load_prebuilt_kernels(): """ Try to load pre-compiled CUDA kernels from the kernels directory. On HuggingFace, downloads from dataset if local files not found. Returns True if successful, False otherwise. """ global _FusedInstanceNorm2d, _CUDA_KERNELS_AVAILABLE, _KERNELS_COMPILED if _KERNELS_COMPILED: return _CUDA_KERNELS_AVAILABLE # Check for kernels in the kernels directory (parent of prebuilt) and prebuilt/ kernels_dir = Path(__file__).parent kernel_files = list(kernels_dir.glob("*.so")) + list(kernels_dir.glob("*.pyd")) kernel_files += list(_PREBUILT_PATH.glob("*.so")) + list(_PREBUILT_PATH.glob("*.pyd")) # Try downloading from dataset if not found locally (on ZeroGPU or if CUDA available) # IMPORTANT: Don't call torch.cuda.is_available() on ZeroGPU at module level! if not kernel_files: print(f"No local pre-compiled kernels found. _ZERO_GPU={_ZERO_GPU}") # On ZeroGPU, always try to download without checking CUDA # On local, check CUDA first before downloading should_download = _ZERO_GPU if not _ZERO_GPU: try: should_download = torch.cuda.is_available() except: should_download = False if should_download: print("Trying HuggingFace dataset...") if _download_kernels_from_dataset(): # Check again after download - look in kernels directory kernel_files = list(kernels_dir.glob("*.so")) + list(kernels_dir.glob("*.pyd")) kernel_files += list(_PREBUILT_PATH.glob("*.so")) + list(_PREBUILT_PATH.glob("*.pyd")) if not kernel_files: print("No pre-compiled kernels found") return False print(f"Found kernel files: {[f.name for f in kernel_files]}") try: import sys import ctypes # Try to load each kernel file for kernel_file in kernel_files: try: # First try to load as a Python extension module module_name = kernel_file.stem spec = __import__('importlib.util').util.spec_from_file_location(module_name, kernel_file) if spec and spec.loader: mod = __import__('importlib.util').util.module_from_spec(spec) spec.loader.exec_module(mod) print(f"Loaded pre-compiled kernel module: {kernel_file.name}") # Check what functions are available in the module available_funcs = [attr for attr in dir(mod) if not attr.startswith('_')] print(f"Available functions in kernel: {available_funcs}") # Try to find the forward function with common naming patterns forward_func = None for func_name in ['fused_instance_norm_forward', 'forward', 'fused_instance_norm', 'instance_norm_forward', 'fused_inst_norm']: if hasattr(mod, func_name): forward_func = getattr(mod, func_name) print(f"Using function: {func_name}") break if forward_func is None: print(f"Warning: No suitable forward function found in {kernel_file.name}") continue # Store the kernel function globally for use with FusedInstanceNorm2d _LOADED_KERNEL_FUNC = forward_func # Create factory function that uses the wrapper with pre-loaded kernel def make_fused_instance_norm(num_features, **kwargs): from .instance_norm_wrapper import FusedInstanceNorm2d # Pass the pre-loaded kernel function return FusedInstanceNorm2d(num_features, kernel_func=forward_func, **kwargs) _FusedInstanceNorm2d = make_fused_instance_norm _CUDA_KERNELS_AVAILABLE = True _KERNELS_COMPILED = True print(f"Successfully initialized FusedInstanceNorm2d from {kernel_file.name}") return True except Exception as e: print(f"Failed to load {kernel_file.name} as Python module: {e}") # Try loading as raw ctypes library try: lib = ctypes.CDLL(str(kernel_file)) print(f"Loaded {kernel_file.name} as ctypes library") # Could add ctypes wrapper here if needed except Exception as e2: print(f"Failed to load {kernel_file.name} as ctypes: {e2}") continue except Exception as e: print(f"Failed to load prebuilt kernels: {e}") return False def compile_kernels(): """ Compile CUDA kernels on-demand. On ZeroGPU: Downloads pre-compiled kernels from dataset. On local: Compiles custom CUDA kernels. """ global _CUDA_KERNELS_AVAILABLE, _FusedInstanceNorm2d, _KERNELS_COMPILED if _KERNELS_COMPILED: return _CUDA_KERNELS_AVAILABLE # On ZeroGPU, try to download pre-compiled kernels from dataset if _ZERO_GPU: print("ZeroGPU mode: Attempting to download pre-compiled kernels from dataset...") if load_prebuilt_kernels(): print("Successfully loaded pre-compiled CUDA kernels from dataset!") return True else: print("No pre-compiled kernels found in dataset, using PyTorch GPU fallback") _KERNELS_COMPILED = True return False # First, try pre-compiled kernels (for local too) if load_prebuilt_kernels(): print("Using pre-compiled CUDA kernels!") return True # Check CUDA availability (safe here since we're not on ZeroGPU) try: if not torch.cuda.is_available(): _KERNELS_COMPILED = True return False except: _KERNELS_COMPILED = True return False try: from .instance_norm_wrapper import FusedInstanceNorm2d _FusedInstanceNorm2d = FusedInstanceNorm2d _CUDA_KERNELS_AVAILABLE = True _KERNELS_COMPILED = True print("CUDA kernels compiled successfully!") return True except Exception as e: print(f"Failed to compile CUDA kernels: {e}") print("Using PyTorch InstanceNorm2d fallback") _KERNELS_COMPILED = True return False # Auto-compile on import for non-ZeroGPU environments with CUDA if _ZERO_GPU: # On ZeroGPU, try to download pre-compiled kernels print("ZeroGPU detected: Attempting to load pre-compiled kernels from dataset...") if load_prebuilt_kernels(): print("Using pre-compiled CUDA kernels from dataset!") else: print("No pre-compiled kernels available, using PyTorch GPU fallback") _KERNELS_COMPILED = True elif not _ZERO_GPU: # On local, check if CUDA is available and compile try: if torch.cuda.is_available(): compile_kernels() except: _KERNELS_COMPILED = True __all__ = [ 'check_cuda_kernels', 'get_fused_instance_norm', 'FusedInstanceNorm2d', 'compile_kernels', 'load_prebuilt_kernels', ]