File size: 5,093 Bytes
3386f25
 
 
37b20e3
 
 
3386f25
 
 
 
 
 
37b20e3
 
5de41f0
 
 
 
 
 
3386f25
37b20e3
 
 
3386f25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37b20e3
 
 
 
 
3386f25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37b20e3
 
 
3386f25
 
 
 
 
 
 
 
37b20e3
 
3386f25
 
 
 
 
 
 
37b20e3
 
 
 
3386f25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37b20e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3386f25
 
 
 
 
 
 
 
 
 
 
 
 
 
37b20e3
3386f25
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
"""
StyleForge - Fused Instance Normalization Wrapper
Python interface for the fused InstanceNorm CUDA kernel.

On ZeroGPU: Uses pre-compiled kernels from HuggingFace dataset.
On local: JIT compiles from source.
"""

import torch
import torch.nn as nn
from pathlib import Path
from typing import Optional
import os

# Check if running on ZeroGPU - use same detection as app.py
try:
    from spaces import GPU
    _ZERO_GPU = True
except ImportError:
    _ZERO_GPU = False

# Import local build utilities (only if not on ZeroGPU)
if not _ZERO_GPU:
    from .cuda_build import compile_inline

# Global module cache
_instance_norm_module = None
_cuda_available = None


def check_cuda_available():
    """Check if CUDA is available and kernels can be compiled."""
    global _cuda_available
    if _cuda_available is not None:
        return _cuda_available

    _cuda_available = torch.cuda.is_available()
    return _cuda_available


def get_instance_norm_module():
    """Lazy-load and compile the InstanceNorm kernel."""
    global _instance_norm_module

    if _instance_norm_module is not None:
        return _instance_norm_module

    # On ZeroGPU, pre-compiled kernels should be loaded by __init__.py
    # This function is only for local JIT compilation
    if _ZERO_GPU:
        raise RuntimeError("ZeroGPU mode: Pre-compiled kernels should be loaded via __init__.py")

    if not check_cuda_available():
        raise RuntimeError("CUDA is not available. Cannot use fused InstanceNorm kernel.")

    kernel_path = Path(__file__).parent / "instance_norm.cu"

    if not kernel_path.exists():
        raise FileNotFoundError(f"InstanceNorm kernel not found at {kernel_path}")

    cuda_source = kernel_path.read_text()

    print("Compiling fused InstanceNorm kernel...")
    try:
        _instance_norm_module = compile_inline(
            name='fused_instance_norm',
            cuda_source=cuda_source,
            functions=['forward'],
            build_directory=Path('build'),
            verbose=False
        )
        print("InstanceNorm compilation complete!")
    except Exception as e:
        print(f"Failed to compile InstanceNorm kernel: {e}")
        print("Falling back to PyTorch implementation.")
        raise

    return _instance_norm_module


class FusedInstanceNorm2d(nn.Module):
    """
    Fused Instance Normalization 2D Module with automatic fallback.

    On ZeroGPU: Uses pre-compiled kernels if available.
    On local: May use JIT-compiled kernels.
    """

    def __init__(
        self,
        num_features: int,
        eps: float = 1e-5,
        affine: bool = True,
        track_running_stats: bool = False,
        use_vectorized: bool = True,
        kernel_func: Optional[callable] = None  # Pre-loaded kernel function
    ):
        super().__init__()

        self.num_features = num_features
        self.eps = eps
        self.use_vectorized = use_vectorized
        self.track_running_stats = False
        self._kernel_func = kernel_func  # Pre-loaded from __init__.py

        # Enable CUDA if kernel function is provided OR not on ZeroGPU with CUDA available
        self._use_cuda = (self._kernel_func is not None) or (check_cuda_available() if not _ZERO_GPU else False)

        if affine:
            self.gamma = nn.Parameter(torch.ones(num_features))
            self.beta = nn.Parameter(torch.zeros(num_features))
        else:
            self.register_buffer('gamma', torch.ones(num_features))
            self.register_buffer('beta', torch.zeros(num_features))

        # Fallback to PyTorch InstanceNorm
        self._pytorch_norm = nn.InstanceNorm2d(num_features, eps=eps, affine=affine)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.dim() != 4:
            raise ValueError(f"Input must be 4D (B, C, H, W), got {x.dim()}D")

        # Use pre-compiled kernel if available
        if self._kernel_func is not None and x.is_cuda:
            try:
                result = self._kernel_func(
                    x.contiguous(),
                    self.gamma,
                    self.beta,
                    self.eps
                )
                return result
            except Exception as e:
                print(f"Custom kernel failed: {e}, falling back to PyTorch")
                # Continue to PyTorch fallback

        # Use CUDA kernel if available and on CUDA device (local JIT compilation)
        if self._use_cuda and x.is_cuda and not _ZERO_GPU and self._kernel_func is None:
            try:
                module = get_instance_norm_module()
                output = module.forward(
                    x.contiguous(),
                    self.gamma,
                    self.beta,
                    self.eps,
                    self.use_vectorized
                )
                return output
            except Exception:
                # Fallback to PyTorch
                pass

        # PyTorch fallback (still GPU accelerated, just not custom fused)
        return self._pytorch_norm(x)


# Alias for compatibility
FusedInstanceNorm2dAuto = FusedInstanceNorm2d