github-actions[bot] commited on
Commit
d9cf8c2
·
1 Parent(s): 64c64e2

Deploy from GitHub - 2026-01-21 09:09:33

Browse files
Files changed (1) hide show
  1. kernels/__init__.py +87 -41
kernels/__init__.py CHANGED
@@ -29,26 +29,31 @@ _KERNEL_DATASET = "oliau/styleforge-kernels" # You'll need to create this datas
29
  def _download_kernels_from_dataset():
30
  """Download pre-compiled kernels from HuggingFace dataset."""
31
  try:
32
- from huggingface_hub import hf_hub_download, HfFileSystem
33
- fs = HfFileSystem()
34
 
35
- # List all .so and .pyd files in the dataset
 
 
36
  kernel_files = []
37
  try:
38
- files = fs.ls(f"datasets/{_KERNEL_DATASET}")
39
- for f in files:
40
- if f['name'].endswith(('.so', '.pyd')):
41
- kernel_files.append(Path(f['name']).name)
42
- except Exception:
43
- # Dataset might not exist yet
44
  return False
45
 
46
  if not kernel_files:
 
47
  return False
48
 
49
- # Download each kernel file
 
50
  for kernel_file in kernel_files:
51
  try:
 
52
  local_path = hf_hub_download(
53
  repo_id=_KERNEL_DATASET,
54
  filename=kernel_file,
@@ -56,12 +61,13 @@ def _download_kernels_from_dataset():
56
  local_dir=str(_PREBUILT_PATH.parent),
57
  local_dir_use_symlinks=False
58
  )
 
59
  print(f"Downloaded kernel: {kernel_file}")
60
  except Exception as e:
61
  print(f"Failed to download {kernel_file}: {e}")
62
  continue
63
 
64
- return True
65
  except ImportError:
66
  print("huggingface_hub not available, skipping kernel download")
67
  return False
@@ -93,7 +99,7 @@ def get_fused_instance_norm(num_features, **kwargs):
93
 
94
  def load_prebuilt_kernels():
95
  """
96
- Try to load pre-compiled CUDA kernels from prebuilt/ directory.
97
  On HuggingFace, downloads from dataset if local files not found.
98
 
99
  Returns True if successful, False otherwise.
@@ -103,36 +109,56 @@ def load_prebuilt_kernels():
103
  if _KERNELS_COMPILED:
104
  return _CUDA_KERNELS_AVAILABLE
105
 
106
- # Check if prebuilt kernels exist locally
107
- prebuilt_files = list(_PREBUILT_PATH.glob("*.so")) + list(_PREBUILT_PATH.glob("*.pyd"))
 
 
108
 
109
  # On HuggingFace Spaces, try downloading from dataset if not found locally
110
- if not prebuilt_files and _ZERO_GPU:
111
  print("No local pre-compiled kernels found. Trying HuggingFace dataset...")
112
  if _download_kernels_from_dataset():
113
- # Check again after download
114
- prebuilt_files = list(_PREBUILT_PATH.glob("*.so")) + list(_PREBUILT_PATH.glob("*.pyd"))
 
115
 
116
- if not prebuilt_files:
117
  print("No pre-compiled kernels found")
118
  return False
119
 
 
 
120
  try:
121
- # Try to import from prebuilt directory
122
  import sys
123
- if str(_PREBUILT_PATH) not in sys.path:
124
- sys.path.insert(0, str(_PREBUILT_PATH))
125
 
126
- # Try to load the prebuilt module
127
- for kernel_file in prebuilt_files:
128
  try:
129
- # Import the compiled module
130
  module_name = kernel_file.stem
131
  spec = __import__('importlib.util').util.spec_from_file_location(module_name, kernel_file)
132
  if spec and spec.loader:
133
  mod = __import__('importlib.util').util.module_from_spec(spec)
134
  spec.loader.exec_module(mod)
135
- print(f"Loaded pre-compiled kernel: {kernel_file.name}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
  # Create FusedInstanceNorm2d class
138
  class PrebuiltFusedInstanceNorm2d(torch.nn.Module):
@@ -147,21 +173,34 @@ def load_prebuilt_kernels():
147
  self.register_buffer('gamma', torch.ones(num_features))
148
  self.register_buffer('beta', torch.zeros(num_features))
149
  self._pytorch_norm = torch.nn.InstanceNorm2d(num_features, **kwargs)
 
150
 
151
  def forward(self, x):
152
  try:
153
- return mod.fused_instance_norm_forward(
 
154
  x.contiguous(), self.gamma, self.beta, self.eps
155
  )
156
- except Exception:
 
 
157
  return self._pytorch_norm(x)
158
 
159
  _FusedInstanceNorm2d = PrebuiltFusedInstanceNorm2d
160
  _CUDA_KERNELS_AVAILABLE = True
161
  _KERNELS_COMPILED = True
 
162
  return True
 
163
  except Exception as e:
164
- print(f"Failed to load {kernel_file.name}: {e}")
 
 
 
 
 
 
 
165
  continue
166
 
167
  except Exception as e:
@@ -174,7 +213,7 @@ def compile_kernels():
174
  """
