File size: 1,082 Bytes
fadb92b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
from torch import nn


class KVCache(nn.Module):
    def __init__(self, max_batch_size, max_seq_length, model_dim, dtype):
        super().__init__()
        cache_shape = (max_batch_size, max_seq_length, model_dim)
        self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
        self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))

    def update(self, input_pos, k_val, v_val):
        # input_pos: [S], k_val: [B, S, H, D]
        index = input_pos[0].long() + 1
        self.k_cache[:, input_pos, ...] = k_val
        self.v_cache[:, input_pos, ...] = v_val

        return self.k_cache[:, :index], self.v_cache[:, :index]


class VCache(nn.Module):
    def __init__(self, max_batch_size, max_seq_length, num_heads, head_dim, dtype):
        super().__init__()
        cache_shape = (max_batch_size, max_seq_length, num_heads, head_dim)
        self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))

    def update(self, v_val):
        self.v_cache = v_val

    def get(self):
        return self.v_cache