Spaces:
Runtime error
Runtime error
| import torch | |
| from torch import nn | |
| class KVCache(nn.Module): | |
| def __init__(self, max_batch_size, max_seq_length, model_dim, dtype): | |
| super().__init__() | |
| cache_shape = (max_batch_size, max_seq_length, model_dim) | |
| self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype)) | |
| self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype)) | |
| def update(self, input_pos, k_val, v_val): | |
| # input_pos: [S], k_val: [B, S, H, D] | |
| index = input_pos[0].long() + 1 | |
| self.k_cache[:, input_pos, ...] = k_val | |
| self.v_cache[:, input_pos, ...] = v_val | |
| return self.k_cache[:, :index], self.v_cache[:, :index] | |
| class VCache(nn.Module): | |
| def __init__(self, max_batch_size, max_seq_length, num_heads, head_dim, dtype): | |
| super().__init__() | |
| cache_shape = (max_batch_size, max_seq_length, num_heads, head_dim) | |
| self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype)) | |
| def update(self, v_val): | |
| self.v_cache = v_val | |
| def get(self): | |
| return self.v_cache | |