Spaces:
Running
Running
Fix meta tensor device mismatch by also patching torch.arange
Browse filesFSQ.__init__ calls torch.arange() through the torch module to build
its implicit codebook. The meta device context set by newer transformers
(which ignores low_cpu_mem_usage=False) redirects this to meta device,
causing a device mismatch with our CPU-forced _levels buffer.
Patch torch.arange globally during model loading to force device="cpu",
alongside the existing tensor() patches for both VQ module namespaces.
- acestep/handler.py +26 -8
acestep/handler.py
CHANGED
|
@@ -484,13 +484,22 @@ class AceStepHandler:
|
|
| 484 |
attn_candidates.append("eager")
|
| 485 |
|
| 486 |
# Patch vector_quantize_pytorch to avoid meta tensor failures.
|
| 487 |
-
#
|
| 488 |
-
#
|
| 489 |
-
#
|
| 490 |
-
#
|
| 491 |
-
#
|
| 492 |
-
#
|
| 493 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 494 |
_vq_patched = False
|
| 495 |
_vq_originals = {}
|
| 496 |
try:
|
|
@@ -498,17 +507,25 @@ class AceStepHandler:
|
|
| 498 |
from vector_quantize_pytorch import finite_scalar_quantization as _fsq_mod
|
| 499 |
|
| 500 |
_orig_tensor = _rfsq_mod.tensor
|
|
|
|
| 501 |
|
| 502 |
def _cpu_tensor(data, *args, **kwargs):
|
| 503 |
kwargs["device"] = "cpu"
|
| 504 |
return _orig_tensor(data, *args, **kwargs)
|
| 505 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 506 |
_vq_originals["rfsq"] = _rfsq_mod.tensor
|
| 507 |
_vq_originals["fsq"] = _fsq_mod.tensor
|
|
|
|
| 508 |
_rfsq_mod.tensor = _cpu_tensor
|
| 509 |
_fsq_mod.tensor = _cpu_tensor
|
|
|
|
| 510 |
_vq_patched = True
|
| 511 |
-
logger.info("[initialize_service] Patched vector_quantize_pytorch for meta device compat")
|
| 512 |
except (ImportError, AttributeError):
|
| 513 |
pass
|
| 514 |
|
|
@@ -535,6 +552,7 @@ class AceStepHandler:
|
|
| 535 |
if _vq_patched:
|
| 536 |
_rfsq_mod.tensor = _vq_originals["rfsq"]
|
| 537 |
_fsq_mod.tensor = _vq_originals["fsq"]
|
|
|
|
| 538 |
|
| 539 |
if self.model is None:
|
| 540 |
raise RuntimeError(
|
|
|
|
| 484 |
attn_candidates.append("eager")
|
| 485 |
|
| 486 |
# Patch vector_quantize_pytorch to avoid meta tensor failures.
|
| 487 |
+
#
|
| 488 |
+
# Newer transformers ignores low_cpu_mem_usage/fast_init and
|
| 489 |
+
# always wraps model __init__ in `with torch.device("meta"):`,
|
| 490 |
+
# which redirects ALL factory functions to meta device.
|
| 491 |
+
#
|
| 492 |
+
# vector_quantize_pytorch's FSQ and ResidualFSQ do real
|
| 493 |
+
# computation during __init__ (assertions, codebook building)
|
| 494 |
+
# that is incompatible with meta tensors.
|
| 495 |
+
#
|
| 496 |
+
# We patch two layers:
|
| 497 |
+
# 1. `tensor` in both VQ module namespaces (they use
|
| 498 |
+
# `from torch import tensor`, a direct reference)
|
| 499 |
+
# 2. `torch.arange` globally (FSQ calls `torch.arange()`
|
| 500 |
+
# through the torch module to build its implicit codebook)
|
| 501 |
+
#
|
| 502 |
+
# Both patches force device="cpu", then we restore after loading.
|
| 503 |
_vq_patched = False
|
| 504 |
_vq_originals = {}
|
| 505 |
try:
|
|
|
|
| 507 |
from vector_quantize_pytorch import finite_scalar_quantization as _fsq_mod
|
| 508 |
|
| 509 |
_orig_tensor = _rfsq_mod.tensor
|
| 510 |
+
_orig_arange = torch.arange
|
| 511 |
|
| 512 |
def _cpu_tensor(data, *args, **kwargs):
|
| 513 |
kwargs["device"] = "cpu"
|
| 514 |
return _orig_tensor(data, *args, **kwargs)
|
| 515 |
|
| 516 |
+
def _cpu_arange(*args, **kwargs):
|
| 517 |
+
if "device" not in kwargs:
|
| 518 |
+
kwargs["device"] = "cpu"
|
| 519 |
+
return _orig_arange(*args, **kwargs)
|
| 520 |
+
|
| 521 |
_vq_originals["rfsq"] = _rfsq_mod.tensor
|
| 522 |
_vq_originals["fsq"] = _fsq_mod.tensor
|
| 523 |
+
_vq_originals["arange"] = _orig_arange
|
| 524 |
_rfsq_mod.tensor = _cpu_tensor
|
| 525 |
_fsq_mod.tensor = _cpu_tensor
|
| 526 |
+
torch.arange = _cpu_arange
|
| 527 |
_vq_patched = True
|
| 528 |
+
logger.info("[initialize_service] Patched vector_quantize_pytorch + torch.arange for meta device compat")
|
| 529 |
except (ImportError, AttributeError):
|
| 530 |
pass
|
| 531 |
|
|
|
|
| 552 |
if _vq_patched:
|
| 553 |
_rfsq_mod.tensor = _vq_originals["rfsq"]
|
| 554 |
_fsq_mod.tensor = _vq_originals["fsq"]
|
| 555 |
+
torch.arange = _vq_originals["arange"]
|
| 556 |
|
| 557 |
if self.model is None:
|
| 558 |
raise RuntimeError(
|