Spaces:
Runtime error
Runtime error
| import math | |
| import torch | |
| from cache import FinchCache | |
| from utils import repeat_kv | |
| from transformers.models.llama.modeling_llama import rotate_half | |
| import spaces | |
| def get_compressed_kv_cache(model, sink_tokens, step_size, target_token_size, context_ids, context_attention_mask, question_ids, question_attention_mask): | |
| device = model.device | |
| dtype = model.dtype | |
| sink_tokens = sink_tokens | |
| num_chunks = step_size | |
| context_ids = context_ids.to(device) | |
| context_attention_mask = context_attention_mask.to(device) | |
| question_ids = question_ids.to(device) | |
| question_attention_mask = question_attention_mask.to(device) | |
| question_len = question_ids.size(1) | |
| total_len = context_ids.size(1) | |
| max_context_tokens_allowed = model.config.max_position_embeddings - question_len | |
| if total_len > max_context_tokens_allowed: | |
| num_chunks = max(step_size, math.ceil(total_len / max_context_tokens_allowed)) | |
| if total_len <= sink_tokens or num_chunks == 1: | |
| # If the context is too short or only one chunk is desired, use the entire context. | |
| context_ids_list = [context_ids] | |
| context_attention_mask_list = [context_attention_mask] | |
| else: | |
| # Calculate how many tokens remain after the sink tokens. | |
| remainder_len = total_len - sink_tokens | |
| # Compute the base tokens per chunk and any leftover. | |
| base = remainder_len // num_chunks | |
| leftover = remainder_len % num_chunks | |
| # Build a list of chunk sizes. | |
| # First chunk gets the sink tokens plus base tokens. | |
| chunk_sizes = [sink_tokens + base] | |
| # Chunks 2 to num_chunks-1 get base tokens each. | |
| for _ in range(num_chunks - 2): | |
| chunk_sizes.append(base) | |
| # The last chunk gets the remaining tokens (base + leftover). | |
| if num_chunks > 1: | |
| chunk_sizes.append(base + leftover) | |
| # Now slice the context using the calculated sizes. | |
| context_ids_list = [] | |
| context_attention_mask_list = [] | |
| offset = 0 | |
| for size in chunk_sizes: | |
| end = offset + size | |
| context_ids_list.append(context_ids[:, offset:end]) | |
| context_attention_mask_list.append(context_attention_mask[:, offset:end]) | |
| offset = end | |
| # (Optional) Continue with the rest of your processing… | |
| len_rest = max(total_len - sink_tokens, 1) | |
| compression_factor = len_rest // target_token_size | |
| if compression_factor < 1: | |
| compression_factor = 1 | |
| tokenized_doc_chunks = [] | |
| for ids_chunk, mask_chunk in zip(context_ids_list, context_attention_mask_list): | |
| tokenized_doc_chunks.append({"input_ids": ids_chunk, "attention_mask": mask_chunk}) | |
| print("Number of chunks: ", len(tokenized_doc_chunks)) | |
| rotary_emb = model.model.rotary_emb.to(device) | |
| inv_freq = rotary_emb.inv_freq | |
| batch_size = question_ids.size(0) | |
| ones_mask = torch.ones(batch_size, 1, dtype=question_attention_mask.dtype, device=device) | |
| cache = FinchCache() | |
| past_cache_len = 0 | |
| past_attention_mask = torch.zeros(batch_size, 0, dtype=question_attention_mask.dtype, device=device) | |
| num_chunks = len(tokenized_doc_chunks) | |
| # Prepare a shared dictionary for hook outputs. | |
| query_context_matrices = {} | |
| # Define a hook function that uses a per-chunk offset stored on self. | |
| def query_hook_fn(module, input, output): | |
| layer_idx = getattr(module, "layer_idx", None) | |
| if layer_idx is not None: | |
| query_states = output.detach() | |
| bsz, seq_len, hidden_dim = query_states.size() | |
| num_query_heads = module.num_query_heads | |
| head_dim = hidden_dim // num_query_heads | |
| query_states = ( | |
| query_states.view(bsz, seq_len, num_query_heads, head_dim) | |
| .transpose(1, 2) | |
| .contiguous() | |
| ) | |
| # Use self._current_chunk_offset to select only the new tokens. | |
| query_context_matrices[layer_idx] = query_states[:, :, _current_chunk_offset:, :].clone() | |
| # Pre-register hooks for all layers only once. | |
| hooks = [] | |
| for i, layer in enumerate(model.model.layers): | |
| layer.self_attn.q_proj.layer_idx = i # For tracking. | |
| layer.self_attn.q_proj.num_query_heads = layer.self_attn.config.num_attention_heads | |
| hook = layer.self_attn.q_proj.register_forward_hook(query_hook_fn) | |
| hooks.append(hook) | |
| # Process each document chunk sequentially. | |
| for j, tokenized_doc_chunk in enumerate(tokenized_doc_chunks): | |
| current_seq_length = tokenized_doc_chunk["input_ids"].size(1) | |
| # Save the offset in an attribute the hook can access. | |
| _current_chunk_offset = current_seq_length | |
| # Clear the dictionary from any previous chunk. | |
| query_context_matrices.clear() | |
| # These chunks are already on the device. | |
| chunk_input_ids = tokenized_doc_chunk["input_ids"].contiguous() | |
| chunk_attention_mask = tokenized_doc_chunk["attention_mask"].contiguous() | |
| segment_attention_mask = torch.cat( | |
| [past_attention_mask, chunk_attention_mask, ones_mask], dim=-1 | |
| ).contiguous() | |
| current_input_ids = torch.cat([chunk_input_ids, question_ids], dim=-1).contiguous() | |
| current_attention_mask = torch.cat([segment_attention_mask, question_attention_mask], dim=-1).contiguous() | |
| past_seen_tokens = cache.get_seq_length() if cache is not None else 0 | |
| cache_position = torch.arange( | |
| past_seen_tokens + chunk_input_ids.shape[1], | |
| past_seen_tokens + current_input_ids.shape[1], | |
| device=device | |
| ) | |
| causal_mask = model.model._prepare_4d_causal_attention_mask_with_cache_position( | |
| current_attention_mask, | |
| sequence_length=question_ids.size(1), | |
| target_length=current_attention_mask.size(-1), | |
| dtype=dtype, | |
| device=device, | |
| cache_position=cache_position, | |
| batch_size=current_input_ids.size(0), | |
| ).contiguous() | |
| with torch.no_grad(): | |
| outputs = model.model( | |
| input_ids=current_input_ids, | |
| use_cache=True, | |
| past_key_values=cache, | |
| ) | |
| cache = outputs.past_key_values | |
| len_question = question_ids.size(1) | |
| # Now, for each transformer layer, update the cache using the query/key attention. | |
| for layer_idx in range(len(model.model.layers)): | |
| key_matrix = cache.key_cache[layer_idx] | |
| query_matrix = query_context_matrices[layer_idx] | |
| layer_cache_pos = torch.arange( | |
| past_cache_len + current_seq_length, | |
| past_cache_len + current_seq_length + len_question, | |
| device=device | |
| ) | |
| position_ids = layer_cache_pos.unsqueeze(0) | |
| cos, sin = rotary_emb(query_matrix, position_ids) | |
| cos = cos.unsqueeze(1) | |
| sin = sin.unsqueeze(1) | |
| query_matrix = (query_matrix * cos) + (rotate_half(query_matrix) * sin) | |
| num_repeats = model.config.num_attention_heads // model.config.num_key_value_heads | |
| key_matrix = repeat_kv(key_matrix, num_repeats) | |
| scaling = math.sqrt(model.config.head_dim) | |
| attention_matrix = torch.matmul(query_matrix, key_matrix.transpose(2, 3)) / scaling | |
| causal_mask_sliced = causal_mask[:, :, :, : key_matrix.shape[-2]] | |
| attention_matrix = attention_matrix + causal_mask_sliced | |
| attention_matrix = torch.nn.functional.softmax(attention_matrix, dim=-1, dtype=torch.float32).to(query_matrix.dtype) | |
| # Normalization | |
| tol = 1e-8 | |
| binary_mask = (torch.abs(causal_mask_sliced.to(torch.float32)) < tol).to(torch.float32) | |
| non_zero_counts = binary_mask.sum(dim=3, keepdim=True) | |
| non_zero_counts = torch.clamp_min(non_zero_counts, 1.0).to(attention_matrix.dtype) | |
| attention_matrix = attention_matrix / non_zero_counts | |
| if j != num_chunks - 1: | |
| attention_matrix = attention_matrix[:, :, :, : past_cache_len + current_seq_length].clone().contiguous() | |
| else: | |
| attention_matrix = attention_matrix[:, :, :, : past_cache_len + current_seq_length + len_question].clone().contiguous() | |
| attention_matrix = torch.sum(attention_matrix, dim=-2) | |
| attention_matrix = attention_matrix.view( | |
| attention_matrix.size(0), model.config.num_key_value_heads, num_repeats, -1 | |
| ).sum(dim=2) | |
| full_context_size = attention_matrix.size(-1) | |
| attention_matrix[..., :sink_tokens] = float("inf") | |
| if j == num_chunks - 1: | |
| attention_matrix[..., -len_question:] = float("inf") | |
| if j == 0: | |
| k = int(sink_tokens + (max(0, current_seq_length - sink_tokens) // compression_factor)) | |
| k = min(k + past_cache_len, full_context_size) | |
| elif j < num_chunks - 1: | |
| to_keep_new = int(current_seq_length // compression_factor) | |
| k = min(past_cache_len + to_keep_new, full_context_size) | |
| else: | |
| desired_final = sink_tokens + target_token_size + len_question# TODO remember to include the question tokens | |
| k = desired_final if full_context_size >= desired_final else full_context_size | |
| k = max(k, sink_tokens) | |
| selected_indices = torch.topk(attention_matrix, k, dim=-1).indices | |
| selected_indices, _ = torch.sort(selected_indices, dim=-1) | |
| cache.compress_cache(layer_idx, selected_indices, inv_freq) | |
| past_cache_len = cache._seen_tokens | |
| past_attention_mask = torch.ones(1, past_cache_len, device=device) | |
| # Remove the hooks once after all chunks are processed. | |
| for hook in hooks: | |
| hook.remove() | |
| return cache | |