| from __future__ import annotations | |
| import torch | |
| from sglang.srt.mem_cache.allocator import PagedTokenToKVPoolAllocator | |
| from sglang.srt.utils import get_num_new_pages | |
| def alloc_extend_kernel_ascend( | |
| prefix_lens, | |
| seq_lens, | |
| last_loc, | |
| free_pages, | |
| out_indices, | |
| page_size, | |
| device, | |
| ): | |
| extend_lens = seq_lens - prefix_lens | |
| end_pos = torch.cumsum(extend_lens, 0) | |
| start_pos = end_pos - extend_lens | |
| num_new_pages = (seq_lens + page_size - 1) // page_size - ( | |
| prefix_lens + page_size - 1 | |
| ) // page_size | |
| num_full_new_pages = (seq_lens) // page_size - ( | |
| prefix_lens + page_size - 1 | |
| ) // page_size | |
| need_page = num_new_pages - num_full_new_pages | |
| end_new_pages = torch.cumsum(num_new_pages, 0) | |
| start_new_pages = end_new_pages - num_new_pages | |
| pos_in_page = torch.arange(page_size, device=device, dtype=torch.int32) | |
| for i in range(len(prefix_lens)): | |
| num1 = ( | |
| min( | |
| seq_lens[i], | |
| (prefix_lens[i] + page_size - 1) // page_size * page_size, | |
| ) | |
| - prefix_lens[i] | |
| ) | |
| if num1: | |
| out_indices[start_pos[i] : start_pos[i] + num1] = ( | |
| last_loc[i] + 1 + pos_in_page[:num1].view(-1) | |
| ) | |
| num2 = ( | |
| seq_lens[i] // page_size - (prefix_lens[i] + page_size - 1) // page_size | |
| ) * page_size | |
| if num2: | |
| pages = ( | |
| free_pages[start_new_pages[i] : end_new_pages[i] - need_page[i]] | |
| * page_size | |
| ) | |
| out_indices[start_pos[i] + num1 : start_pos[i] + num1 + num2] = ( | |
| pages.view(-1, 1) + pos_in_page.view(1, -1) | |
| ).view(-1) | |
| num3 = seq_lens[i] - seq_lens[i] // page_size * page_size | |
| if num3: | |
| out_indices[end_pos[i] - num3 : end_pos[i]] = ( | |
| free_pages[end_new_pages[i] - 1] * page_size + pos_in_page[:num3] | |
| ).view(-1) | |
| class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator): | |
| def alloc_extend( | |
| self, | |
| prefix_lens: torch.Tensor, | |
| prefix_lens_cpu: torch.Tensor, | |
| seq_lens: torch.Tensor, | |
| seq_lens_cpu: torch.Tensor, | |
| last_loc: torch.Tensor, | |
| extend_num_tokens: int, | |
| ): | |
| if self.debug_mode: | |
| assert torch.all( | |
| (last_loc + 1) % self.page_size == prefix_lens % self.page_size | |
| ) | |
| num_new_pages = ( | |
| (seq_lens + self.page_size - 1) // self.page_size | |
| - (prefix_lens + self.page_size - 1) // self.page_size | |
| ).sum() | |
| num_new_pages_item = num_new_pages.item() | |
| if self.need_sort and num_new_pages_item > len(self.free_pages): | |
| self.merge_and_sort_free() | |
| if num_new_pages_item > len(self.free_pages): | |
| return None | |
| out_indices = torch.empty( | |
| (extend_num_tokens,), dtype=torch.int64, device=self.device | |
| ) | |
| if num_new_pages_item < 200: | |
| import sgl_kernel_npu # noqa: F401 | |
| torch.ops.npu.alloc_extend( | |
| prefix_lens, | |
| seq_lens, | |
| last_loc, | |
| self.free_pages, | |
| self.page_size, | |
| out_indices, | |
| num_new_pages, | |
| ) | |
| else: | |
| alloc_extend_kernel_ascend( | |
| prefix_lens, | |
| seq_lens, | |
| last_loc, | |
| self.free_pages, | |
| out_indices, | |
| self.page_size, | |
| self.device, | |
| ) | |
| if self.debug_mode: | |
| assert len(torch.unique(out_indices)) == len(out_indices) | |
| self.free_pages = self.free_pages[num_new_pages_item:] | |
| return out_indices.int() | |
| def alloc_decode( | |
| self, | |
| seq_lens: torch.Tensor, | |
| seq_lens_cpu: torch.Tensor, | |
| last_loc: torch.Tensor, | |
| ): | |
| if self.debug_mode: | |
| assert torch.all( | |
| (last_loc + 2) % self.page_size == seq_lens % self.page_size | |
| ) | |
| num_new_pages = get_num_new_pages( | |
| seq_lens=seq_lens_cpu, | |
| page_size=self.page_size, | |
| decode=True, | |
| ) | |
| if num_new_pages > len(self.free_pages): | |
| self.merge_and_sort_free() | |
| if num_new_pages > len(self.free_pages): | |
| return None | |
| need_new_pages = (seq_lens % self.page_size == 1).int() | |
| end_new_pages = torch.cumsum(need_new_pages, 0) | |
| start_new_pages = end_new_pages - need_new_pages | |
| if num_new_pages == 0: | |
| out_indices = last_loc + 1 | |
| else: | |
| out_indices = (last_loc + 1) * (1 - need_new_pages) + self.free_pages[ | |
| start_new_pages | |
| ] * self.page_size * need_new_pages | |
| if self.debug_mode: | |
| assert len(torch.unique(out_indices)) == len(out_indices) | |
| self.free_pages = self.free_pages[num_new_pages:] | |
| return out_indices.int() | |
Xet Storage Details
- Size:
- 5.02 kB
- Xet hash:
- 851a9a1c72e6390bd1dd8b766af3e22bec4c106e9b91c3dbc342510f9a45f33b
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.