silveroxides commited on
Commit
28030ea
·
1 Parent(s): aee2946

fix: Use GPU for MXFP8/NVFP4 formats, CPU for others

Browse files

- gpu_quantize() with 600s duration for CUDA-requiring formats
- cpu_quantize() for FP8 tensorwise/block and INT8

Files changed (1) hide show
  1. app.py +16 -9
app.py CHANGED
@@ -147,11 +147,15 @@ def upload_model_as_pr(
147
  except Exception as e:
148
  return f"❌ Upload failed: {str(e)}"
149
 
150
- @spaces.GPU(duration=30)
151
- def gpu_check():
152
- """Minimal GPU function to satisfy ZeroGPU space requirements."""
153
- import torch
154
- return torch.cuda.is_available()
 
 
 
 
155
 
156
 
157
  def quantize_model(
@@ -230,7 +234,13 @@ def quantize_model(
230
  )
231
 
232
  try:
233
- result = convert(config)
 
 
 
 
 
 
234
 
235
  if not result.success:
236
  status_log.append(f"❌ Quantization failed: {result.error}")
@@ -254,9 +264,6 @@ def quantize_model(
254
  )
255
  status_log.append(upload_status)
256
 
257
- # Brief GPU check to satisfy ZeroGPU requirements
258
- gpu_check()
259
-
260
  return result.output_path, "\n\n".join(status_log)
261
 
262
  except Exception as e:
 
147
  except Exception as e:
148
  return f"❌ Upload failed: {str(e)}"
149
 
150
+ @spaces.GPU(duration=600)
151
+ def gpu_quantize(config):
152
+ """Run quantization on GPU for formats that require CUDA (MXFP8, NVFP4)."""
153
+ return convert(config)
154
+
155
+
156
+ def cpu_quantize(config):
157
+ """Run quantization on CPU for formats that don't require CUDA."""
158
+ return convert(config)
159
 
160
 
161
  def quantize_model(
 
234
  )
235
 
236
  try:
237
+ # Use GPU for formats that require CUDA, CPU for others
238
+ requires_gpu = format_config["format"] in ("mxfp8", "nvfp4")
239
+ if requires_gpu:
240
+ status_log.append("🖥️ Using GPU for quantization...")
241
+ result = gpu_quantize(config)
242
+ else:
243
+ result = cpu_quantize(config)
244
 
245
  if not result.success:
246
  status_log.append(f"❌ Quantization failed: {result.error}")
 
264
  )
265
  status_log.append(upload_status)
266
 
 
 
 
267
  return result.output_path, "\n\n".join(status_log)
268
 
269
  except Exception as e: