0xZohar commited on
Commit
72e5e67
·
verified ·
1 Parent(s): 1e1d4d1

Add code/cube3d/model/transformers/cache.py

Browse files
code/cube3d/model/transformers/cache.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ import torch
3
+
4
+ @dataclass
5
+ class Cache:
6
+ key_states: torch.Tensor
7
+ value_states: torch.Tensor
8
+ _supports_index_copy: bool = field(init=False) # For CUDA graph support
9
+
10
+ def __post_init__(self):
11
+ self._supports_index_copy = self._check_index_copy_support()
12
+
13
+ def _check_index_copy_support(self) -> bool:
14
+ """Verifies support for `index_copy_` on device."""
15
+ try:
16
+ device = self.key_states.device
17
+ dummy = torch.tensor([0, 0], device=device)
18
+ dummy.index_copy_(0, torch.tensor([0], device=device), torch.tensor([1], device=device))
19
+ return True
20
+ except NotImplementedError:
21
+ return False
22
+
23
+ def update(self, curr_pos_id: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> None:
24
+ """
25
+ Updates the cache based on device operator support.
26
+ Args:
27
+ curr_pos_id (torch.Tensor): Current position indices for decoding.
28
+ k (torch.Tensor): The keys to update
29
+ v (torch.Tensor): The values to update
30
+ """
31
+ if self._supports_index_copy: # CUDA/CPU
32
+ self.key_states.index_copy_(2, curr_pos_id, k)
33
+ self.value_states.index_copy_(2, curr_pos_id, v)
34
+ # # 非原地操作:创建新张量并赋值,不修改原始张量
35
+ # self.key_states = self.key_states.index_copy(2, curr_pos_id, k) # 用index_copy(非原地)
36
+ # # self.value_states = self.value_states.index_copy(2, curr_pos_id, v) # 替换index_copy_
37
+ # self.value_states = self.value_states.clone().index_copy(2, curr_pos_id, v)
38
+ else: # MPS
39
+ self.key_states[:, :, curr_pos_id:curr_pos_id +1, ...].copy_(k)
40
+ self.value_states[:, :, curr_pos_id:curr_pos_id +1, ...].copy_(v)
41
+ # self.key_states[:, :, curr_pos_id:curr_pos_id +1, ...].copy_(k) # 原地操作
42
+ # self.value_states[:, :, curr_pos_id:curr_pos_id +1, ...].copy_(v) # 原地操作