Spaces:
Runtime error
Runtime error
| from transformers import DynamicCache | |
| import torch | |
| import os | |
| class FinchCache(DynamicCache): | |
| def __init__(self) -> None: | |
| super().__init__() | |
| self.key_cache = [] | |
| self.value_cache = [] | |
| def _rotate_half(x): | |
| x1 = x[..., : x.shape[-1] // 2] | |
| x2 = x[..., x.shape[-1] // 2 :] | |
| return torch.cat((-x2, x1), dim=-1) | |
| def _apply_key_rotary_pos_emb(self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: | |
| return (key_states * cos) + (self._rotate_half(key_states) * sin) | |
| def _rerotate_cos_sin(x, inv_freq, important_pos_batch): | |
| B, H, L = important_pos_batch.shape | |
| device = important_pos_batch.device | |
| device_type = x.device.type | |
| dtype = x.dtype | |
| idx = torch.arange(0, L, device=device) | |
| idx = idx.unsqueeze(0) | |
| inv_freq = inv_freq[None, None, :, None].float().expand(B, H, -1, 1) # (B, H, M, 1) | |
| idx = idx[:, None, :].float().expand(B, H, L) # (B, H, L) | |
| delta_pos = idx - important_pos_batch | |
| delta_pos = delta_pos.unsqueeze(2) # (B, H, 1, L) | |
| device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" | |
| with torch.autocast(device_type=device_type, enabled=False): | |
| freqs = delta_pos.float() * inv_freq.float() | |
| freqs = freqs.transpose(2, 3) | |
| emb = torch.cat((freqs, freqs), dim=-1) | |
| cos = emb.cos().contiguous() | |
| sin = emb.sin().contiguous() | |
| return cos.to(dtype=dtype), sin.to(dtype=dtype) | |
| def gather_important_tokens(states, indices): | |
| return torch.gather(states, 2, indices.unsqueeze(-1).expand(-1, -1, -1, states.size(3))).contiguous() | |
| def compress_cache(self, layer_index, important_pos, inv_freq): | |
| new_length = important_pos.size(2) | |
| new_cos, new_sin = self._rerotate_cos_sin(self.key_cache[layer_index], inv_freq, important_pos) | |
| gathered_keys = self.gather_important_tokens(self.key_cache[layer_index], important_pos).clone() | |
| self.key_cache[layer_index] = self._apply_key_rotary_pos_emb(gathered_keys, new_cos, new_sin) | |
| gathered_values = self.gather_important_tokens(self.value_cache[layer_index], important_pos).clone() | |
| self.value_cache[layer_index] = gathered_values | |
| self._seen_tokens = new_length | |
| def save(self, path: str): | |
| """Save the cache to disk, moving tensors to CPU.""" | |
| try: | |
| os.makedirs(os.path.dirname(path), exist_ok=True) | |
| torch.save( | |
| {"key_cache": [k.cpu() for k in self.key_cache], "value_cache": [v.cpu() for v in self.value_cache]}, | |
| path, | |
| ) | |
| except Exception as e: | |
| print(f"Error occurred while saving: {e}") | |
| def load(cls, path: str, device: str = "cpu") -> "FinchCache": | |
| """Load the cache from disk and move tensors to the specified device.""" | |
| data = torch.load(path, map_location=device) | |
| cache = cls() | |
| cache.key_cache = [k.to(device) for k in data["key_cache"]] | |
| cache.value_cache = [v.to(device) for v in data["value_cache"]] | |
| cache._seen_tokens = cache.value_cache[0].size(2) if cache.value_cache else 0 | |
| return cache |