Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| StyleForge - Fused Instance Normalization Wrapper | |
| Python interface for the fused InstanceNorm CUDA kernel. | |
| On ZeroGPU: Uses pre-compiled kernels from HuggingFace dataset. | |
| On local: JIT compiles from source. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from pathlib import Path | |
| from typing import Optional | |
| import os | |
| # Check if running on ZeroGPU - use same detection as app.py | |
| try: | |
| from spaces import GPU | |
| _ZERO_GPU = True | |
| except ImportError: | |
| _ZERO_GPU = False | |
| # Import local build utilities (only if not on ZeroGPU) | |
| if not _ZERO_GPU: | |
| from .cuda_build import compile_inline | |
| # Global module cache | |
| _instance_norm_module = None | |
| _cuda_available = None | |
| def check_cuda_available(): | |
| """Check if CUDA is available and kernels can be compiled.""" | |
| global _cuda_available | |
| if _cuda_available is not None: | |
| return _cuda_available | |
| _cuda_available = torch.cuda.is_available() | |
| return _cuda_available | |
| def get_instance_norm_module(): | |
| """Lazy-load and compile the InstanceNorm kernel.""" | |
| global _instance_norm_module | |
| if _instance_norm_module is not None: | |
| return _instance_norm_module | |
| # On ZeroGPU, pre-compiled kernels should be loaded by __init__.py | |
| # This function is only for local JIT compilation | |
| if _ZERO_GPU: | |
| raise RuntimeError("ZeroGPU mode: Pre-compiled kernels should be loaded via __init__.py") | |
| if not check_cuda_available(): | |
| raise RuntimeError("CUDA is not available. Cannot use fused InstanceNorm kernel.") | |
| kernel_path = Path(__file__).parent / "instance_norm.cu" | |
| if not kernel_path.exists(): | |
| raise FileNotFoundError(f"InstanceNorm kernel not found at {kernel_path}") | |
| cuda_source = kernel_path.read_text() | |
| print("Compiling fused InstanceNorm kernel...") | |
| try: | |
| _instance_norm_module = compile_inline( | |
| name='fused_instance_norm', | |
| cuda_source=cuda_source, | |
| functions=['forward'], | |
| build_directory=Path('build'), | |
| verbose=False | |
| ) | |
| print("InstanceNorm compilation complete!") | |
| except Exception as e: | |
| print(f"Failed to compile InstanceNorm kernel: {e}") | |
| print("Falling back to PyTorch implementation.") | |
| raise | |
| return _instance_norm_module | |
| class FusedInstanceNorm2d(nn.Module): | |
| """ | |
| Fused Instance Normalization 2D Module with automatic fallback. | |
| On ZeroGPU: Uses pre-compiled kernels if available. | |
| On local: May use JIT-compiled kernels. | |
| """ | |
| def __init__( | |
| self, | |
| num_features: int, | |
| eps: float = 1e-5, | |
| affine: bool = True, | |
| track_running_stats: bool = False, | |
| use_vectorized: bool = True, | |
| kernel_func: Optional[callable] = None # Pre-loaded kernel function | |
| ): | |
| super().__init__() | |
| self.num_features = num_features | |
| self.eps = eps | |
| self.use_vectorized = use_vectorized | |
| self.track_running_stats = False | |
| self._kernel_func = kernel_func # Pre-loaded from __init__.py | |
| # Enable CUDA if kernel function is provided OR not on ZeroGPU with CUDA available | |
| self._use_cuda = (self._kernel_func is not None) or (check_cuda_available() if not _ZERO_GPU else False) | |
| if affine: | |
| self.gamma = nn.Parameter(torch.ones(num_features)) | |
| self.beta = nn.Parameter(torch.zeros(num_features)) | |
| else: | |
| self.register_buffer('gamma', torch.ones(num_features)) | |
| self.register_buffer('beta', torch.zeros(num_features)) | |
| # Fallback to PyTorch InstanceNorm | |
| self._pytorch_norm = nn.InstanceNorm2d(num_features, eps=eps, affine=affine) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| if x.dim() != 4: | |
| raise ValueError(f"Input must be 4D (B, C, H, W), got {x.dim()}D") | |
| # Use pre-compiled kernel if available | |
| if self._kernel_func is not None and x.is_cuda: | |
| try: | |
| result = self._kernel_func( | |
| x.contiguous(), | |
| self.gamma, | |
| self.beta, | |
| self.eps | |
| ) | |
| return result | |
| except Exception as e: | |
| print(f"Custom kernel failed: {e}, falling back to PyTorch") | |
| # Continue to PyTorch fallback | |
| # Use CUDA kernel if available and on CUDA device (local JIT compilation) | |
| if self._use_cuda and x.is_cuda and not _ZERO_GPU and self._kernel_func is None: | |
| try: | |
| module = get_instance_norm_module() | |
| output = module.forward( | |
| x.contiguous(), | |
| self.gamma, | |
| self.beta, | |
| self.eps, | |
| self.use_vectorized | |
| ) | |
| return output | |
| except Exception: | |
| # Fallback to PyTorch | |
| pass | |
| # PyTorch fallback (still GPU accelerated, just not custom fused) | |
| return self._pytorch_norm(x) | |
| # Alias for compatibility | |
| FusedInstanceNorm2dAuto = FusedInstanceNorm2d | |