Spaces:
Running
on
Zero
Running
on
Zero
ACE-Step Custom
commited on
Commit
·
aa918f7
1
Parent(s):
052ca84
Fix: Remove device_map to prevent meta tensor initialization - ACE-Step custom model needs direct device placement
Browse files- acestep/handler.py +6 -5
- acestep/llm_inference.py +9 -8
acestep/handler.py
CHANGED
|
@@ -495,7 +495,6 @@ class AceStepHandler:
|
|
| 495 |
torch_dtype=self.dtype,
|
| 496 |
low_cpu_mem_usage=False, # Disable memory-efficient weight loading
|
| 497 |
_fast_init=False, # Disable fast initialization (prevents meta device)
|
| 498 |
-
device_map={"": device}, # Explicitly map all components to target device
|
| 499 |
)
|
| 500 |
attn_implementation = candidate
|
| 501 |
break
|
|
@@ -569,9 +568,10 @@ class AceStepHandler:
|
|
| 569 |
# Determine target device for VAE
|
| 570 |
vae_device = device if not self.offload_to_cpu else "cpu"
|
| 571 |
self.vae = AutoencoderOobleck.from_pretrained(
|
| 572 |
-
vae_checkpoint_path
|
| 573 |
-
device_map={"": vae_device} # Explicitly map to target device
|
| 574 |
)
|
|
|
|
|
|
|
| 575 |
if not self.offload_to_cpu:
|
| 576 |
# Keep VAE in GPU precision when resident on accelerator.
|
| 577 |
vae_dtype = self._get_vae_dtype(device)
|
|
@@ -602,9 +602,10 @@ class AceStepHandler:
|
|
| 602 |
# Determine target device for text encoder
|
| 603 |
text_encoder_device = device if not self.offload_to_cpu else "cpu"
|
| 604 |
self.text_encoder = AutoModel.from_pretrained(
|
| 605 |
-
text_encoder_path
|
| 606 |
-
device_map={"": text_encoder_device} # Explicitly map to target device
|
| 607 |
)
|
|
|
|
|
|
|
| 608 |
if not self.offload_to_cpu:
|
| 609 |
self.text_encoder = self.text_encoder.to(self.dtype)
|
| 610 |
else:
|
|
|
|
| 495 |
torch_dtype=self.dtype,
|
| 496 |
low_cpu_mem_usage=False, # Disable memory-efficient weight loading
|
| 497 |
_fast_init=False, # Disable fast initialization (prevents meta device)
|
|
|
|
| 498 |
)
|
| 499 |
attn_implementation = candidate
|
| 500 |
break
|
|
|
|
| 568 |
# Determine target device for VAE
|
| 569 |
vae_device = device if not self.offload_to_cpu else "cpu"
|
| 570 |
self.vae = AutoencoderOobleck.from_pretrained(
|
| 571 |
+
vae_checkpoint_path
|
|
|
|
| 572 |
)
|
| 573 |
+
# Move VAE to target device
|
| 574 |
+
self.vae = self.vae.to(vae_device)
|
| 575 |
if not self.offload_to_cpu:
|
| 576 |
# Keep VAE in GPU precision when resident on accelerator.
|
| 577 |
vae_dtype = self._get_vae_dtype(device)
|
|
|
|
| 602 |
# Determine target device for text encoder
|
| 603 |
text_encoder_device = device if not self.offload_to_cpu else "cpu"
|
| 604 |
self.text_encoder = AutoModel.from_pretrained(
|
| 605 |
+
text_encoder_path
|
|
|
|
| 606 |
)
|
| 607 |
+
# Move text encoder to target device
|
| 608 |
+
self.text_encoder = self.text_encoder.to(text_encoder_device)
|
| 609 |
if not self.offload_to_cpu:
|
| 610 |
self.text_encoder = self.text_encoder.to(self.dtype)
|
| 611 |
else:
|
acestep/llm_inference.py
CHANGED
|
@@ -278,9 +278,10 @@ class LLMHandler:
|
|
| 278 |
target_device = device if not self.offload_to_cpu else "cpu"
|
| 279 |
self.llm = AutoModelForCausalLM.from_pretrained(
|
| 280 |
model_path,
|
| 281 |
-
trust_remote_code=True
|
| 282 |
-
device_map={"": target_device} # Explicitly map to target device
|
| 283 |
)
|
|
|
|
|
|
|
| 284 |
if not self.offload_to_cpu:
|
| 285 |
self.llm = self.llm.to(self.dtype)
|
| 286 |
else:
|
|
@@ -3024,13 +3025,13 @@ class LLMHandler:
|
|
| 3024 |
self._hf_model_for_scoring = AutoModelForCausalLM.from_pretrained(
|
| 3025 |
model_path,
|
| 3026 |
trust_remote_code=True,
|
| 3027 |
-
torch_dtype=self.dtype
|
| 3028 |
-
device_map={"": str(device)} # Explicitly map to vLLM device
|
| 3029 |
)
|
|
|
|
|
|
|
| 3030 |
load_time = time.time() - start_time
|
| 3031 |
logger.info(f"HuggingFace model loaded in {load_time:.2f}s")
|
| 3032 |
|
| 3033 |
-
# Already on device from device_map
|
| 3034 |
self._hf_model_for_scoring.eval()
|
| 3035 |
|
| 3036 |
logger.info(f"HuggingFace model for scoring ready on {device}")
|
|
@@ -3054,13 +3055,13 @@ class LLMHandler:
|
|
| 3054 |
self._hf_model_for_scoring = AutoModelForCausalLM.from_pretrained(
|
| 3055 |
model_path,
|
| 3056 |
trust_remote_code=True,
|
| 3057 |
-
torch_dtype=self.dtype
|
| 3058 |
-
device_map={"": device} # Explicitly map to target device
|
| 3059 |
)
|
|
|
|
|
|
|
| 3060 |
load_time = time.time() - start_time
|
| 3061 |
logger.info(f"HuggingFace model loaded in {load_time:.2f}s")
|
| 3062 |
|
| 3063 |
-
# Already on device from device_map
|
| 3064 |
self._hf_model_for_scoring.eval()
|
| 3065 |
|
| 3066 |
logger.info(f"HuggingFace model for scoring ready on {device}")
|
|
|
|
| 278 |
target_device = device if not self.offload_to_cpu else "cpu"
|
| 279 |
self.llm = AutoModelForCausalLM.from_pretrained(
|
| 280 |
model_path,
|
| 281 |
+
trust_remote_code=True
|
|
|
|
| 282 |
)
|
| 283 |
+
# Move model to target device
|
| 284 |
+
self.llm = self.llm.to(target_device)
|
| 285 |
if not self.offload_to_cpu:
|
| 286 |
self.llm = self.llm.to(self.dtype)
|
| 287 |
else:
|
|
|
|
| 3025 |
self._hf_model_for_scoring = AutoModelForCausalLM.from_pretrained(
|
| 3026 |
model_path,
|
| 3027 |
trust_remote_code=True,
|
| 3028 |
+
torch_dtype=self.dtype
|
|
|
|
| 3029 |
)
|
| 3030 |
+
# Move model to vLLM device
|
| 3031 |
+
self._hf_model_for_scoring = self._hf_model_for_scoring.to(device)
|
| 3032 |
load_time = time.time() - start_time
|
| 3033 |
logger.info(f"HuggingFace model loaded in {load_time:.2f}s")
|
| 3034 |
|
|
|
|
| 3035 |
self._hf_model_for_scoring.eval()
|
| 3036 |
|
| 3037 |
logger.info(f"HuggingFace model for scoring ready on {device}")
|
|
|
|
| 3055 |
self._hf_model_for_scoring = AutoModelForCausalLM.from_pretrained(
|
| 3056 |
model_path,
|
| 3057 |
trust_remote_code=True,
|
| 3058 |
+
torch_dtype=self.dtype
|
|
|
|
| 3059 |
)
|
| 3060 |
+
# Move model to target device
|
| 3061 |
+
self._hf_model_for_scoring = self._hf_model_for_scoring.to(device)
|
| 3062 |
load_time = time.time() - start_time
|
| 3063 |
logger.info(f"HuggingFace model loaded in {load_time:.2f}s")
|
| 3064 |
|
|
|
|
| 3065 |
self._hf_model_for_scoring.eval()
|
| 3066 |
|
| 3067 |
logger.info(f"HuggingFace model for scoring ready on {device}")
|