175
  Compile CUDA kernels on-demand.
176
 
177
- On ZeroGPU: Tries pre-compiled kernels first.
178
  On local: Compiles custom CUDA kernels.
179
  """
180
  global _CUDA_KERNELS_AVAILABLE, _FusedInstanceNorm2d, _KERNELS_COMPILED
@@ -182,17 +221,22 @@ def compile_kernels():
182
  if _KERNELS_COMPILED:
183
  return _CUDA_KERNELS_AVAILABLE
184
 
185
- # First, try pre-compiled kernels
 
 
 
 
 
 
 
 
 
 
 
186
  if load_prebuilt_kernels():
187
  print("Using pre-compiled CUDA kernels!")
188
  return True
189
 
190
- # Fall back to JIT compilation (only on local, not ZeroGPU)
191
- if _ZERO_GPU:
192
- print("ZeroGPU mode: No pre-compiled kernels found, using PyTorch fallback")
193
- _KERNELS_COMPILED = True
194
- return False
195
-
196
  if not torch.cuda.is_available():
197
  _KERNELS_COMPILED = True
198
  return False
@@ -212,14 +256,16 @@ def compile_kernels():
212
 
213
 
214
  # Auto-compile on import for non-ZeroGPU environments with CUDA
215
- if torch.cuda.is_available() and not _ZERO_GPU:
216
- compile_kernels()
217
- elif _ZERO_GPU:
218
- # On ZeroGPU, try prebuilt kernels
219
  if load_prebuilt_kernels():
220
  print("Using pre-compiled CUDA kernels from dataset!")
221
  else:
222
- print("No pre-compiled kernels, using PyTorch GPU fallback")
 
 
 
223
 
224
 
225
  __all__ = [
 
29
  def _download_kernels_from_dataset():
30
  """Download pre-compiled kernels from HuggingFace dataset."""
31
  try:
32
+ from huggingface_hub import hf_hub_download, HfFileSystem, list_repo_files
33
+ import re
34
 
35
+ print(f"Looking for kernels in dataset: {_KERNEL_DATASET}")
36
+
37
+ # List all files in the dataset
38
  kernel_files = []
39
  try:
40
+ all_files = list_repo_files(_KERNEL_DATASET, repo_type="dataset")
41
+ # Filter for .so files (Linux) and .pyd files (Windows)
42
+ kernel_files = [f for f in all_files if f.endswith(('.so', '.pyd'))]
43
+ print(f"Found kernel files in dataset: {kernel_files}")
44
+ except Exception as e:
45
+ print(f"Could not list dataset files: {e}")
46
  return False
47
 
48
  if not kernel_files:
49
+ print("No kernel files (.so/.pyd) found in dataset")
50
  return False
51
 
52
+ # Download each kernel file to the prebuilt directory
53
+ downloaded = []
54
  for kernel_file in kernel_files:
55
  try:
56
+ # Download to the kernels directory (parent of prebuilt)
57
  local_path = hf_hub_download(
58
  repo_id=_KERNEL_DATASET,
59
  filename=kernel_file,
 
61
  local_dir=str(_PREBUILT_PATH.parent),
62
  local_dir_use_symlinks=False
63
  )
64
+ downloaded.append(kernel_file)
65
  print(f"Downloaded kernel: {kernel_file}")
66
  except Exception as e:
67
  print(f"Failed to download {kernel_file}: {e}")
68
  continue
69
 
70
+ return len(downloaded) > 0
71
  except ImportError:
72
  print("huggingface_hub not available, skipping kernel download")
73
  return False
 
99
 
100
  def load_prebuilt_kernels():
101
  """
102
+ Try to load pre-compiled CUDA kernels from the kernels directory.
103
  On HuggingFace, downloads from dataset if local files not found.
104
 
105
  Returns True if successful, False otherwise.
 
109
  if _KERNELS_COMPILED:
110
  return _CUDA_KERNELS_AVAILABLE
111
 
112
+ # Check for kernels in the kernels directory (parent of prebuilt) and prebuilt/
113
+ kernels_dir = Path(__file__).parent
114
+ kernel_files = list(kernels_dir.glob("*.so")) + list(kernels_dir.glob("*.pyd"))
115
+ kernel_files += list(_PREBUILT_PATH.glob("*.so")) + list(_PREBUILT_PATH.glob("*.pyd"))
116
 
117
  # On HuggingFace Spaces, try downloading from dataset if not found locally
118
+ if not kernel_files and _ZERO_GPU:
119
  print("No local pre-compiled kernels found. Trying HuggingFace dataset...")
120
  if _download_kernels_from_dataset():
121
+ # Check again after download - look in kernels directory
122
+ kernel_files = list(kernels_dir.glob("*.so")) + list(kernels_dir.glob("*.pyd"))
123
+ kernel_files += list(_PREBUILT_PATH.glob("*.so")) + list(_PREBUILT_PATH.glob("*.pyd"))
124
 
125
+ if not kernel_files:
126
  print("No pre-compiled kernels found")
127
  return False
128
 
