HV-Khurdula commited on
Commit
1168c80
·
verified ·
1 Parent(s): 0949437

Update rope.py

Browse files

feat: allow support for batched

Files changed (1) hide show
  1. rope.py +40 -28
rope.py CHANGED
@@ -17,32 +17,44 @@ def precompute_freqs_cis(
17
  return torch.stack([freqs.real, freqs.imag], dim=-1)
18
 
19
 
20
- def apply_rotary_emb(
21
- x: torch.Tensor,
22
- freqs_cis: torch.Tensor,
23
- position_ids: torch.Tensor,
24
- num_heads: int,
25
- rot_dim: int = 32,
26
- interleave: bool = False,
27
- ) -> torch.Tensor:
28
- assert rot_dim == freqs_cis.shape[-2] * 2
29
- assert num_heads == x.shape[1]
30
-
31
- x_rot, x_pass = x[..., :rot_dim], x[..., rot_dim:]
32
-
33
- if interleave:
34
- xq_r = x_rot.float().reshape(*x_rot.shape[:-1], -1, 2)[..., 0]
35
- xq_i = x_rot.float().reshape(*x_rot.shape[:-1], -1, 2)[..., 1]
 
 
 
 
 
 
 
 
36
  else:
37
- d_q = x_rot.shape[-1] // 2
38
- xq_r, xq_i = x_rot[..., :d_q], x_rot[..., d_q:]
39
-
40
- freqs_cos = freqs_cis[..., 0][position_ids, :].unsqueeze(0).unsqueeze(0)
41
- freqs_sin = freqs_cis[..., 1][position_ids, :].unsqueeze(0).unsqueeze(0)
42
-
43
- # Complex multiplication: (a + bi) * (c + di) = (ac - bd) + (ad + bc)i
44
- xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
45
- xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
46
- xq_out = torch.stack((xq_out_r, xq_out_i), dim=-1).flatten(-2)
47
-
48
- return torch.cat([xq_out.to(x.dtype), x_pass], dim=-1)
 
 
 
 
 
17
  return torch.stack([freqs.real, freqs.imag], dim=-1)
18
 
19
 
20
+ def apply_rotary_emb(x, freqs_cis, position_ids, num_heads, rot_dim=None, interleave=False):
21
+ """
22
+ x: (B, num_heads, q_len, head_dim)
23
+ freqs_cis: (max_seq, rot_half, 2) # [..., cos/sin]
24
+ position_ids: (q_len,) or (B, q_len) or scalar
25
+ num_heads: int (unused here; kept for API compatibility)
26
+ rot_dim: optional; if None we use min(D, 2*rot_half)
27
+ """
28
+ B, H, T, D = x.shape
29
+ rot_half_from_freqs = freqs_cis.size(-2) # available rotary half
30
+ rd = rot_dim or (rot_half_from_freqs * 2)
31
+ rd = min(rd, D) # don't exceed head_dim
32
+
33
+ x_rot, x_pass = x[..., :rd], x[..., rd:] # (B,H,T,rd), (B,H,T,D-rd)
34
+
35
+ # Gather cos/sin for each position; result (B,T,rot_half_from_freqs)
36
+ if torch.is_tensor(position_ids):
37
+ if position_ids.dim() == 2 and position_ids.size(0) == B: # (B,T)
38
+ freq = freqs_cis[position_ids] # (B,T,rot_half,2)
39
+ elif position_ids.dim() == 1 and position_ids.size(0) == T: # (T,)
40
+ freq = freqs_cis[position_ids].unsqueeze(0).expand(B, -1, -1, -1)
41
+ else: # scalar
42
+ pid = position_ids.view(()).long()
43
+ freq = freqs_cis[pid].unsqueeze(0).expand(B, T, -1, -1)
44
  else:
45
+ pid = torch.tensor(position_ids, device=x.device, dtype=torch.long)
46
+ freq = freqs_cis[pid].unsqueeze(0).expand(B, T, -1, -1)
47
+
48
+ # Trim freqs to rd//2 if needed
49
+ rot_half = rd // 2
50
+ cos = freq[..., 0][..., :rot_half].unsqueeze(1) # (B,1,T,rot_half)
51
+ sin = freq[..., 1][..., :rot_half].unsqueeze(1) # (B,1,T,rot_half)
52
+
53
+ # Split real/imag and apply rotation
54
+ x_rot = x_rot.view(B, H, T, rot_half, 2)
55
+ xr, xi = x_rot[..., 0], x_rot[..., 1] # (B,H,T,rot_half)
56
+ yr = xr * cos - xi * sin
57
+ yi = xr * sin + xi * cos
58
+ y = torch.stack((yr, yi), dim=-1).flatten(-2) # (B,H,T,rd)
59
+
60
+ return torch.cat([y.to(x.dtype), x_pass], dim=-1)