Spaces:
Paused
Paused
Add code/cube3d/model/transformers/cache.py
Browse files
code/cube3d/model/transformers/cache.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
@dataclass
|
| 5 |
+
class Cache:
|
| 6 |
+
key_states: torch.Tensor
|
| 7 |
+
value_states: torch.Tensor
|
| 8 |
+
_supports_index_copy: bool = field(init=False) # For CUDA graph support
|
| 9 |
+
|
| 10 |
+
def __post_init__(self):
|
| 11 |
+
self._supports_index_copy = self._check_index_copy_support()
|
| 12 |
+
|
| 13 |
+
def _check_index_copy_support(self) -> bool:
|
| 14 |
+
"""Verifies support for `index_copy_` on device."""
|
| 15 |
+
try:
|
| 16 |
+
device = self.key_states.device
|
| 17 |
+
dummy = torch.tensor([0, 0], device=device)
|
| 18 |
+
dummy.index_copy_(0, torch.tensor([0], device=device), torch.tensor([1], device=device))
|
| 19 |
+
return True
|
| 20 |
+
except NotImplementedError:
|
| 21 |
+
return False
|
| 22 |
+
|
| 23 |
+
def update(self, curr_pos_id: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> None:
|
| 24 |
+
"""
|
| 25 |
+
Updates the cache based on device operator support.
|
| 26 |
+
Args:
|
| 27 |
+
curr_pos_id (torch.Tensor): Current position indices for decoding.
|
| 28 |
+
k (torch.Tensor): The keys to update
|
| 29 |
+
v (torch.Tensor): The values to update
|
| 30 |
+
"""
|
| 31 |
+
if self._supports_index_copy: # CUDA/CPU
|
| 32 |
+
self.key_states.index_copy_(2, curr_pos_id, k)
|
| 33 |
+
self.value_states.index_copy_(2, curr_pos_id, v)
|
| 34 |
+
# # 非原地操作:创建新张量并赋值,不修改原始张量
|
| 35 |
+
# self.key_states = self.key_states.index_copy(2, curr_pos_id, k) # 用index_copy(非原地)
|
| 36 |
+
# # self.value_states = self.value_states.index_copy(2, curr_pos_id, v) # 替换index_copy_
|
| 37 |
+
# self.value_states = self.value_states.clone().index_copy(2, curr_pos_id, v)
|
| 38 |
+
else: # MPS
|
| 39 |
+
self.key_states[:, :, curr_pos_id:curr_pos_id +1, ...].copy_(k)
|
| 40 |
+
self.value_states[:, :, curr_pos_id:curr_pos_id +1, ...].copy_(v)
|
| 41 |
+
# self.key_states[:, :, curr_pos_id:curr_pos_id +1, ...].copy_(k) # 原地操作
|
| 42 |
+
# self.value_states[:, :, curr_pos_id:curr_pos_id +1, ...].copy_(v) # 原地操作
|