ACE-Step Custom commited on
Commit
fc9be45
·
1 Parent(s): aa918f7

Fix: Force CPU device context during model init to prevent meta tensor operations

Browse files
Files changed (1) hide show
  1. acestep/handler.py +21 -16
acestep/handler.py CHANGED
@@ -483,24 +483,29 @@ class AceStepHandler:
483
  if "eager" not in attn_candidates:
484
  attn_candidates.append("eager")
485
 
 
486
  last_attn_error = None
487
  self.model = None
488
- for candidate in attn_candidates:
489
- try:
490
- logger.info(f"[initialize_service] Attempting to load model with attention implementation: {candidate}")
491
- self.model = AutoModel.from_pretrained(
492
- acestep_v15_checkpoint_path,
493
- trust_remote_code=True,
494
- attn_implementation=candidate,
495
- torch_dtype=self.dtype,
496
- low_cpu_mem_usage=False, # Disable memory-efficient weight loading
497
- _fast_init=False, # Disable fast initialization (prevents meta device)
498
- )
499
- attn_implementation = candidate
500
- break
501
- except Exception as e:
502
- last_attn_error = e
503
- logger.warning(f"[initialize_service] Failed to load model with {candidate}: {e}")
 
 
 
 
504
 
505
  if self.model is None:
506
  raise RuntimeError(
 
483
  if "eager" not in attn_candidates:
484
  attn_candidates.append("eager")
485
 
486
+
487
  last_attn_error = None
488
  self.model = None
489
+
490
+ # Use device context to force model initialization on CPU instead of meta device
491
+ # ACE-Step's ResidualFSQ performs tensor assertions during __init__ that fail on meta device
492
+ with torch.device("cpu"):
493
+ for candidate in attn_candidates:
494
+ try:
495
+ logger.info(f"[initialize_service] Attempting to load model with attention implementation: {candidate}")
496
+ self.model = AutoModel.from_pretrained(
497
+ acestep_v15_checkpoint_path,
498
+ trust_remote_code=True,
499
+ attn_implementation=candidate,
500
+ torch_dtype=self.dtype,
501
+ low_cpu_mem_usage=False, # Disable memory-efficient weight loading
502
+ _fast_init=False, # Disable fast initialization
503
+ )
504
+ attn_implementation = candidate
505
+ break
506
+ except Exception as e:
507
+ last_attn_error = e
508
+ logger.warning(f"[initialize_service] Failed to load model with {candidate}: {e}")
509
 
510
  if self.model is None:
511
  raise RuntimeError(