| from typing import List, Optional, Union | |
| import torch | |
| def apply_token_bitmask_inplace_cuda( | |
| logits: torch.Tensor, | |
| bitmask: torch.Tensor, | |
| indices: Optional[Union[List[int], torch.Tensor]] = None, | |
| ) -> None: | |
| if isinstance(indices, list): | |
| indices = torch.tensor(indices, dtype=torch.int32, device=logits.device) | |
| if indices is not None: | |
| indices = indices.to(logits.device) | |
| torch.ops.sgl_kernel.apply_token_bitmask_inplace_cuda(logits, bitmask, indices) | |