File size: 4,003 Bytes
27871e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
"""
Explicit KV Cache management for efficient inference.
This is critical for Qualcomm deployment and agent control loops.
"""

import torch
from typing import Optional, Tuple
from dataclasses import dataclass


@dataclass
class KVCache:
    """Key-Value cache for transformer inference.

    Layout: [num_layers, batch_size, num_heads, max_seq_len, head_dim]

    This explicit cache enables:
    - Efficient autoregressive decoding
    - Cache offloading for memory management
    - Sliding window attention (future)
    - Agent control loops with cache manipulation
    """

    key_cache: torch.Tensor  # [num_layers, batch, heads, max_len, head_dim]
    value_cache: torch.Tensor  # [num_layers, batch, heads, max_len, head_dim]
    seq_len: int  # Current sequence length in cache

    @classmethod
    def create(
        cls,
        num_layers: int,
        batch_size: int,
        num_heads: int,
        max_seq_len: int,
        head_dim: int,
        dtype: torch.dtype = torch.float16,
        device: torch.device = None,
    ) -> "KVCache":
        """Create an empty KV cache.

        Args:
            num_layers: Number of transformer layers
            batch_size: Batch size
            num_heads: Number of attention heads
            max_seq_len: Maximum sequence length
            head_dim: Dimension per attention head
            dtype: Data type for cache tensors
            device: Device to create cache on

        Returns:
            Initialized KVCache with zero tensors
        """
        shape = (num_layers, batch_size, num_heads, max_seq_len, head_dim)

        key_cache = torch.zeros(shape, dtype=dtype, device=device)
        value_cache = torch.zeros(shape, dtype=dtype, device=device)

        return cls(key_cache=key_cache, value_cache=value_cache, seq_len=0)

    def update(
        self,
        layer_idx: int,
        key: torch.Tensor,
        value: torch.Tensor,
        position: int,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Update cache for a specific layer and return full K, V.

        Args:
            layer_idx: Index of the transformer layer
            key: New key tensor [batch, heads, seq_len, head_dim]
            value: New value tensor [batch, heads, seq_len, head_dim]
            position: Starting position for the new tokens

        Returns:
            Tuple of (full_key, full_value) including cached values
        """
        seq_len = key.shape[2]
        end_pos = position + seq_len

        # Store new keys and values
        self.key_cache[layer_idx, :, :, position:end_pos, :] = key
        self.value_cache[layer_idx, :, :, position:end_pos, :] = value

        # Update sequence length
        self.seq_len = max(self.seq_len, end_pos)

        # Return full K, V up to current position
        return (
            self.key_cache[layer_idx, :, :, :end_pos, :],
            self.value_cache[layer_idx, :, :, :end_pos, :],
        )

    def get(
        self,
        layer_idx: int,
        end_pos: Optional[int] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get cached K, V for a specific layer.

        Args:
            layer_idx: Index of the transformer layer
            end_pos: End position (defaults to current seq_len)

        Returns:
            Tuple of (key, value) tensors
        """
        if end_pos is None:
            end_pos = self.seq_len

        return (
            self.key_cache[layer_idx, :, :, :end_pos, :],
            self.value_cache[layer_idx, :, :, :end_pos, :],
        )

    def reset(self):
        """Reset the cache to empty state."""
        self.key_cache.zero_()
        self.value_cache.zero_()
        self.seq_len = 0

    @property
    def memory_usage_mb(self) -> float:
        """Calculate memory usage in megabytes."""
        total_bytes = self.key_cache.numel() * self.key_cache.element_size()
        total_bytes += self.value_cache.numel() * self.value_cache.element_size()
        return total_bytes / (1024 * 1024)