Spaces:
Runtime error
Runtime error
| # Copyright 2025 Musab Gultekin and the LlamaFactory team. | |
| # | |
| # This code is based on the Musab Gultekin's functionary library. | |
| # https://github.com/MeetKai/functionary/blob/main/functionary/train/packing/monkey_patch_packing.py | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # | |
| # MIT License | |
| # | |
| # Copyright (c) 2023 Musab Gultekin | |
| # | |
| # Permission is hereby granted, free of charge, to any person obtaining a copy | |
| # of this software and associated documentation files (the "Software"), to deal | |
| # in the Software without restriction, including without limitation the rights | |
| # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
| # copies of the Software, and to permit persons to whom the Software is | |
| # furnished to do so, subject to the following conditions: | |
| # | |
| # The above copyright notice and this permission notice shall be included in all | |
| # copies or substantial portions of the Software. | |
| # | |
| # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
| # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
| # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
| # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
| # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
| # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
| # SOFTWARE. | |
| from typing import TYPE_CHECKING | |
| import torch | |
| import torch.nn.functional as F | |
| from ...extras import logging | |
| if TYPE_CHECKING: | |
| from ...hparams import ModelArguments | |
| logger = logging.get_logger(__name__) | |
| def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor": | |
| r"""Get the sequnce lengths in the current batch. | |
| e.g. | |
| ```python | |
| # input | |
| [ | |
| [1, 1, 2, 2, 2, 0], | |
| [1, 2, 2, 3, 3, 3], | |
| ] | |
| # output | |
| [2, 3, 1, 2, 3] | |
| ``` | |
| """ | |
| bsz = attention_mask.size(0) | |
| dtype, device = attention_mask.dtype, attention_mask.device | |
| max_num = torch.max(attention_mask).item() | |
| counts: torch.Tensor = torch.zeros((bsz, max_num), dtype=dtype, device=device) | |
| for i in range(max_num): | |
| counts[:, i] = torch.sum(attention_mask == (i + 1), dim=-1) | |
| counts = counts.flatten() | |
| seqlens = counts[counts.nonzero().squeeze(dim=-1)] | |
| return seqlens | |
| def get_unpad_data(attention_mask: "torch.Tensor") -> tuple["torch.Tensor", "torch.Tensor", int]: | |
| r"""Prepare the indices and seqlens for flash attn varlen function. | |
| Returns: | |
| indices: indices of non-masked tokens from the flattened sequence. | |
| cu_seqlens: the cumulative sequence lengths in the current batch, always starts from 0. | |
| max_seqlen_in_batch: the largest seqlen in the current batch. | |
| e.g. | |
| ```python | |
| # input | |
| [ | |
| [1, 1, 2, 2, 2, 0], | |
| [1, 2, 2, 3, 3, 3], | |
| ] | |
| # output | |
| [0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11] | |
| [0, 2, 5, 6, 8, 11] | |
| 3 | |
| ``` | |
| """ | |
| seqlens_in_batch = get_seqlens_in_batch(attention_mask) | |
| indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() | |
| max_seqlen_in_batch = seqlens_in_batch.max().item() | |
| cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) | |
| return indices, cu_seqlens, max_seqlen_in_batch | |
| def configure_packing(model_args: "ModelArguments", is_trainable: bool) -> None: | |
| if not is_trainable or not model_args.block_diag_attn: | |
| return | |
| import transformers.modeling_flash_attention_utils | |
| transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data | |
| logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.") | |