MsAlEhR commited on
Commit
5c5a7f4
·
verified ·
1 Parent(s): 1632667

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +143 -62
model.py CHANGED
@@ -36,75 +36,156 @@ except Exception:
36
  HAVE_FLASH_ATTN = False
37
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  def _sdpa_flash_attn_compat(
40
- q: torch.Tensor,
41
- k: torch.Tensor,
42
- v: torch.Tensor,
43
  *,
44
- causal: bool = True,
45
- window_size: Optional[Tuple[int, int]] = None,
 
 
 
 
46
  ) -> torch.Tensor:
47
  """
48
- Compatibility wrapper that emulates ``flash_attn_func`` using PyTorch's
49
- built–in scaled dot product attention. It accepts query, key, and value
50
- tensors with shape ``[batch, seq_len, num_heads, head_dim]`` and returns an
51
- output tensor of the same shape. When ``window_size`` is provided a banded
52
- local attention mask is applied in combination with a causal mask.
53
-
54
- Parameters
55
- ----------
56
- q, k, v : torch.Tensor
57
- Input tensors shaped as (B, S, H, D). Internally these will be
58
- transposed to (B, H, S, D) for PyTorch's SDPA API.
59
- causal : bool, optional
60
- Whether to apply a causal (upper triangular) mask to prevent attending
61
- to future positions. Defaults to ``True``.
62
- window_size : tuple of two ints, optional
63
- If provided, this denotes a symmetric window around each position. A
64
- tuple ``(left, right)`` means each position can attend to at most
65
- ``left`` tokens to its left and ``right`` tokens to its right. When
66
- supplied the causal mask is merged with the band mask.
67
-
68
- Returns
69
- -------
70
- torch.Tensor
71
- Output tensor with shape (B, S, H, D).
72
  """
73
- # Convert from [B, S, H, D] to [B, H, S, D] for SDPA.
74
- qh = q.permute(0, 2, 1, 3)
75
- kh = k.permute(0, 2, 1, 3)
76
- vh = v.permute(0, 2, 1, 3)
77
- seq_len = q.shape[1]
78
- attn_mask = None
79
- # Build an attention mask if a window is specified.
80
- if window_size is not None:
81
- left, right = window_size
82
- # Create a banded mask of allowed positions. ``True`` indicates a
83
- # position should be masked out. Shape: [S, S].
84
- device = q.device
85
- ar = torch.arange(seq_len, device=device)
86
- i = ar.view(-1, 1)
87
- j = ar.view(1, -1)
88
- band = (j >= (i - left)) & (j <= (i + right))
89
- mask = ~band # invert to mark positions outside the band.
90
- if causal:
91
- # Combine with a standard causal mask.
92
- causal_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=device).triu(1)
93
- mask = mask | causal_mask
94
- attn_mask = mask
95
- # When we provide an attention mask PyTorch's SDPA expects
96
- # is_causal=False; the mask encodes causality explicitly.
97
- causal = False
98
- # Use scaled_dot_product_attention. This API supports dropout, but we set
99
- # dropout_p=0.0 to mirror flash attention's inference behaviour.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  out = F.scaled_dot_product_attention(
101
  qh, kh, vh,
102
- attn_mask=attn_mask,
103
- dropout_p=0.0,
104
- is_causal=causal,
105
- ) # shape: [B, H, S, D]
106
- # Convert back to [B, S, H, D].
107
- return out.permute(0, 2, 1, 3).contiguous()
 
108
 
109
 
