Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,093 Bytes
3386f25 37b20e3 3386f25 37b20e3 5de41f0 3386f25 37b20e3 3386f25 37b20e3 3386f25 37b20e3 3386f25 37b20e3 3386f25 37b20e3 3386f25 37b20e3 3386f25 37b20e3 3386f25 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
"""
StyleForge - Fused Instance Normalization Wrapper
Python interface for the fused InstanceNorm CUDA kernel.
On ZeroGPU: Uses pre-compiled kernels from HuggingFace dataset.
On local: JIT compiles from source.
"""
import torch
import torch.nn as nn
from pathlib import Path
from typing import Optional
import os
# Check if running on ZeroGPU - use same detection as app.py
try:
from spaces import GPU
_ZERO_GPU = True
except ImportError:
_ZERO_GPU = False
# Import local build utilities (only if not on ZeroGPU)
if not _ZERO_GPU:
from .cuda_build import compile_inline
# Global module cache
_instance_norm_module = None
_cuda_available = None
def check_cuda_available():
"""Check if CUDA is available and kernels can be compiled."""
global _cuda_available
if _cuda_available is not None:
return _cuda_available
_cuda_available = torch.cuda.is_available()
return _cuda_available
def get_instance_norm_module():
"""Lazy-load and compile the InstanceNorm kernel."""
global _instance_norm_module
if _instance_norm_module is not None:
return _instance_norm_module
# On ZeroGPU, pre-compiled kernels should be loaded by __init__.py
# This function is only for local JIT compilation
if _ZERO_GPU:
raise RuntimeError("ZeroGPU mode: Pre-compiled kernels should be loaded via __init__.py")
if not check_cuda_available():
raise RuntimeError("CUDA is not available. Cannot use fused InstanceNorm kernel.")
kernel_path = Path(__file__).parent / "instance_norm.cu"
if not kernel_path.exists():
raise FileNotFoundError(f"InstanceNorm kernel not found at {kernel_path}")
cuda_source = kernel_path.read_text()
print("Compiling fused InstanceNorm kernel...")
try:
_instance_norm_module = compile_inline(
name='fused_instance_norm',
cuda_source=cuda_source,
functions=['forward'],
build_directory=Path('build'),
verbose=False
)
print("InstanceNorm compilation complete!")
except Exception as e:
print(f"Failed to compile InstanceNorm kernel: {e}")
print("Falling back to PyTorch implementation.")
raise
return _instance_norm_module
class FusedInstanceNorm2d(nn.Module):
"""
Fused Instance Normalization 2D Module with automatic fallback.
On ZeroGPU: Uses pre-compiled kernels if available.
On local: May use JIT-compiled kernels.
"""
def __init__(
self,
num_features: int,
eps: float = 1e-5,
affine: bool = True,
track_running_stats: bool = False,
use_vectorized: bool = True,
kernel_func: Optional[callable] = None # Pre-loaded kernel function
):
super().__init__()
self.num_features = num_features
self.eps = eps
self.use_vectorized = use_vectorized
self.track_running_stats = False
self._kernel_func = kernel_func # Pre-loaded from __init__.py
# Enable CUDA if kernel function is provided OR not on ZeroGPU with CUDA available
self._use_cuda = (self._kernel_func is not None) or (check_cuda_available() if not _ZERO_GPU else False)
if affine:
self.gamma = nn.Parameter(torch.ones(num_features))
self.beta = nn.Parameter(torch.zeros(num_features))
else:
self.register_buffer('gamma', torch.ones(num_features))
self.register_buffer('beta', torch.zeros(num_features))
# Fallback to PyTorch InstanceNorm
self._pytorch_norm = nn.InstanceNorm2d(num_features, eps=eps, affine=affine)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.dim() != 4:
raise ValueError(f"Input must be 4D (B, C, H, W), got {x.dim()}D")
# Use pre-compiled kernel if available
if self._kernel_func is not None and x.is_cuda:
try:
result = self._kernel_func(
x.contiguous(),
self.gamma,
self.beta,
self.eps
)
return result
except Exception as e:
print(f"Custom kernel failed: {e}, falling back to PyTorch")
# Continue to PyTorch fallback
# Use CUDA kernel if available and on CUDA device (local JIT compilation)
if self._use_cuda and x.is_cuda and not _ZERO_GPU and self._kernel_func is None:
try:
module = get_instance_norm_module()
output = module.forward(
x.contiguous(),
self.gamma,
self.beta,
self.eps,
self.use_vectorized
)
return output
except Exception:
# Fallback to PyTorch
pass
# PyTorch fallback (still GPU accelerated, just not custom fused)
return self._pytorch_norm(x)
# Alias for compatibility
FusedInstanceNorm2dAuto = FusedInstanceNorm2d
|