raster2seq / models /kv_cache.py
anas
Initial deployment of Raster2Seq floor plan vectorization API
fadb92b
import torch
from torch import nn
class KVCache(nn.Module):
def __init__(self, max_batch_size, max_seq_length, model_dim, dtype):
super().__init__()
cache_shape = (max_batch_size, max_seq_length, model_dim)
self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
def update(self, input_pos, k_val, v_val):
# input_pos: [S], k_val: [B, S, H, D]
index = input_pos[0].long() + 1
self.k_cache[:, input_pos, ...] = k_val
self.v_cache[:, input_pos, ...] = v_val
return self.k_cache[:, :index], self.v_cache[:, :index]
class VCache(nn.Module):
def __init__(self, max_batch_size, max_seq_length, num_heads, head_dim, dtype):
super().__init__()
cache_shape = (max_batch_size, max_seq_length, num_heads, head_dim)
self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
def update(self, v_val):
self.v_cache = v_val
def get(self):
return self.v_cache