129
+ print(f"Found kernel files: {[f.name for f in kernel_files]}")
130
+
131
  try:
 
132
  import sys
133
+ import ctypes
 
134
 
135
+ # Try to load each kernel file
136
+ for kernel_file in kernel_files:
137
  try:
138
+ # First try to load as a Python extension module
139
  module_name = kernel_file.stem
140
  spec = __import__('importlib.util').util.spec_from_file_location(module_name, kernel_file)
141
  if spec and spec.loader:
142
  mod = __import__('importlib.util').util.module_from_spec(spec)
143
  spec.loader.exec_module(mod)
144
+ print(f"Loaded pre-compiled kernel module: {kernel_file.name}")
145
+
146
+ # Check what functions are available in the module
147
+ available_funcs = [attr for attr in dir(mod) if not attr.startswith('_')]
148
+ print(f"Available functions in kernel: {available_funcs}")
149
+
150
+ # Try to find the forward function with common naming patterns
151
+ forward_func = None
152
+ for func_name in ['fused_instance_norm_forward', 'forward', 'fused_instance_norm',
153
+ 'instance_norm_forward', 'fused_inst_norm']:
154
+ if hasattr(mod, func_name):
155
+ forward_func = getattr(mod, func_name)
156
+ print(f"Using function: {func_name}")
157
+ break
158
+
159
+ if forward_func is None:
160
+ print(f"Warning: No suitable forward function found in {kernel_file.name}")
161
+ continue
162
 
163
  # Create FusedInstanceNorm2d class
164
  class PrebuiltFusedInstanceNorm2d(torch.nn.Module):
 
173
  self.register_buffer('gamma', torch.ones(num_features))
174
  self.register_buffer('beta', torch.zeros(num_features))
175
  self._pytorch_norm = torch.nn.InstanceNorm2d(num_features, **kwargs)
176
+ self._kernel_func = forward_func
177
 
178
  def forward(self, x):
179
  try:
180
+ # Try calling the kernel function
181
+ result = self._kernel_func(
182
  x.contiguous(), self.gamma, self.beta, self.eps
183
  )
184
+ return result
185
+ except Exception as e:
186
+ # Fallback to PyTorch
187
  return self._pytorch_norm(x)
188
 
189
  _FusedInstanceNorm2d = PrebuiltFusedInstanceNorm2d
190
  _CUDA_KERNELS_AVAILABLE = True
191
  _KERNELS_COMPILED = True
192
+ print(f"Successfully initialized FusedInstanceNorm2d from {kernel_file.name}")
193
  return True
194
+
195
  except Exception as e:
196
+ print(f"Failed to load {kernel_file.name} as Python module: {e}")
197
+ # Try loading as raw ctypes library
198
+ try:
199
+ lib = ctypes.CDLL(str(kernel_file))
200
+ print(f"Loaded {kernel_file.name} as ctypes library")
201
+ # Could add ctypes wrapper here if needed
202
+ except Exception as e2:
203
+ print(f"Failed to load {kernel_file.name} as ctypes: {e2}")
204
  continue
205
 
206
  except Exception as e:
 
213
  """
214
  Compile CUDA kernels on-demand.
215
 
216
+ On ZeroGPU: Downloads pre-compiled kernels from dataset.
217
  On local: Compiles custom CUDA kernels.
218
  """
219
  global _CUDA_KERNELS_AVAILABLE, _FusedInstanceNorm2d, _KERNELS_COMPILED
 
221
  if _KERNELS_COMPILED:
222
  return _CUDA_KERNELS_AVAILABLE
223
 
224
+ # On ZeroGPU, try to download pre-compiled kernels from dataset
225
+ if _ZERO_GPU:
226
+ print("ZeroGPU mode: Attempting to download pre-compiled kernels from dataset...")
227
+ if load_prebuilt_kernels():
228
+ print("Successfully loaded pre-compiled CUDA kernels from dataset!")
229
+ return True
230
+ else:
231
+ print("No pre-compiled kernels found in dataset, using PyTorch GPU fallback")
232
+ _KERNELS_COMPILED = True
233
+ return False
234
+
235
+ # First, try pre-compiled kernels (for local too)
236
  if load_prebuilt_kernels():
237
  print("Using pre-compiled CUDA kernels!")
238
  return True
239
 
 
 
 
 
 
 
240
  if not torch.cuda.is_available():
241
  _KERNELS_COMPILED = True
242
  return False
 
256
 
257
 
258
  # Auto-compile on import for non-ZeroGPU environments with CUDA
259
+ if _ZERO_GPU:
260
+ # On ZeroGPU, try to download pre-compiled kernels
261
+ print("ZeroGPU detected: Attempting to load pre-compiled kernels from dataset...")
 
262
  if load_prebuilt_kernels():
263
  print("Using pre-compiled CUDA kernels from dataset!")
264
  else:
265
+ print("No pre-compiled kernels available, using PyTorch GPU fallback")
266
+ _KERNELS_COMPILED = True
267
+ elif torch.cuda.is_available():
268
+ compile_kernels()
269
 
270
 
271
  __all__ = [