Andrew commited on
Commit
7e3f77a
·
1 Parent(s): a496155

Patch torch.sort bool CUDA compatibility for endpoint runtime

Browse files
Files changed (1) hide show
  1. handler.py +27 -0
handler.py CHANGED
@@ -91,6 +91,9 @@ class EndpointHandler:
91
  # ACE-Step dynamic config imports layer_type_validation from transformers.
92
  # Some endpoint base images ship a transformers build without this helper.
93
  self._patch_transformers_layer_validation()
 
 
 
94
 
95
  try:
96
  from acestep.handler import AceStepHandler
@@ -159,6 +162,30 @@ class EndpointHandler:
159
 
160
  cu.layer_type_validation = _fallback_layer_type_validation
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  def _ensure_llm_initialized(self) -> bool:
163
  if self.llm_handler is None:
164
  self.llm_error = "LLM handler is not available"
 
91
  # ACE-Step dynamic config imports layer_type_validation from transformers.
92
  # Some endpoint base images ship a transformers build without this helper.
93
  self._patch_transformers_layer_validation()
94
+ # Some CUDA/torch combinations used by managed endpoint images don't support
95
+ # sorting bool tensors on CUDA. ACE-Step/Transformers paths can hit this.
96
+ self._patch_torch_sort_bool_cuda()
97
 
98
  try:
99
  from acestep.handler import AceStepHandler
 
162
 
163
  cu.layer_type_validation = _fallback_layer_type_validation
164
 
165
+ @staticmethod
166
+ def _patch_torch_sort_bool_cuda() -> None:
167
+ if torch is None or not hasattr(torch, "sort"):
168
+ return
169
+ if getattr(torch.sort, "__name__", "") == "_sort_bool_cuda_compat":
170
+ return
171
+
172
+ _orig_sort = torch.sort
173
+
174
+ def _sort_bool_cuda_compat(input_tensor, *args, **kwargs):
175
+ if (
176
+ isinstance(input_tensor, torch.Tensor)
177
+ and input_tensor.is_cuda
178
+ and input_tensor.dtype == torch.bool
179
+ ):
180
+ out = _orig_sort(input_tensor.to(torch.uint8), *args, **kwargs)
181
+ values = out.values.to(torch.bool) if hasattr(out, "values") else out[0].to(torch.bool)
182
+ indices = out.indices if hasattr(out, "indices") else out[1]
183
+ return values, indices
184
+ return _orig_sort(input_tensor, *args, **kwargs)
185
+
186
+ _sort_bool_cuda_compat.__name__ = "_sort_bool_cuda_compat"
187
+ torch.sort = _sort_bool_cuda_compat
188
+
189
  def _ensure_llm_initialized(self) -> bool:
190
  if self.llm_handler is None:
191
  self.llm_error = "LLM handler is not available"