pedroapfilho commited on
Commit
bf4c82d
·
unverified ·
1 Parent(s): b0a0560

Fix meta tensor: monkey-patch ResidualFSQ to force CPU tensor creation

Browse files

ZeroGPU's __torch_function__ hooks redirect ALL tensor creation to
meta device, even with torch.device('cpu') context. The fix patches
torch.tensor within ResidualFSQ.__init__ to explicitly pass device='cpu',
so the levels assertion runs on real CPU tensors instead of meta tensors.

Files changed (1) hide show
  1. acestep/handler.py +40 -12
acestep/handler.py CHANGED
@@ -483,15 +483,40 @@ class AceStepHandler:
483
  if "eager" not in attn_candidates:
484
  attn_candidates.append("eager")
485
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
486
  last_attn_error = None
487
  self.model = None
488
- for candidate in attn_candidates:
489
- try:
490
- logger.info(f"[initialize_service] Attempting to load model with attention implementation: {candidate}")
491
- # Force CPU device context to override ZeroGPU's meta device
492
- # redirection. ResidualFSQ asserts on tensor values during
493
- # __init__, which fails on meta tensors.
494
- with torch.device("cpu"):
495
  self.model = AutoModel.from_pretrained(
496
  acestep_v15_checkpoint_path,
497
  trust_remote_code=True,
@@ -500,11 +525,14 @@ class AceStepHandler:
500
  low_cpu_mem_usage=False,
501
  _fast_init=False,
502
  )
503
- attn_implementation = candidate
504
- break
505
- except Exception as e:
506
- last_attn_error = e
507
- logger.warning(f"[initialize_service] Failed to load model with {candidate}: {e}")
 
 
 
508
 
509
  if self.model is None:
510
  raise RuntimeError(
 
483
  if "eager" not in attn_candidates:
484
  attn_candidates.append("eager")
485
 
486
+ # Monkey-patch ResidualFSQ to handle ZeroGPU meta tensors.
487
+ # ZeroGPU redirects tensor creation to meta device via
488
+ # __torch_function__. ResidualFSQ.__init__ does
489
+ # `assert (torch.tensor(levels) > 1).all()` which fails
490
+ # because .all() can't evaluate meta tensors.
491
+ # Fix: force torch.tensor calls in FSQ init to use CPU.
492
+ _fsq_patched = False
493
+ try:
494
+ from vector_quantize_pytorch import residual_fsq as _fsq_mod
495
+ _orig_fsq_init = _fsq_mod.ResidualFSQ.__init__
496
+
497
+ def _meta_safe_fsq_init(self_fsq, **kwargs):
498
+ _real_torch_tensor = torch.tensor
499
+ def _cpu_tensor(data, *a, **kw):
500
+ kw["device"] = "cpu"
501
+ return _real_torch_tensor(data, *a, **kw)
502
+ torch.tensor = _cpu_tensor
503
+ try:
504
+ _orig_fsq_init(self_fsq, **kwargs)
505
+ finally:
506
+ torch.tensor = _real_torch_tensor
507
+
508
+ _fsq_mod.ResidualFSQ.__init__ = _meta_safe_fsq_init
509
+ _fsq_patched = True
510
+ logger.info("[initialize_service] Patched ResidualFSQ for ZeroGPU compat")
511
+ except ImportError:
512
+ pass
513
+
514
  last_attn_error = None
515
  self.model = None
516
+ try:
517
+ for candidate in attn_candidates:
518
+ try:
519
+ logger.info(f"[initialize_service] Attempting to load model with attention implementation: {candidate}")
 
 
 
520
  self.model = AutoModel.from_pretrained(
521
  acestep_v15_checkpoint_path,
522
  trust_remote_code=True,
 
525
  low_cpu_mem_usage=False,
526
  _fast_init=False,
527
  )
528
+ attn_implementation = candidate
529
+ break
530
+ except Exception as e:
531
+ last_attn_error = e
532
+ logger.warning(f"[initialize_service] Failed to load model with {candidate}: {e}")
533
+ finally:
534
+ if _fsq_patched:
535
+ _fsq_mod.ResidualFSQ.__init__ = _orig_fsq_init
536
 
537
  if self.model is None:
538
  raise RuntimeError(