0xZohar commited on
Commit
e79b1e4
·
verified ·
1 Parent(s): 49b5170

Add code/cube3d/model/transformers/rope.py

Browse files
code/cube3d/model/transformers/rope.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+
7
+ def apply_rotary_emb(
8
+ x: torch.Tensor,
9
+ freqs_cis: torch.Tensor,
10
+ curr_pos_id: Optional[torch.Tensor] = None,
11
+ ) -> torch.Tensor:
12
+ """
13
+ Applies rotary positional embeddings to the input tensor.
14
+ Args:
15
+ x (torch.Tensor): The input tensor.
16
+ freqs_cis (torch.Tensor): A tensor containing the precomputed rotary
17
+ frequency components.
18
+ curr_pos_id (Optional[torch.Tensor]): An optional tensor specifying the
19
+ current position IDs to use for selecting a subset of `freqs_cis`.
20
+ If None, the function uses the last `seq_len` positions.
21
+ Returns:
22
+ torch.Tensor: The input tensor `x` with rotary positional embeddings
23
+ applied.
24
+ """
25
+ x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
26
+ if curr_pos_id is None:
27
+ freqs_cis = freqs_cis[:, -x.shape[2] :].unsqueeze(1)
28
+ else:
29
+ freqs_cis = freqs_cis[:, curr_pos_id, :].unsqueeze(1)
30
+ y = torch.view_as_real(x_ * freqs_cis).flatten(3)
31
+ return y.type_as(x)
32
+
33
+
34
+ @torch.no_grad
35
+ def precompute_freqs_cis(dim: int, t: torch.Tensor, theta: float = 10000.0):
36
+ """Calculate rotary embedding cos & sin, this is useful when every blocks in the network use same positional embedding.
37
+
38
+ Args:
39
+ dim (int): dimension of the single head of the transformer block
40
+ t (torch.Tensor): position ids [..., L]
41
+ theta (int, optional): rope theta. Defaults to 10000.
42
+
43
+ Returns:
44
+ Tuple[torch.Tensor, torch.Tensor]: tuple of cos and sin of rope
45
+ """
46
+ assert dim % 2 == 0, (
47
+ "RoPE only supports embedding dimensions that are multiples of 2"
48
+ )
49
+ freqs = 1.0 / (
50
+ theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=t.device) / dim)
51
+ )
52
+ # [batch_size, seq_len, num_freqs]
53
+ freqs = torch.outer(t.contiguous().view(-1), freqs).reshape(*t.shape, -1)
54
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
55
+
56
+ return freqs_cis
57
+
58
+
59
+ def scaled_dot_product_attention_with_rotary_emb(
60
+ q: torch.Tensor,
61
+ k: torch.Tensor,
62
+ v: torch.Tensor,
63
+ freqs_cis: torch.Tensor,
64
+ attn_mask: Optional[torch.Tensor] = None,
65
+ curr_pos_id: Optional[torch.Tensor] = None,
66
+ is_causal: bool = False,
67
+ ) -> torch.Tensor:
68
+ """
69
+ Computes scaled dot product attention on query, key and value tensors
70
+ with rotary position embeddings on query and key.
71
+
72
+ Without caching enabled,
73
+ q should be (bs, nh, seqlen, hd).
74
+ k and v should stay unchanged, (bs, nh, seqlen, hd).
75
+ With caching enabled,
76
+ q should be (bs, nh, 1, hd).
77
+ k and v should stay unchanged, (bs, nh, 1, hd).
78
+ causal_mask must be False.
79
+ """
80
+ q = apply_rotary_emb(q, freqs_cis, curr_pos_id=curr_pos_id) # (bs, nh, l, hd)
81
+ k = apply_rotary_emb(k, freqs_cis, curr_pos_id=None) # (bs, nh, s + l, hd)
82
+
83
+ x = F.scaled_dot_product_attention(
84
+ q,
85
+ k,
86
+ v,
87
+ attn_mask=attn_mask,
88
+ dropout_p=0.0,
89
+ is_causal=is_causal and attn_mask is None,
90
+ )
91
+ return x