110
  def _attn_dispatch(
 
36
  HAVE_FLASH_ATTN = False
37
 
38
 
39
+
40
+ def _repeat_kv_for_gqa(x: torch.Tensor, repeat: int) -> torch.Tensor:
41
+ # x: [B, S, Hk, D] -> [B, S, Hq, D], where Hq = Hk * repeat
42
+ if repeat == 1:
43
+ return x
44
+ B, S, Hk, D = x.shape
45
+ x = x.unsqueeze(2).expand(B, S, repeat, Hk, D) # [B,S,repeat,Hk,D]
46
+ return x.reshape(B, S, repeat * Hk, D)
47
+
48
+ @torch.no_grad()
49
+ def _build_window_mask(
50
+ Sq: int, Sk: int, left: int, right: int, causal: bool, device: torch.device
51
+ ) -> torch.Tensor:
52
+ """
53
+ FA2 window semantics:
54
+ valid j for query i: j ∈ [ i + Sk - Sq - left, i + Sk - Sq + right ]
55
+ FA2.1 causal alignment (bottom-right): additionally disallow j > i + Sk - Sq
56
+ Return: float mask [1,1,Sq,Sk] with 0 for keep, -inf for mask.
57
+ """
58
+ i = torch.arange(Sq, device=device).view(-1, 1) # [Sq,1]
59
+ j = torch.arange(Sk, device=device).view(1, -1) # [1,Sk]
60
+ shift = Sk - Sq
61
+ j_min = i + shift - left
62
+ j_max = i + shift + right
63
+ allowed = (j >= j_min) & (j <= j_max)
64
+ if causal:
65
+ # forbid looking ahead relative to FA2.1 alignment
66
+ allowed &= (j <= (i + shift))
67
+ masked = ~allowed
68
+ m = torch.full((Sq, Sk), 0.0, device=device)
69
+ m[masked] = -torch.finfo(m.dtype).max # -inf
70
+ return m.view(1, 1, Sq, Sk).contiguous()
71
+
72
+ @torch.no_grad()
73
+ def _build_causal_mask_fa21(
74
+ Sq: int, Sk: int, device: torch.device
75
+ ) -> torch.Tensor:
76
+ """
77
+ FA2.1 causal only (no window): mask positions with j > i + (Sk - Sq).
78
+ Returns float mask [1,1,Sq,Sk] with 0 keep, -inf mask.
79
+ """
80
+ i = torch.arange(Sq, device=device).view(-1, 1)
81
+ j = torch.arange(Sk, device=device).view(1, -1)
82
+ shift = Sk - Sq
83
+ allowed = (j <= (i + shift))
84
+ masked = ~allowed
85
+ m = torch.full((Sq, Sk), 0.0, device=device)
86
+ m[masked] = -torch.finfo(m.dtype).max
87
+ return m.view(1, 1, Sq, Sk).contiguous()
88
+
89
  def _sdpa_flash_attn_compat(
90
+ q: torch.Tensor, # [B,Sq,Hq,D]
91
+ k: torch.Tensor, # [B,Sk,Hk,D]
92
+ v: torch.Tensor, # [B,Sk,Hk,D]
93
  *,
94
+ dropout_p: float = 0.0,
95
+ softmax_scale: Optional[float] = None, # default 1/sqrt(D) if None
96
+ causal: bool = False,
97
+ window_size: Tuple[int, int] = (-1, -1), # (-1,-1) == no window
98
+ alibi_slopes: Optional[torch.Tensor] = None, # (Hq,) or (B,Hq)
99
+ training: Optional[bool] = None,
100
  ) -> torch.Tensor:
101
  """
102
+ SDPA path emulating flash_attn_func semantics (v2):
103
+ - supports GQA (Hq divisible by Hk)
104
+ - FA2.1 causal alignment when Sq != Sk
105
+ - sliding window: j in [i + Sk - Sq - left, i + Sk - Sq + right]
106
+ - ALiBi additive bias
107
+ Returns: [B,Sq,Hq,D] with original dtype.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  """
109
+ assert q.dim() == k.dim() == v.dim() == 4, "Expect [B,S,H,D] tensors"
110
+ B, Sq, Hq, D = q.shape
111
+ Bk, Sk, Hk, Dk = k.shape
112
+ assert (Bk, Sk, Dk) == (B, k.shape[1], D), "Batch/Dim mismatch"
113
+ assert v.shape[:3] == k.shape[:3] and v.shape[3] == D, "K/V mismatch"
114
+ assert Hq % Hk == 0, "Hq must be divisible by Hk for GQA/MQA"
115
+ repeat = Hq // Hk
116
+
117
+ # GQA: expand K,V heads to match Q heads so SDPA sees [B,Hq,*,D]
118
+ k_exp = _repeat_kv_for_gqa(k, repeat) # [B,Sk,Hq,D]
119
+ v_exp = _repeat_kv_for_gqa(v, repeat) # [B,Sk,Hq,D]
120
+
121
+ # layout for SDPA: [B,H,S,D]
122
+ qh = q.permute(0, 2, 1, 3).to(torch.float32) # [B,Hq,Sq,D]
123
+ kh = k_exp.permute(0, 2, 1, 3).to(torch.float32) # [B,Hq,Sk,D]
124
+ vh = v_exp.permute(0, 2, 1, 3).to(torch.float32) # [B,Hq,Sk,D]
125
+ in_dtype = q.dtype
126
+ device = q.device
127
+
128
+ # softmax scale: default 1/sqrt(D); emulate custom s by scaling Q by s*sqrt(D)
129
+ if softmax_scale is None:
130
+ softmax_scale = 1.0 / math.sqrt(D)
131
+ qh = qh * (softmax_scale * math.sqrt(D))
132
+
133
+ # Build float mask (+ALiBi) as additive bias; pass is_causal=False to SDPA.
134
+ left, right = window_size
135
+ use_window = (left, right) != (-1, -1)
136
+ attn_bias = None # [B,Hq,Sq,Sk] float, 0 for keep, -inf for mask, +ALiBi
137
+
138
+ if use_window:
139
+ # Per FA2 semantics; also clamp look-ahead under causal
140
+ if causal and right > 0:
141
+ right = 0
142
+ base = _build_window_mask(Sq, Sk, left, right, causal, device) # [1,1,Sq,Sk]
143
+ attn_bias = base.expand(B, Hq, Sq, Sk)
144
+ is_causal = False
145
+ elif causal:
146
+ base = _build_causal_mask_fa21(Sq, Sk, device) # [1,1,Sq,Sk]
147
+ attn_bias = base.expand(B, Hq, Sq, Sk)
148
+ is_causal = False
149
+ else:
150
+ is_causal = False
151
+ attn_bias = None # fastest path
152
+
153
+ # ALiBi: add -(slope * |(i + Sk - Sq) - j|) to logits (i=0..Sq-1, j=0..Sk-1)
154
+ if alibi_slopes is not None:
155
+ # make slopes shape [B,Hq,1,1]
156
+ if alibi_slopes.dim() == 1:
157
+ # [Hq] -> [1,Hq,1,1]
158
+ alibi = alibi_slopes.view(1, Hq, 1, 1).to(dtype=torch.float32, device=device)
159
+ alibi = alibi.expand(B, Hq, 1, 1)
160
+ elif alibi_slopes.dim() == 2:
161
+ # [B,Hq] -> [B,Hq,1,1]
162
+ alibi = alibi_slopes.view(B, Hq, 1, 1).to(dtype=torch.float32, device=device)
163
+ else:
164
+ raise ValueError("alibi_slopes must be (Hq,) or (B,Hq)")
165
+ i = torch.arange(Sq, device=device).view(1, 1, -1, 1)
166
+ j = torch.arange(Sk, device=device).view(1, 1, 1, -1)
167
+ shift = Sk - Sq
168
+ dist = (i + shift - j).abs().to(torch.float32) # [1,1,Sq,Sk]
169
+ alibi_term = -(alibi * dist) # [B,Hq,Sq,Sk]
170
+ if attn_bias is None:
171
+ attn_bias = alibi_term
172
+ else:
173
+ attn_bias = attn_bias + alibi_term
174
+
175
+ # Dropout (train) vs eval
176
+ if training is None:
177
+ training = (dropout_p > 0.0) and any(t.requires_grad for t in (q, k, v))
178
+ dp = dropout_p if training else 0.0
179
+
180
  out = F.scaled_dot_product_attention(
181
  qh, kh, vh,
182
+ attn_mask=attn_bias, # float additive mask/bias or None
183
+ dropout_p=dp,
184
+ is_causal=is_causal, # we encode causal via mask/bias when needed
185
+ ) # [B,Hq,Sq,D] fp32
186
+
187
+ return out.permute(0, 2, 1, 3).to(in_dtype).contiguous() # [B,Sq,Hq,D]
188
+
189
 
190
 
191
  def _attn_dispatch(