ACE-Step Custom commited on
Commit
aa918f7
·
1 Parent(s): 052ca84

Fix: Remove device_map to prevent meta tensor initialization - ACE-Step custom model needs direct device placement

Browse files
Files changed (2) hide show
  1. acestep/handler.py +6 -5
  2. acestep/llm_inference.py +9 -8
acestep/handler.py CHANGED
@@ -495,7 +495,6 @@ class AceStepHandler:
495
  torch_dtype=self.dtype,
496
  low_cpu_mem_usage=False, # Disable memory-efficient weight loading
497
  _fast_init=False, # Disable fast initialization (prevents meta device)
498
- device_map={"": device}, # Explicitly map all components to target device
499
  )
500
  attn_implementation = candidate
501
  break
@@ -569,9 +568,10 @@ class AceStepHandler:
569
  # Determine target device for VAE
570
  vae_device = device if not self.offload_to_cpu else "cpu"
571
  self.vae = AutoencoderOobleck.from_pretrained(
572
- vae_checkpoint_path,
573
- device_map={"": vae_device} # Explicitly map to target device
574
  )
 
 
575
  if not self.offload_to_cpu:
576
  # Keep VAE in GPU precision when resident on accelerator.
577
  vae_dtype = self._get_vae_dtype(device)
@@ -602,9 +602,10 @@ class AceStepHandler:
602
  # Determine target device for text encoder
603
  text_encoder_device = device if not self.offload_to_cpu else "cpu"
604
  self.text_encoder = AutoModel.from_pretrained(
605
- text_encoder_path,
606
- device_map={"": text_encoder_device} # Explicitly map to target device
607
  )
 
 
608
  if not self.offload_to_cpu:
609
  self.text_encoder = self.text_encoder.to(self.dtype)
610
  else:
 
495
  torch_dtype=self.dtype,
496
  low_cpu_mem_usage=False, # Disable memory-efficient weight loading
497
  _fast_init=False, # Disable fast initialization (prevents meta device)
 
498
  )
499
  attn_implementation = candidate
500
  break
 
568
  # Determine target device for VAE
569
  vae_device = device if not self.offload_to_cpu else "cpu"
570
  self.vae = AutoencoderOobleck.from_pretrained(
571
+ vae_checkpoint_path
 
572
  )
573
+ # Move VAE to target device
574
+ self.vae = self.vae.to(vae_device)
575
  if not self.offload_to_cpu:
576
  # Keep VAE in GPU precision when resident on accelerator.
577
  vae_dtype = self._get_vae_dtype(device)
 
602
  # Determine target device for text encoder
603
  text_encoder_device = device if not self.offload_to_cpu else "cpu"
604
  self.text_encoder = AutoModel.from_pretrained(
605
+ text_encoder_path
 
606
  )
607
+ # Move text encoder to target device
608
+ self.text_encoder = self.text_encoder.to(text_encoder_device)
609
  if not self.offload_to_cpu:
610
  self.text_encoder = self.text_encoder.to(self.dtype)
611
  else:
acestep/llm_inference.py CHANGED
@@ -278,9 +278,10 @@ class LLMHandler:
278
  target_device = device if not self.offload_to_cpu else "cpu"
279
  self.llm = AutoModelForCausalLM.from_pretrained(
280
  model_path,
281
- trust_remote_code=True,
282
- device_map={"": target_device} # Explicitly map to target device
283
  )
 
 
284
  if not self.offload_to_cpu:
285
  self.llm = self.llm.to(self.dtype)
286
  else:
@@ -3024,13 +3025,13 @@ class LLMHandler:
3024
  self._hf_model_for_scoring = AutoModelForCausalLM.from_pretrained(
3025
  model_path,
3026
  trust_remote_code=True,
3027
- torch_dtype=self.dtype,
3028
- device_map={"": str(device)} # Explicitly map to vLLM device
3029
  )
 
 
3030
  load_time = time.time() - start_time
3031
  logger.info(f"HuggingFace model loaded in {load_time:.2f}s")
3032
 
3033
- # Already on device from device_map
3034
  self._hf_model_for_scoring.eval()
3035
 
3036
  logger.info(f"HuggingFace model for scoring ready on {device}")
@@ -3054,13 +3055,13 @@ class LLMHandler:
3054
  self._hf_model_for_scoring = AutoModelForCausalLM.from_pretrained(
3055
  model_path,
3056
  trust_remote_code=True,
3057
- torch_dtype=self.dtype,
3058
- device_map={"": device} # Explicitly map to target device
3059
  )
 
 
3060
  load_time = time.time() - start_time
3061
  logger.info(f"HuggingFace model loaded in {load_time:.2f}s")
3062
 
3063
- # Already on device from device_map
3064
  self._hf_model_for_scoring.eval()
3065
 
3066
  logger.info(f"HuggingFace model for scoring ready on {device}")
 
278
  target_device = device if not self.offload_to_cpu else "cpu"
279
  self.llm = AutoModelForCausalLM.from_pretrained(
280
  model_path,
281
+ trust_remote_code=True
 
282
  )
283
+ # Move model to target device
284
+ self.llm = self.llm.to(target_device)
285
  if not self.offload_to_cpu:
286
  self.llm = self.llm.to(self.dtype)
287
  else:
 
3025
  self._hf_model_for_scoring = AutoModelForCausalLM.from_pretrained(
3026
  model_path,
3027
  trust_remote_code=True,
3028
+ torch_dtype=self.dtype
 
3029
  )
3030
+ # Move model to vLLM device
3031
+ self._hf_model_for_scoring = self._hf_model_for_scoring.to(device)
3032
  load_time = time.time() - start_time
3033
  logger.info(f"HuggingFace model loaded in {load_time:.2f}s")
3034
 
 
3035
  self._hf_model_for_scoring.eval()
3036
 
3037
  logger.info(f"HuggingFace model for scoring ready on {device}")
 
3055
  self._hf_model_for_scoring = AutoModelForCausalLM.from_pretrained(
3056
  model_path,
3057
  trust_remote_code=True,
3058
+ torch_dtype=self.dtype
 
3059
  )
3060
+ # Move model to target device
3061
+ self._hf_model_for_scoring = self._hf_model_for_scoring.to(device)
3062
  load_time = time.time() - start_time
3063
  logger.info(f"HuggingFace model loaded in {load_time:.2f}s")
3064
 
 
3065
  self._hf_model_for_scoring.eval()
3066
 
3067
  logger.info(f"HuggingFace model for scoring ready on {device}")