ACE-Step Custom commited on
Commit
180e887
·
1 Parent(s): 897bbd6

Force traditional loading: low_cpu_mem_usage=False + explicit device move

Browse files
Files changed (1) hide show
  1. acestep/handler.py +17 -16
acestep/handler.py CHANGED
@@ -487,23 +487,24 @@ class AceStepHandler:
487
  last_attn_error = None
488
  self.model = None
489
 
490
- # Use device context to force model initialization on CPU instead of meta device
491
  # ACE-Step's ResidualFSQ performs tensor assertions during __init__ that fail on meta device
492
- with torch.device("cpu"):
493
- for candidate in attn_candidates:
494
- try:
495
- logger.info(f"[initialize_service] Attempting to load model with attention implementation: {candidate}")
496
- self.model = AutoModel.from_pretrained(
497
- acestep_v15_checkpoint_path,
498
- trust_remote_code=True,
499
- attn_implementation=candidate,
500
- dtype="bfloat16" # Use string like official demo
501
- )
502
- attn_implementation = candidate
503
- break
504
- except Exception as e:
505
- last_attn_error = e
506
- logger.warning(f"[initialize_service] Failed to load model with {candidate}: {e}")
 
507
 
508
  if self.model is None:
509
  raise RuntimeError(
 
487
  last_attn_error = None
488
  self.model = None
489
 
490
+ # Force traditional loading to avoid meta device initialization
491
  # ACE-Step's ResidualFSQ performs tensor assertions during __init__ that fail on meta device
492
+ for candidate in attn_candidates:
493
+ try:
494
+ logger.info(f"[initialize_service] Attempting to load model with attention implementation: {candidate}")
495
+ self.model = AutoModel.from_pretrained(
496
+ acestep_v15_checkpoint_path,
497
+ trust_remote_code=True,
498
+ attn_implementation=candidate,
499
+ torch_dtype=torch.bfloat16,
500
+ low_cpu_mem_usage=False, # Disable meta device
501
+ device_map=None # No automatic device mapping
502
+ ).to(self.device) # Explicitly move to target device after load
503
+ attn_implementation = candidate
504
+ break
505
+ except Exception as e:
506
+ last_attn_error = e
507
+ logger.warning(f"[initialize_service] Failed to load model with {candidate}: {e}")
508
 
509
  if self.model is None:
510
  raise RuntimeError(