Andrew commited on
Commit
56eb0fe
·
1 Parent(s): 7e3f77a

Patch Tensor.sort/argsort bool CUDA compatibility

Browse files
Files changed (1) hide show
  1. handler.py +37 -0
handler.py CHANGED
@@ -170,6 +170,9 @@ class EndpointHandler:
170
  return
171
 
172
  _orig_sort = torch.sort
 
 
 
173
 
174
  def _sort_bool_cuda_compat(input_tensor, *args, **kwargs):
175
  if (
@@ -186,6 +189,40 @@ class EndpointHandler:
186
  _sort_bool_cuda_compat.__name__ = "_sort_bool_cuda_compat"
187
  torch.sort = _sort_bool_cuda_compat
188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  def _ensure_llm_initialized(self) -> bool:
190
  if self.llm_handler is None:
191
  self.llm_error = "LLM handler is not available"
 
170
  return
171
 
172
  _orig_sort = torch.sort
173
+ _orig_tensor_sort = getattr(torch.Tensor, "sort", None)
174
+ _orig_argsort = getattr(torch, "argsort", None)
175
+ _orig_tensor_argsort = getattr(torch.Tensor, "argsort", None)
176
 
177
  def _sort_bool_cuda_compat(input_tensor, *args, **kwargs):
178
  if (
 
189
  _sort_bool_cuda_compat.__name__ = "_sort_bool_cuda_compat"
190
  torch.sort = _sort_bool_cuda_compat
191
 
192
+ if callable(_orig_tensor_sort):
193
+ def _tensor_sort_bool_cuda_compat(self, *args, **kwargs):
194
+ if self.is_cuda and self.dtype == torch.bool:
195
+ out = _orig_tensor_sort(self.to(torch.uint8), *args, **kwargs)
196
+ values = out.values.to(torch.bool) if hasattr(out, "values") else out[0].to(torch.bool)
197
+ indices = out.indices if hasattr(out, "indices") else out[1]
198
+ return values, indices
199
+ return _orig_tensor_sort(self, *args, **kwargs)
200
+
201
+ _tensor_sort_bool_cuda_compat.__name__ = "_tensor_sort_bool_cuda_compat"
202
+ torch.Tensor.sort = _tensor_sort_bool_cuda_compat
203
+
204
+ if callable(_orig_argsort):
205
+ def _argsort_bool_cuda_compat(input_tensor, *args, **kwargs):
206
+ if (
207
+ isinstance(input_tensor, torch.Tensor)
208
+ and input_tensor.is_cuda
209
+ and input_tensor.dtype == torch.bool
210
+ ):
211
+ return _orig_argsort(input_tensor.to(torch.uint8), *args, **kwargs)
212
+ return _orig_argsort(input_tensor, *args, **kwargs)
213
+
214
+ _argsort_bool_cuda_compat.__name__ = "_argsort_bool_cuda_compat"
215
+ torch.argsort = _argsort_bool_cuda_compat
216
+
217
+ if callable(_orig_tensor_argsort):
218
+ def _tensor_argsort_bool_cuda_compat(self, *args, **kwargs):
219
+ if self.is_cuda and self.dtype == torch.bool:
220
+ return _orig_tensor_argsort(self.to(torch.uint8), *args, **kwargs)
221
+ return _orig_tensor_argsort(self, *args, **kwargs)
222
+
223
+ _tensor_argsort_bool_cuda_compat.__name__ = "_tensor_argsort_bool_cuda_compat"
224
+ torch.Tensor.argsort = _tensor_argsort_bool_cuda_compat
225
+
226
  def _ensure_llm_initialized(self) -> bool:
227
  if self.llm_handler is None:
228
  self.llm_error = "LLM handler is not available"