pedroapfilho commited on
Commit
0c47532
·
unverified ·
1 Parent(s): 773a0e0

Fix meta tensor: patch residual_fsq module's tensor reference, not torch.tensor

Browse files

ResidualFSQ uses 'from torch import tensor' - a direct module-level
reference that bypasses our torch.tensor patch. Now we replace the
'tensor' name in the residual_fsq module namespace directly, forcing
CPU device for all tensor creation during FSQ init.

Files changed (1) hide show
  1. acestep/handler.py +20 -22
acestep/handler.py CHANGED
@@ -483,32 +483,30 @@ class AceStepHandler:
483
  if "eager" not in attn_candidates:
484
  attn_candidates.append("eager")
485
 
486
- # Monkey-patch ResidualFSQ to handle ZeroGPU meta tensors.
487
- # ZeroGPU redirects tensor creation to meta device via
488
- # __torch_function__. ResidualFSQ.__init__ does
489
- # `assert (torch.tensor(levels) > 1).all()` which fails
490
- # because .all() can't evaluate meta tensors.
491
- # Fix: force torch.tensor calls in FSQ init to use CPU.
 
 
 
 
492
  _fsq_patched = False
 
493
  try:
494
  from vector_quantize_pytorch import residual_fsq as _fsq_mod
495
- _orig_fsq_init = _fsq_mod.ResidualFSQ.__init__
496
-
497
- def _meta_safe_fsq_init(self_fsq, **kwargs):
498
- _real_torch_tensor = torch.tensor
499
- def _cpu_tensor(data, *a, **kw):
500
- kw["device"] = "cpu"
501
- return _real_torch_tensor(data, *a, **kw)
502
- torch.tensor = _cpu_tensor
503
- try:
504
- _orig_fsq_init(self_fsq, **kwargs)
505
- finally:
506
- torch.tensor = _real_torch_tensor
507
 
508
- _fsq_mod.ResidualFSQ.__init__ = _meta_safe_fsq_init
509
  _fsq_patched = True
510
- logger.info("[initialize_service] Patched ResidualFSQ for ZeroGPU compat")
511
- except ImportError:
512
  pass
513
 
514
  last_attn_error = None
@@ -532,7 +530,7 @@ class AceStepHandler:
532
  logger.warning(f"[initialize_service] Failed to load model with {candidate}: {e}")
533
  finally:
534
  if _fsq_patched:
535
- _fsq_mod.ResidualFSQ.__init__ = _orig_fsq_init
536
 
537
  if self.model is None:
538
  raise RuntimeError(
 
483
  if "eager" not in attn_candidates:
484
  attn_candidates.append("eager")
485
 
486
+ # Patch ResidualFSQ to avoid meta tensor failures.
487
+ # ResidualFSQ uses `from torch import tensor` (a direct
488
+ # module-level reference). During model init, transformers
489
+ # sets a meta device context that makes all `tensor()` calls
490
+ # create meta tensors. ResidualFSQ then does:
491
+ # levels_tensor = tensor(levels)
492
+ # assert (levels_tensor > 1).all() # fails on meta
493
+ # Fix: replace the `tensor` name in the residual_fsq module
494
+ # namespace with a CPU-forcing version. All derived operations
495
+ # on CPU tensors stay on CPU, so the assertion works.
496
  _fsq_patched = False
497
+ _orig_tensor_fn = None
498
  try:
499
  from vector_quantize_pytorch import residual_fsq as _fsq_mod
500
+ _orig_tensor_fn = _fsq_mod.tensor
501
+
502
+ def _cpu_tensor(data, *args, **kwargs):
503
+ kwargs["device"] = "cpu"
504
+ return _orig_tensor_fn(data, *args, **kwargs)
 
 
 
 
 
 
 
505
 
506
+ _fsq_mod.tensor = _cpu_tensor
507
  _fsq_patched = True
508
+ logger.info("[initialize_service] Patched residual_fsq.tensor for meta device compat")
509
+ except (ImportError, AttributeError):
510
  pass
511
 
512
  last_attn_error = None
 
530
  logger.warning(f"[initialize_service] Failed to load model with {candidate}: {e}")
531
  finally:
532
  if _fsq_patched:
533
+ _fsq_mod.tensor = _orig_tensor_fn
534
 
535
  if self.model is None:
536
  raise RuntimeError(