| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Callable |
|
|
| _index_first_axis, _pad_input, _rearrange, _unpad_input = None, None, None, None |
|
|
|
|
| def _get_attention_functions() -> tuple[Callable, Callable, Callable, Callable]: |
| """Dynamically import attention functions based on available hardware.""" |
|
|
| from verl.utils.device import is_torch_npu_available |
|
|
| global _index_first_axis, _pad_input, _rearrange, _unpad_input |
|
|
| if is_torch_npu_available(check_device=False): |
| from verl.utils.npu_flash_attn_utils import index_first_axis, pad_input, rearrange, unpad_input |
| else: |
| from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input |
|
|
| _index_first_axis, _pad_input, _rearrange, _unpad_input = index_first_axis, pad_input, rearrange, unpad_input |
|
|
| return _index_first_axis, _pad_input, _rearrange, _unpad_input |
|
|
|
|
| def index_first_axis(*args, **kwargs): |
| """ |
| Unified entry point for `index_first_axis` across CUDA and NPU backends. |
| |
| Dynamically dispatches to the appropriate device-specific implementation: |
| - On CUDA: `flash_attn.bert_padding.index_first_axis` |
| - On NPU: `transformers.integrations.npu_flash_attention.index_first_axis` |
| (falls back to `transformers.modeling_flash_attention_utils._index_first_axis` |
| in newer versions of transformers). |
| |
| Users can call this function directly without worrying about the underlying device. |
| """ |
| func, *_ = _get_attention_functions() |
| return func(*args, **kwargs) |
|
|
|
|
| def pad_input(*args, **kwargs): |
| """ |
| Unified entry point for `pad_input` across CUDA and NPU backends. |
| |
| Dynamically dispatches to the appropriate device-specific implementation: |
| - On CUDA: `flash_attn.bert_padding.pad_input` |
| - On NPU: `transformers.integrations.npu_flash_attention.pad_input` |
| (falls back to `transformers.modeling_flash_attention_utils._pad_input` |
| in newer versions of transformers). |
| |
| Users can call this function directly without worrying about the underlying device. |
| """ |
| _, func, *_ = _get_attention_functions() |
| return func(*args, **kwargs) |
|
|
|
|
| def rearrange(*args, **kwargs): |
| """ |
| Unified entry point for `rearrange` across CUDA and NPU backends. |
| |
| Dynamically dispatches to the appropriate device-specific implementation: |
| - On CUDA: `flash_attn.bert_padding.rearrange` |
| - On NPU: `transformers.integrations.npu_flash_attention.rearrange` |
| (falls back to `einops.rearrange` if no dedicated NPU implementation exists). |
| |
| Users can call this function directly without worrying about the underlying device. |
| """ |
| *_, func, _ = _get_attention_functions() |
| return func(*args, **kwargs) |
|
|
|
|
| def unpad_input(*args, **kwargs): |
| """ |
| Unified entry point for `unpad_input` across CUDA and NPU backends. |
| |
| Dynamically dispatches to the appropriate device-specific implementation: |
| - On CUDA: `flash_attn.bert_padding.unpad_input` |
| - On NPU: `transformers.integrations.npu_flash_attention.unpad_input` |
| (falls back to `transformers.modeling_flash_attention_utils._unpad_input` |
| in newer versions of transformers). |
| |
| Users can call this function directly without worrying about the underlying device. |
| """ |
| *_, func = _get_attention_functions() |
| return func(*args, **kwargs) |
|
|
|
|
| __all__ = ["index_first_axis", "pad_input", "rearrange", "unpad_input"] |
|
|