Qwen3-8B-DMS-8x / dms_cache.py
alancucki's picture
Add files using upload-large-folder tool
3844f4b verified
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import functools
import gc
import math
from typing import Any
import torch
from transformers import CacheLayerMixin
from transformers.cache_utils import Cache
def ceil_int_div(a: int, b: int) -> int:
return (a + b - 1) // b
def float_ceil(a: float):
return math.ceil(a)
def _aux_potential_eviction(
vals_for_replacement: torch.Tensor,
to_be_evicted_table_block_id: torch.Tensor,
to_be_evicted_position_within_block: torch.Tensor,
to_be_evicted_mask: torch.Tensor,
block_table: torch.Tensor,
blocks: torch.Tensor,
page_batch_index: torch.Tensor,
last_table_block_id: torch.Tensor,
next_position_within_block: torch.Tensor,
):
"""Adding a new element to KV cache may lead to eviction of the last element in the DMS sliding window."""
# For each batch element the block table contains a list of blocks allocated for this batch element
block_ids = block_table[page_batch_index, to_be_evicted_table_block_id]
# Override the last element of the sliding window with the new element if the last element of the sliding window
# is marked for the eviction and the window is full
blocks[block_ids, to_be_evicted_position_within_block, :, :] = (
blocks[block_ids, to_be_evicted_position_within_block, :, :] * (1 - to_be_evicted_mask[:, None, None])
+ vals_for_replacement[:, 0, None, :] * to_be_evicted_mask[:, None, None]
)
# Otherwise write the new element to the next position within the last allocated block
block_ids = block_table[page_batch_index, last_table_block_id]
blocks[block_ids, next_position_within_block, :, :] = blocks[
block_ids, next_position_within_block, :, :
] * to_be_evicted_mask[:, None, None] + vals_for_replacement[:, 0, None, :] * (
1 - to_be_evicted_mask[:, None, None]
)
@torch.compile()
def _aux_update_single(
key_states: torch.Tensor,
value_states: torch.Tensor,
eviction_info: torch.Tensor,
recent_info: torch.Tensor,
recent_info_position: torch.Tensor,
block_table: torch.Tensor,
key_blocks: torch.Tensor,
value_blocks: torch.Tensor,
cache_seq_lenghts: torch.Tensor,
page_batch_index: torch.Tensor,
) -> torch.Tensor:
# page_batch, seq_len, head_dim = key_states.size()
# page_batch, seq_len = eviction_info.size()
# page_batch_index is a tensor of shape (page_batch,): 0, 1, 2, ... page_batch - 1
block_size = key_blocks.size(1)
# `recent_info_position` points to the next position in the sliding window; when the sliding window is full,
# it points to the first position. Not-filled elements are not zeroed out and not marked for eviction
# (see recent_info initialization).
eviction_candidate_info_position = recent_info_position % recent_info.size(1)
eviction_candidate_info = recent_info[
page_batch_index, eviction_candidate_info_position
] # Note that this is zeroed out in the beginning
# `eviction_candidate_info[:, 1]` is 1 when the element is marked for eviction and 0 otherwise
# `block_table[eviction_candidate_info[:, 0] // block_size]` is the block id where the element resides
# and `eviction_candidate_info[:, 0] % block_size` is the position (offset) within the block
to_be_evicted = eviction_candidate_info[:, 1] == 1
to_be_evicted_kv = to_be_evicted.to(key_blocks.dtype)
to_be_evicted_int = to_be_evicted.to(torch.int32)
to_be_evicted_position = eviction_candidate_info[:, 0]
to_be_evicted_table_block_id = to_be_evicted_position // block_size
to_be_evicted_position_within_block = to_be_evicted_position % block_size
last_table_block_id = cache_seq_lenghts // block_size
next_position_within_block = cache_seq_lenghts % block_size
_aux_potential_eviction(
vals_for_replacement=key_states,
to_be_evicted_table_block_id=to_be_evicted_table_block_id,
to_be_evicted_position_within_block=to_be_evicted_position_within_block,
to_be_evicted_mask=to_be_evicted_kv,
block_table=block_table,
blocks=key_blocks,
page_batch_index=page_batch_index,
last_table_block_id=last_table_block_id,
next_position_within_block=next_position_within_block,
)
_aux_potential_eviction(
vals_for_replacement=value_states,
to_be_evicted_table_block_id=to_be_evicted_table_block_id,
to_be_evicted_position_within_block=to_be_evicted_position_within_block,
to_be_evicted_mask=to_be_evicted_kv,
block_table=block_table,
blocks=value_blocks,
page_batch_index=page_batch_index,
last_table_block_id=last_table_block_id,
next_position_within_block=next_position_within_block,
)
final_position = to_be_evicted_position * to_be_evicted_int + (1 - to_be_evicted_int) * (cache_seq_lenghts)
previous_recent_info_position = (recent_info_position + recent_info.size(1) - 1) % recent_info.size(1)
# Update the eviction info for the previous element in the sliding window (if present)
recent_info[page_batch_index, previous_recent_info_position, 1] = (
eviction_info[:, 0] * (cache_seq_lenghts > 0).to(torch.int32)
).to(torch.int32)
# No info about eviction yet for the new element
recent_info[page_batch_index, eviction_candidate_info_position, 1] = 0
recent_info[page_batch_index, eviction_candidate_info_position, 0] = final_position
recent_info_position[...] += 1
cache_seq_lenghts[...] = cache_seq_lenghts + (1 - to_be_evicted_int)
# At the beginning of this function call block_table[cache_seq_lenghts // block_size] points to a block with
# at least one free position; need to maintain this invariant by detecting filled blocks
requires_free_page = torch.logical_and((cache_seq_lenghts % block_size) == 0, to_be_evicted_int == 0)
return requires_free_page
def _aux_get_recent_position_size(cache_seq_lenghts: torch.Tensor, dms_window_size: int) -> torch.Tensor:
return torch.clamp(cache_seq_lenghts, max=dms_window_size)
def _aux_get_first_recent_position(
recent_info_position: torch.Tensor,
cache_seq_lenghts: torch.Tensor,
dms_window_size: int,
) -> torch.Tensor:
return recent_info_position - _aux_get_recent_position_size(
cache_seq_lenghts=cache_seq_lenghts, dms_window_size=dms_window_size
)
def _aux_write_kv(
block_table: torch.Tensor,
blocks: torch.Tensor,
write_positions: torch.Tensor,
values: torch.Tensor,
page_batch_index: torch.Tensor,
):
page_batch, chunk_len = write_positions.size()
block_size = blocks.size(1)
block_table_id = write_positions // block_size
position_within_block = write_positions % block_size
block_id = block_table[page_batch_index[:, None], block_table_id]
assert (block_id != -1).all()
blocks[block_id, position_within_block, :, :] = values[:, :, None, :]
@torch.compile()
def _aux_update_many_handle_single_chunk(
update_key_chunk: torch.Tensor,
update_value_chunk: torch.Tensor,
eviction_info_chunk: torch.Tensor,
block_table: torch.Tensor,
key_blocks: torch.Tensor,
value_blocks: torch.Tensor,
cache_seq_lenghts: torch.Tensor,
recent_info: torch.Tensor,
recent_info_position: torch.Tensor,
page_batch_index: torch.Tensor,
update_mask: torch.Tensor,
true_update_size: torch.Tensor,
) -> torch.Tensor:
"""
Used for prefilling the KV cache as each tensor has a fixed size.
`true_update_size` represents the true number of elements to be added for each batch index.
"""
assert update_key_chunk.size() == update_value_chunk.size()
page_batch, chunk_len, head_dim = update_key_chunk.size()
assert chunk_len < recent_info.size(1)
assert eviction_info_chunk.size() == (page_batch, chunk_len)
assert page_batch_index.size() == (page_batch,)
block_size = key_blocks.size(1)
device = update_key_chunk.device
# First we update the eviction info for the previous element if present
update_eviction_info_positions = (recent_info_position - 1) % recent_info.size(1)
update_eviction_info_mask = (cache_seq_lenghts > 0).to(torch.int32)
recent_info[page_batch_index, update_eviction_info_positions, 1] = (
eviction_info_chunk[:, 0] * update_eviction_info_mask
+ (1 - update_eviction_info_mask) * recent_info[page_batch_index, update_eviction_info_positions, 1]
).to(torch.int32)
chunk_indexer = torch.arange(chunk_len, dtype=torch.int32, device=device)
# The following trick handles variable lens: if the index is longer than true_update_size, then pad the index
# with the last element within the true_update_size, e.g., [0, 1, 2, 3, 4, 5] and true_update_size = [3]
# means that we have [0, 1, 2, 2, 2, 2] . This will later be used to write the same element multiple times
# while preserving the constant shapes of the tensors.
potential_eviction_positions_in_recent_info = (
recent_info_position[:, None] + torch.minimum(chunk_indexer[None, :], true_update_size[:, None] - 1)
) % recent_info.size(1)
potential_eviction_positions_in_seq = recent_info[
page_batch_index[:, None], potential_eviction_positions_in_recent_info, 0
]
confirmed_evictions_mask = (
recent_info[page_batch_index[:, None], potential_eviction_positions_in_recent_info, 1] == 1
)
# Account for the padding with the last element (as described above) to get a proper count of confirmed evictions
confirmed_evictions_mask[:, 1:] = torch.logical_and(
confirmed_evictions_mask[:, 1:],
potential_eviction_positions_in_recent_info[:, 1:] != potential_eviction_positions_in_recent_info[:, :-1],
)
confirmed_evictions_cum_sum = confirmed_evictions_mask.to(torch.int32).cumsum(dim=-1)
confirmed_evictions_mask = torch.logical_and(
confirmed_evictions_mask,
confirmed_evictions_cum_sum <= true_update_size[:, None],
)
# Count how many new positions are needed for each element of the batch
num_confirmed_evictions = confirmed_evictions_mask.to(torch.int32).sum(dim=-1)
new_positions_used = true_update_size - num_confirmed_evictions
assert (new_positions_used >= 0).all()
assert new_positions_used.size() == (page_batch,)
new_free_positions = cache_seq_lenghts[:, None] + torch.clamp(
torch.minimum(chunk_indexer[None, :], new_positions_used[:, None] - 1), min=0
)
assert new_free_positions.size() == (page_batch, chunk_len)
assert new_free_positions.size() == potential_eviction_positions_in_seq.size()
potential_eviction_positions_in_seq = torch.cat(
[
potential_eviction_positions_in_seq,
new_free_positions,
],
dim=-1,
)
# Padding below allows for constant shape ops to take prefix of length new_positions_used from new_free_positions
confirmed_evictions_padding = torch.zeros_like(confirmed_evictions_mask)
padding_chunk_size = chunk_len - num_confirmed_evictions[:, None]
indexer = torch.minimum(chunk_indexer[None, :], torch.clamp(padding_chunk_size - 1, min=0))
confirmed_evictions_padding[page_batch_index[:, None], indexer] = True
# If only post eviction positions are used, then have writing padding that ends in the last of those positions,
# instead of the next free position
confirmed_evictions_padding = torch.logical_and(confirmed_evictions_padding, padding_chunk_size > 0)
confirmed_evictions_mask = torch.cat([confirmed_evictions_mask, confirmed_evictions_padding], dim=-1)
pad_selector = (new_positions_used > 0).to(torch.int32)[:, None]
potential_eviction_positions_in_seq[:, chunk_len:] = (
pad_selector * potential_eviction_positions_in_seq[:, chunk_len:]
+ (1 - pad_selector) * potential_eviction_positions_in_seq[:, [chunk_len - 1]]
)
new_write_positions = potential_eviction_positions_in_seq[confirmed_evictions_mask].reshape(page_batch, chunk_len)
_aux_write_kv(
block_table=block_table,
blocks=key_blocks,
write_positions=new_write_positions,
values=update_key_chunk,
page_batch_index=page_batch_index,
)
_aux_write_kv(
block_table=block_table,
blocks=value_blocks,
write_positions=new_write_positions,
values=update_value_chunk,
page_batch_index=page_batch_index,
)
recent_indexer = torch.minimum(chunk_indexer[None, :], torch.clamp(true_update_size[:, None] - 1, min=0))
recent_info_indexer = (recent_info_position[:, None] + recent_indexer) % recent_info.size(1)
# Update the info about last window positions
non_empty_update = (true_update_size[:, None] > 0).to(torch.int32)
recent_info[page_batch_index[:, None], recent_info_indexer, 0] = (
new_write_positions * non_empty_update
+ recent_info[page_batch_index[:, None], recent_info_indexer, 0] * (1 - non_empty_update)
).to(torch.int32)
eviction_info_chunk = torch.cat(
[
eviction_info_chunk[:, 1:],
torch.zeros_like(eviction_info_chunk[:, [0]]),
],
dim=-1,
)
recent_info[page_batch_index[:, None], recent_info_indexer, 1] = (
eviction_info_chunk[:, :] * non_empty_update
+ recent_info[page_batch_index[:, None], recent_info_indexer, 1] * (1 - non_empty_update)
).to(torch.int32)
recent_info_position[...] += true_update_size
cache_seq_lenghts[...] += new_positions_used
require_free_pages = torch.logical_and(new_positions_used > 0, cache_seq_lenghts % block_size == 0)
return require_free_pages
class DMSCache(Cache):
def __init__(
self,
dms_window_size: int,
max_context_length: int,
offloading: bool = False,
offload_only_non_sliding: bool = False,
accomodate_min_initial_context_length: int = 2048,
block_size: int = 256,
):
super().__init__(
layer_class_to_replicate=functools.partial(
DMSPagedCacheLayer,
dms_window_size=dms_window_size,
max_context_length=max_context_length,
accomodate_min_initial_context_length=accomodate_min_initial_context_length,
block_size=block_size,
),
offloading=offloading,
offload_only_non_sliding=offload_only_non_sliding,
)
def to_legacy_cache(self):
raise NotImplementedError("Not Supported")
@classmethod
def from_legacy_cache(cls, *args, **kwargs):
raise NotImplementedError("Not Supported")
def early_initialization(self, *args, **kwargs):
raise NotImplementedError("Not Supported")
def __iter__(self):
raise NotImplementedError("Not Supported")
def __getitem__(self, layer_idx: int):
assert layer_idx < len(self.layers)
return self.layers[layer_idx]
class DMSPagedCacheLayer(CacheLayerMixin):
def __init__(
self,
dms_window_size: int,
max_context_length: int,
block_size: int = 256,
growth_factor: float = 1.5,
accomodate_min_initial_context_length: int = 4096,
):
super().__init__()
assert block_size <= dms_window_size
self.block_size = block_size
self.dms_window_size = dms_window_size
self.prefill_chunk_size = max(self.dms_window_size - 2, block_size)
assert self.prefill_chunk_size > 0
self.growth_factor = growth_factor
self.min_initial_context_length = accomodate_min_initial_context_length
self.max_context_length = max_context_length
self.max_blocks_per_sequence = ceil_int_div(self.max_context_length, self.block_size)
self.key_blocks = None
self.value_blocks = None
self.block_table = None
self.free_page_ids = None
self.cache_seq_lengths = None
self.recent_info = None # Position and eviction info of last window_size keys/values
self.recent_info_position = None
self.device = None
self.cumulative_length = 0
def offload(self):
if self.key_blocks is not None:
self.key_blocks = self.key_blocks.to("cpu", non_blocking=True)
self.value_blocks = self.value_blocks.to("cpu", non_blocking=True)
self.block_table = self.block_table.to("cpu", non_blocking=True)
self.free_page_ids = self.free_page_ids.to("cpu", non_blocking=True)
self.cache_seq_lengths = self.cache_seq_lengths.to("cpu", non_blocking=True)
self.recent_info = self.recent_info.to("cpu", non_blocking=True)
self.recent_info_position = self.recent_info_position.to("cpu", non_blocking=True)
def prefetch(self):
if self.key_blocks is not None and self.key_blocks.device != self.device:
self.key_blocks = self.key_blocks.to(self.device, non_blocking=True)
self.value_blocks = self.value_blocks.to(self.device, non_blocking=True)
self.block_table = self.block_table.to(self.device, non_blocking=True)
self.free_page_ids = self.free_page_ids.to(self.device, non_blocking=True)
self.cache_seq_lengths = self.cache_seq_lengths.to(self.device, non_blocking=True)
self.recent_info = self.recent_info.to(self.device, non_blocking=True)
self.recent_info_position = self.recent_info_position.to(self.device, non_blocking=True)
def reset(self) -> None:
"""Resets the cache values while preserving the objects"""
print(f"reset {self.key_blocks is not None}")
if self.key_blocks is not None:
self.key_blocks = None
self.value_blocks = None
self.block_table = None
self.free_page_ids = None
self.cache_seq_lengths = None
self.recent_info = None
self.recent_info_position = None
gc.collect()
torch.cuda.empty_cache()
self.cumulative_length = 0
def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
"""Reorders this layer's cache for beam search."""
assert False # No support for beam search at this point
def _get_free_pages(self, num_pages: int):
while len(self.free_page_ids) < num_pages:
def expand_blocks(blocks: torch.Tensor):
return torch.cat(
[
blocks,
blocks.new_zeros(
(
float_ceil(blocks.size(0) * self.growth_factor) - blocks.size(0),
blocks.size(1),
blocks.size(2),
blocks.size(3),
)
),
],
dim=0,
)
old_num_blocks = self.key_blocks.size(0)
self.key_blocks = expand_blocks(self.key_blocks)
self.value_blocks = expand_blocks(self.value_blocks)
assert self.key_blocks.size(0) == self.value_blocks.size(0)
self.free_page_ids = torch.cat(
[
self.free_page_ids,
torch.arange(
old_num_blocks,
self.key_blocks.size(0),
dtype=torch.int32,
device=self.device,
),
],
dim=0,
)
result = self.free_page_ids[:num_pages]
assert result.size() == (num_pages,)
self.free_page_ids = self.free_page_ids[num_pages:]
return result
def lazy_initialization(self, key_states: torch.Tensor):
self.dtype, self.device = key_states.dtype, key_states.device
self.batch_size, self.num_heads, _, self.head_dim = key_states.shape
self.page_batch = self.batch_size * self.num_heads
initial_num_blocks = max(
ceil_int_div(self.min_initial_context_length, self.block_size) * self.page_batch,
self.page_batch,
)
self.block_table = -torch.ones(
self.page_batch,
self.max_blocks_per_sequence + 1, # +1 for handling full cache case
dtype=torch.int32,
device=self.device,
)
self.key_blocks = torch.zeros(
(initial_num_blocks, self.block_size, 1, self.head_dim),
dtype=self.dtype,
device=self.device,
)
self.value_blocks = torch.zeros(
(initial_num_blocks, self.block_size, 1, self.head_dim),
dtype=self.dtype,
device=self.device,
)
self.free_page_ids = torch.arange(0, initial_num_blocks, dtype=torch.int32, device=self.device)
self.cache_seq_lengths = torch.zeros(self.page_batch, dtype=torch.int32, device=self.device)
self.recent_info = torch.zeros(
(self.page_batch, self.dms_window_size, 2),
dtype=torch.int32,
device=self.device,
)
self.recent_info_position = torch.zeros((self.page_batch,), dtype=torch.int32, device=self.device)
self.block_table[:, 0] = self._get_free_pages(self.block_table.size(0))
def _handle_page_allocation(self, requires_free_page: torch.Tensor, page_batch_index: torch.Tensor):
if requires_free_page.any():
req_free_pages = page_batch_index[requires_free_page]
free_pages = self._get_free_pages(len(req_free_pages))
self.block_table[
req_free_pages,
self.cache_seq_lengths[req_free_pages] // self.block_size,
] = free_pages
def _update_single(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
eviction_info: torch.Tensor,
):
batch_x_head, seq_len, head_dim = key_states.size()
page_batch_index = torch.arange(batch_x_head, dtype=torch.int32, device=self.device)
assert seq_len == 1
requires_free_page = _aux_update_single(
key_states=key_states,
value_states=value_states,
eviction_info=eviction_info,
recent_info=self.recent_info,
recent_info_position=self.recent_info_position,
block_table=self.block_table,
key_blocks=self.key_blocks,
value_blocks=self.value_blocks,
cache_seq_lenghts=self.cache_seq_lengths,
page_batch_index=page_batch_index,
)
self._handle_page_allocation(requires_free_page=requires_free_page, page_batch_index=page_batch_index)
# NOTE: Prefill is not yet optimized
def _update_many(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
eviction_info: torch.Tensor,
sequence_lengths: torch.Tensor,
):
# Assume key and value states are left padded, e.g., [_, _, _, 1, 2, 3, 4]
page_batch, seq_len, head_dim = key_states.size()
assert page_batch == self.page_batch
assert head_dim == self.head_dim
assert eviction_info.size() == (page_batch, seq_len)
assert sequence_lengths.ndim == 1
assert sequence_lengths.min() > 0, sequence_lengths
start_positions = seq_len - sequence_lengths
end_positions = start_positions + sequence_lengths
page_batch_index = torch.arange(page_batch, dtype=torch.int32, device=self.device)
while (start_positions < end_positions).any():
chunk_indexer = torch.arange(self.prefill_chunk_size, dtype=torch.int32, device=self.device)[None, :]
update_mask = chunk_indexer < (self.block_size - (self.cache_seq_lengths[:, None] % self.block_size))
chunk_indexer = start_positions[:, None] + chunk_indexer
update_mask = torch.logical_and(update_mask, chunk_indexer < end_positions[:, None])
chunk_indexer = torch.clamp(torch.minimum(chunk_indexer, end_positions[:, None] - 1), min=0)
true_update_size = update_mask.to(torch.int32).sum(dim=1)
chunk_indexer = torch.clamp(
torch.minimum(
chunk_indexer,
start_positions[:, None] + true_update_size[:, None] - 1,
),
min=0,
)
key_chunk = key_states[page_batch_index[:, None], chunk_indexer]
value_chunk = value_states[page_batch_index[:, None], chunk_indexer]
eviction_info_chunk = eviction_info[page_batch_index[:, None], chunk_indexer]
requires_free_page = _aux_update_many_handle_single_chunk(
update_key_chunk=key_chunk,
update_value_chunk=value_chunk,
eviction_info_chunk=eviction_info_chunk,
block_table=self.block_table,
key_blocks=self.key_blocks,
value_blocks=self.value_blocks,
cache_seq_lenghts=self.cache_seq_lengths,
recent_info=self.recent_info,
recent_info_position=self.recent_info_position,
page_batch_index=page_batch_index,
update_mask=update_mask,
true_update_size=true_update_size,
)
self._handle_page_allocation(requires_free_page=requires_free_page, page_batch_index=page_batch_index)
start_positions[...] += true_update_size
def get_contiguous_cache(
self, right_padded: bool = True
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
assert self.key_blocks is not None
num_blocks_per_sequence = self.cache_seq_lengths.max().item() // self.block_size + 1
blocks_to_retrieve = self.block_table[:, :num_blocks_per_sequence]
# page_batch_index = torch.arange(self.batch_size * self.num_heads, dtype=torch.int32, device=self.device)[:, None]
max_length = self.cache_seq_lengths.max().item()
def handle_one(
blocks: torch.Tensor,
blocks_to_retrieve: torch.Tensor,
max_length: int,
num_blocks_per_sequence: int,
):
retrieved = blocks[blocks_to_retrieve].reshape(
self.page_batch,
num_blocks_per_sequence * self.block_size,
self.head_dim,
)
retrieved = retrieved[:, :max_length, :]
recent_info_size = torch.clamp(self.cache_seq_lengths, max=self.dms_window_size)
window_index = torch.arange(self.dms_window_size, device=self.device, dtype=torch.int32)
adjusted_window_index = torch.minimum(
window_index[None, :], torch.clamp(recent_info_size[:, None] - 1, min=0)
)
last_pos_data_ptr = (self.recent_info_position[:, None] - 1 - adjusted_window_index) % self.dms_window_size
page_batch_index = torch.arange(self.page_batch, device=self.device, dtype=torch.int32)
window_positions = self.recent_info[page_batch_index[:, None], last_pos_data_ptr, 0]
permutation_index = self.cache_seq_lengths[:, None] - 1 - adjusted_window_index
assert (permutation_index >= 0).all()
permutation = torch.arange(max_length, device=blocks.device, dtype=torch.int32)[None, :]
permutation = torch.minimum(permutation, self.cache_seq_lengths[:, None] - 1)
permutation = torch.broadcast_to(permutation, (self.page_batch, max_length))
non_window_positions = permutation[:, :, None] != window_positions[:, None, :]
non_window_positions = non_window_positions.to(torch.int32).min(dim=-1).values.to(torch.bool)
result_permutation = torch.zeros_like(permutation)
result_permutation[page_batch_index[:, None], permutation_index] = window_positions
# Not yet optimized
num_non_window_positions = non_window_positions.to(torch.int32).sum(dim=-1).cpu().tolist()
for i in range(self.page_batch):
result_permutation[i, : num_non_window_positions[i]] = permutation[i, non_window_positions[i]]
return retrieved[page_batch_index[:, None], result_permutation]
retrieved_keys = handle_one(
blocks=self.key_blocks,
blocks_to_retrieve=blocks_to_retrieve,
max_length=max_length,
num_blocks_per_sequence=num_blocks_per_sequence,
)
retrieved_values = handle_one(
blocks=self.value_blocks,
blocks_to_retrieve=blocks_to_retrieve,
max_length=max_length,
num_blocks_per_sequence=num_blocks_per_sequence,
)
cache_seq_lenghts = self.cache_seq_lengths
page_batch_index = torch.arange(self.page_batch, device=self.device, dtype=torch.int32)
eviction_info_indexer = torch.arange(self.dms_window_size, device=self.device, dtype=torch.int32)
eviction_info_indexer = torch.minimum(
eviction_info_indexer[None, :],
_aux_get_recent_position_size(
cache_seq_lenghts=self.cache_seq_lengths,
dms_window_size=self.dms_window_size,
)[:, None]
- 1,
)
eviction_info_indexer = (
_aux_get_first_recent_position(
recent_info_position=self.recent_info_position,
cache_seq_lenghts=self.cache_seq_lengths,
dms_window_size=self.dms_window_size,
)[:, None]
+ eviction_info_indexer
)
eviction_info_indexer = eviction_info_indexer % self.dms_window_size
eviction_info = self.recent_info[page_batch_index[:, None], eviction_info_indexer, 1]
if not right_padded:
# Slow, but only used only in prefill
def left_pad_one(x: torch.Tensor, lens: torch.Tensor):
max_length = x.shape[1]
lens = lens.cpu().tolist()
left_padded_content = []
for i in range(self.page_batch):
content = x[i, : lens[i]]
content = torch.cat(
[
x.new_zeros(
(
max_length - lens[i],
*content.shape[1:],
)
),
content,
],
dim=0,
)
left_padded_content.append(content)
return torch.stack(left_padded_content, dim=0)
retrieved_keys = left_pad_one(retrieved_keys, cache_seq_lenghts)
retrieved_values = left_pad_one(retrieved_values, cache_seq_lenghts)
cache_seq_lenghts = cache_seq_lenghts.reshape(self.batch_size, self.num_heads)
eviction_info = left_pad_one(
eviction_info,
_aux_get_recent_position_size(
cache_seq_lenghts=self.cache_seq_lengths,
dms_window_size=self.dms_window_size,
),
)
retrieved_keys = retrieved_keys.reshape(
self.batch_size, self.num_heads, max_length, self.head_dim
).contiguous()
retrieved_values = retrieved_values.reshape(
self.batch_size, self.num_heads, max_length, self.head_dim
).contiguous()
cache_seq_lenghts = cache_seq_lenghts.reshape(self.batch_size, self.num_heads)
eviction_info = eviction_info.reshape(self.batch_size, self.num_heads, -1)
return retrieved_keys, retrieved_values, cache_seq_lenghts, eviction_info
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
cache_kwargs: dict[str, Any],
):
eviction_info = cache_kwargs["eviction_info"]
sequence_lengths = cache_kwargs["sequence_lengths"]
cumulative_length = cache_kwargs["cumulative_length"]
if self.key_blocks is None:
self.lazy_initialization(key_states)
batch, head, seq_len, head_dim = key_states.size()
assert key_states.size() == value_states.size()
assert key_states.size()[:3] == eviction_info.size()
assert sequence_lengths is None or sequence_lengths.size() == (batch, head)
assert batch * head == self.page_batch
assert self.head_dim == head_dim
key_states = key_states.reshape(self.page_batch, seq_len, head_dim)
value_states = value_states.reshape(self.page_batch, seq_len, head_dim)
eviction_info = eviction_info.reshape(self.page_batch, seq_len)
if sequence_lengths is not None:
sequence_lengths = sequence_lengths.reshape(self.page_batch)
if seq_len == 1:
assert sequence_lengths is None or (sequence_lengths == 1).all()
assert cumulative_length == 1
self._update_single(
key_states=key_states,
value_states=value_states,
eviction_info=eviction_info,
)
else:
self._update_many(
key_states=key_states,
value_states=value_states,
eviction_info=eviction_info,
sequence_lengths=sequence_lengths,
)
self.cumulative_length += cumulative_length
return None, None
def get_block_table(self):
return self.block_table
def get_key_blocks(self):
return self.key_blocks
def get_value_blocks(self):
return self.value_blocks
def get_seq_lengths(self):
return self.cache_seq_lengths
def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
"""Returns the length and offset of the cache, used to generate the mask."""
kv_offset = 0
query_length = cache_position.shape[0]
past_seen_tokens = self.get_seq_length()
kv_length = query_length + past_seen_tokens
return kv_length, kv_offset
def get_seq_length(self) -> int:
"""Returns the sequence length of the cached states."""
return self.cumulative_length
def get_max_cache_shape(self) -> int:
"""Returns the maximum sequence length of the cache object. DynamicLayer does not have a maximum length."""
return self.max_context_length