Upload model.py
Browse files
model.py
CHANGED
|
@@ -36,75 +36,156 @@ except Exception:
|
|
| 36 |
HAVE_FLASH_ATTN = False
|
| 37 |
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
def _sdpa_flash_attn_compat(
|
| 40 |
-
q: torch.Tensor,
|
| 41 |
-
k: torch.Tensor,
|
| 42 |
-
v: torch.Tensor,
|
| 43 |
*,
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
) -> torch.Tensor:
|
| 47 |
"""
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
Parameters
|
| 55 |
-
----------
|
| 56 |
-
q, k, v : torch.Tensor
|
| 57 |
-
Input tensors shaped as (B, S, H, D). Internally these will be
|
| 58 |
-
transposed to (B, H, S, D) for PyTorch's SDPA API.
|
| 59 |
-
causal : bool, optional
|
| 60 |
-
Whether to apply a causal (upper triangular) mask to prevent attending
|
| 61 |
-
to future positions. Defaults to ``True``.
|
| 62 |
-
window_size : tuple of two ints, optional
|
| 63 |
-
If provided, this denotes a symmetric window around each position. A
|
| 64 |
-
tuple ``(left, right)`` means each position can attend to at most
|
| 65 |
-
``left`` tokens to its left and ``right`` tokens to its right. When
|
| 66 |
-
supplied the causal mask is merged with the band mask.
|
| 67 |
-
|
| 68 |
-
Returns
|
| 69 |
-
-------
|
| 70 |
-
torch.Tensor
|
| 71 |
-
Output tensor with shape (B, S, H, D).
|
| 72 |
"""
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
out = F.scaled_dot_product_attention(
|
| 101 |
qh, kh, vh,
|
| 102 |
-
attn_mask=
|
| 103 |
-
dropout_p=
|
| 104 |
-
is_causal=
|
| 105 |
-
) #
|
| 106 |
-
|
| 107 |
-
return out.permute(0, 2, 1, 3).contiguous()
|
|
|
|
| 108 |
|
| 109 |
|
| 110 |
def _attn_dispatch(
|
|
|
|
| 36 |
HAVE_FLASH_ATTN = False
|
| 37 |
|
| 38 |
|
| 39 |
+
|
| 40 |
+
def _repeat_kv_for_gqa(x: torch.Tensor, repeat: int) -> torch.Tensor:
|
| 41 |
+
# x: [B, S, Hk, D] -> [B, S, Hq, D], where Hq = Hk * repeat
|
| 42 |
+
if repeat == 1:
|
| 43 |
+
return x
|
| 44 |
+
B, S, Hk, D = x.shape
|
| 45 |
+
x = x.unsqueeze(2).expand(B, S, repeat, Hk, D) # [B,S,repeat,Hk,D]
|
| 46 |
+
return x.reshape(B, S, repeat * Hk, D)
|
| 47 |
+
|
| 48 |
+
@torch.no_grad()
|
| 49 |
+
def _build_window_mask(
|
| 50 |
+
Sq: int, Sk: int, left: int, right: int, causal: bool, device: torch.device
|
| 51 |
+
) -> torch.Tensor:
|
| 52 |
+
"""
|
| 53 |
+
FA2 window semantics:
|
| 54 |
+
valid j for query i: j ∈ [ i + Sk - Sq - left, i + Sk - Sq + right ]
|
| 55 |
+
FA2.1 causal alignment (bottom-right): additionally disallow j > i + Sk - Sq
|
| 56 |
+
Return: float mask [1,1,Sq,Sk] with 0 for keep, -inf for mask.
|
| 57 |
+
"""
|
| 58 |
+
i = torch.arange(Sq, device=device).view(-1, 1) # [Sq,1]
|
| 59 |
+
j = torch.arange(Sk, device=device).view(1, -1) # [1,Sk]
|
| 60 |
+
shift = Sk - Sq
|
| 61 |
+
j_min = i + shift - left
|
| 62 |
+
j_max = i + shift + right
|
| 63 |
+
allowed = (j >= j_min) & (j <= j_max)
|
| 64 |
+
if causal:
|
| 65 |
+
# forbid looking ahead relative to FA2.1 alignment
|
| 66 |
+
allowed &= (j <= (i + shift))
|
| 67 |
+
masked = ~allowed
|
| 68 |
+
m = torch.full((Sq, Sk), 0.0, device=device)
|
| 69 |
+
m[masked] = -torch.finfo(m.dtype).max # -inf
|
| 70 |
+
return m.view(1, 1, Sq, Sk).contiguous()
|
| 71 |
+
|
| 72 |
+
@torch.no_grad()
|
| 73 |
+
def _build_causal_mask_fa21(
|
| 74 |
+
Sq: int, Sk: int, device: torch.device
|
| 75 |
+
) -> torch.Tensor:
|
| 76 |
+
"""
|
| 77 |
+
FA2.1 causal only (no window): mask positions with j > i + (Sk - Sq).
|
| 78 |
+
Returns float mask [1,1,Sq,Sk] with 0 keep, -inf mask.
|
| 79 |
+
"""
|
| 80 |
+
i = torch.arange(Sq, device=device).view(-1, 1)
|
| 81 |
+
j = torch.arange(Sk, device=device).view(1, -1)
|
| 82 |
+
shift = Sk - Sq
|
| 83 |
+
allowed = (j <= (i + shift))
|
| 84 |
+
masked = ~allowed
|
| 85 |
+
m = torch.full((Sq, Sk), 0.0, device=device)
|
| 86 |
+
m[masked] = -torch.finfo(m.dtype).max
|
| 87 |
+
return m.view(1, 1, Sq, Sk).contiguous()
|
| 88 |
+
|
| 89 |
def _sdpa_flash_attn_compat(
|
| 90 |
+
q: torch.Tensor, # [B,Sq,Hq,D]
|
| 91 |
+
k: torch.Tensor, # [B,Sk,Hk,D]
|
| 92 |
+
v: torch.Tensor, # [B,Sk,Hk,D]
|
| 93 |
*,
|
| 94 |
+
dropout_p: float = 0.0,
|
| 95 |
+
softmax_scale: Optional[float] = None, # default 1/sqrt(D) if None
|
| 96 |
+
causal: bool = False,
|
| 97 |
+
window_size: Tuple[int, int] = (-1, -1), # (-1,-1) == no window
|
| 98 |
+
alibi_slopes: Optional[torch.Tensor] = None, # (Hq,) or (B,Hq)
|
| 99 |
+
training: Optional[bool] = None,
|
| 100 |
) -> torch.Tensor:
|
| 101 |
"""
|
| 102 |
+
SDPA path emulating flash_attn_func semantics (v2):
|
| 103 |
+
- supports GQA (Hq divisible by Hk)
|
| 104 |
+
- FA2.1 causal alignment when Sq != Sk
|
| 105 |
+
- sliding window: j in [i + Sk - Sq - left, i + Sk - Sq + right]
|
| 106 |
+
- ALiBi additive bias
|
| 107 |
+
Returns: [B,Sq,Hq,D] with original dtype.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
"""
|
| 109 |
+
assert q.dim() == k.dim() == v.dim() == 4, "Expect [B,S,H,D] tensors"
|
| 110 |
+
B, Sq, Hq, D = q.shape
|
| 111 |
+
Bk, Sk, Hk, Dk = k.shape
|
| 112 |
+
assert (Bk, Sk, Dk) == (B, k.shape[1], D), "Batch/Dim mismatch"
|
| 113 |
+
assert v.shape[:3] == k.shape[:3] and v.shape[3] == D, "K/V mismatch"
|
| 114 |
+
assert Hq % Hk == 0, "Hq must be divisible by Hk for GQA/MQA"
|
| 115 |
+
repeat = Hq // Hk
|
| 116 |
+
|
| 117 |
+
# GQA: expand K,V heads to match Q heads so SDPA sees [B,Hq,*,D]
|
| 118 |
+
k_exp = _repeat_kv_for_gqa(k, repeat) # [B,Sk,Hq,D]
|
| 119 |
+
v_exp = _repeat_kv_for_gqa(v, repeat) # [B,Sk,Hq,D]
|
| 120 |
+
|
| 121 |
+
# layout for SDPA: [B,H,S,D]
|
| 122 |
+
qh = q.permute(0, 2, 1, 3).to(torch.float32) # [B,Hq,Sq,D]
|
| 123 |
+
kh = k_exp.permute(0, 2, 1, 3).to(torch.float32) # [B,Hq,Sk,D]
|
| 124 |
+
vh = v_exp.permute(0, 2, 1, 3).to(torch.float32) # [B,Hq,Sk,D]
|
| 125 |
+
in_dtype = q.dtype
|
| 126 |
+
device = q.device
|
| 127 |
+
|
| 128 |
+
# softmax scale: default 1/sqrt(D); emulate custom s by scaling Q by s*sqrt(D)
|
| 129 |
+
if softmax_scale is None:
|
| 130 |
+
softmax_scale = 1.0 / math.sqrt(D)
|
| 131 |
+
qh = qh * (softmax_scale * math.sqrt(D))
|
| 132 |
+
|
| 133 |
+
# Build float mask (+ALiBi) as additive bias; pass is_causal=False to SDPA.
|
| 134 |
+
left, right = window_size
|
| 135 |
+
use_window = (left, right) != (-1, -1)
|
| 136 |
+
attn_bias = None # [B,Hq,Sq,Sk] float, 0 for keep, -inf for mask, +ALiBi
|
| 137 |
+
|
| 138 |
+
if use_window:
|
| 139 |
+
# Per FA2 semantics; also clamp look-ahead under causal
|
| 140 |
+
if causal and right > 0:
|
| 141 |
+
right = 0
|
| 142 |
+
base = _build_window_mask(Sq, Sk, left, right, causal, device) # [1,1,Sq,Sk]
|
| 143 |
+
attn_bias = base.expand(B, Hq, Sq, Sk)
|
| 144 |
+
is_causal = False
|
| 145 |
+
elif causal:
|
| 146 |
+
base = _build_causal_mask_fa21(Sq, Sk, device) # [1,1,Sq,Sk]
|
| 147 |
+
attn_bias = base.expand(B, Hq, Sq, Sk)
|
| 148 |
+
is_causal = False
|
| 149 |
+
else:
|
| 150 |
+
is_causal = False
|
| 151 |
+
attn_bias = None # fastest path
|
| 152 |
+
|
| 153 |
+
# ALiBi: add -(slope * |(i + Sk - Sq) - j|) to logits (i=0..Sq-1, j=0..Sk-1)
|
| 154 |
+
if alibi_slopes is not None:
|
| 155 |
+
# make slopes shape [B,Hq,1,1]
|
| 156 |
+
if alibi_slopes.dim() == 1:
|
| 157 |
+
# [Hq] -> [1,Hq,1,1]
|
| 158 |
+
alibi = alibi_slopes.view(1, Hq, 1, 1).to(dtype=torch.float32, device=device)
|
| 159 |
+
alibi = alibi.expand(B, Hq, 1, 1)
|
| 160 |
+
elif alibi_slopes.dim() == 2:
|
| 161 |
+
# [B,Hq] -> [B,Hq,1,1]
|
| 162 |
+
alibi = alibi_slopes.view(B, Hq, 1, 1).to(dtype=torch.float32, device=device)
|
| 163 |
+
else:
|
| 164 |
+
raise ValueError("alibi_slopes must be (Hq,) or (B,Hq)")
|
| 165 |
+
i = torch.arange(Sq, device=device).view(1, 1, -1, 1)
|
| 166 |
+
j = torch.arange(Sk, device=device).view(1, 1, 1, -1)
|
| 167 |
+
shift = Sk - Sq
|
| 168 |
+
dist = (i + shift - j).abs().to(torch.float32) # [1,1,Sq,Sk]
|
| 169 |
+
alibi_term = -(alibi * dist) # [B,Hq,Sq,Sk]
|
| 170 |
+
if attn_bias is None:
|
| 171 |
+
attn_bias = alibi_term
|
| 172 |
+
else:
|
| 173 |
+
attn_bias = attn_bias + alibi_term
|
| 174 |
+
|
| 175 |
+
# Dropout (train) vs eval
|
| 176 |
+
if training is None:
|
| 177 |
+
training = (dropout_p > 0.0) and any(t.requires_grad for t in (q, k, v))
|
| 178 |
+
dp = dropout_p if training else 0.0
|
| 179 |
+
|
| 180 |
out = F.scaled_dot_product_attention(
|
| 181 |
qh, kh, vh,
|
| 182 |
+
attn_mask=attn_bias, # float additive mask/bias or None
|
| 183 |
+
dropout_p=dp,
|
| 184 |
+
is_causal=is_causal, # we encode causal via mask/bias when needed
|
| 185 |
+
) # [B,Hq,Sq,D] fp32
|
| 186 |
+
|
| 187 |
+
return out.permute(0, 2, 1, 3).to(in_dtype).contiguous() # [B,Sq,Hq,D]
|
| 188 |
+
|
| 189 |
|
| 190 |
|
| 191 |
def _attn_dispatch(
|