Commit
·
7c0c66d
1
Parent(s):
23b970e
Revert "Removed non-flash classes"
Browse filesThis reverts commit 6be42f2bc584ea7d59e85502f1b8ccd538fe50e5.
- configuration_phi.py +4 -0
- modeling_phi.py +339 -64
configuration_phi.py
CHANGED
|
@@ -29,6 +29,8 @@ class PhiConfig(PretrainedConfig):
|
|
| 29 |
n_head_kv: Optional[int] = None,
|
| 30 |
rotary_dim: Optional[int] = 32,
|
| 31 |
activation_function: Optional[str] = "gelu_new",
|
|
|
|
|
|
|
| 32 |
fused_dense: bool = False,
|
| 33 |
attn_pdrop: float = 0.0,
|
| 34 |
embd_pdrop: float = 0.0,
|
|
@@ -48,6 +50,8 @@ class PhiConfig(PretrainedConfig):
|
|
| 48 |
self.n_head_kv = n_head_kv
|
| 49 |
self.rotary_dim = min(rotary_dim, n_embd // n_head)
|
| 50 |
self.activation_function = activation_function
|
|
|
|
|
|
|
| 51 |
self.fused_dense = fused_dense
|
| 52 |
self.attn_pdrop = attn_pdrop
|
| 53 |
self.embd_pdrop = embd_pdrop
|
|
|
|
| 29 |
n_head_kv: Optional[int] = None,
|
| 30 |
rotary_dim: Optional[int] = 32,
|
| 31 |
activation_function: Optional[str] = "gelu_new",
|
| 32 |
+
flash_attn: bool = False,
|
| 33 |
+
flash_rotary: bool = False,
|
| 34 |
fused_dense: bool = False,
|
| 35 |
attn_pdrop: float = 0.0,
|
| 36 |
embd_pdrop: float = 0.0,
|
|
|
|
| 50 |
self.n_head_kv = n_head_kv
|
| 51 |
self.rotary_dim = min(rotary_dim, n_embd // n_head)
|
| 52 |
self.activation_function = activation_function
|
| 53 |
+
self.flash_attn = flash_attn
|
| 54 |
+
self.flash_rotary = flash_rotary
|
| 55 |
self.fused_dense = fused_dense
|
| 56 |
self.attn_pdrop = attn_pdrop
|
| 57 |
self.embd_pdrop = embd_pdrop
|
modeling_phi.py
CHANGED
|
@@ -19,10 +19,16 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
| 19 |
|
| 20 |
from .configuration_phi import PhiConfig
|
| 21 |
|
| 22 |
-
|
| 23 |
-
from flash_attn.
|
| 24 |
-
from flash_attn.
|
| 25 |
-
from flash_attn.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
|
| 28 |
@dataclass
|
|
@@ -162,6 +168,128 @@ def _apply_rotary_emb_qkv(
|
|
| 162 |
)
|
| 163 |
|
| 164 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
class MLP(nn.Module):
|
| 166 |
"""Multi-Layer Perceptron.
|
| 167 |
|
|
@@ -196,6 +324,139 @@ class MLP(nn.Module):
|
|
| 196 |
return hidden_states
|
| 197 |
|
| 198 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
def _find_mha_dims(
|
| 200 |
config: PretrainedConfig,
|
| 201 |
n_head: Optional[int] = None,
|
|
@@ -271,8 +532,14 @@ class MHA(nn.Module):
|
|
| 271 |
# Rotary embedding
|
| 272 |
self.rotary_dim = rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0)
|
| 273 |
if self.rotary_dim > 0:
|
| 274 |
-
rotary_cls = FlashRotaryEmbedding
|
|
|
|
|
|
|
|
|
|
| 275 |
rotary_kwargs = {}
|
|
|
|
|
|
|
|
|
|
| 276 |
self.rotary_emb = rotary_cls(
|
| 277 |
self.rotary_dim,
|
| 278 |
base=rotary_base,
|
|
@@ -296,8 +563,13 @@ class MHA(nn.Module):
|
|
| 296 |
self.out_proj = linear_cls(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype)
|
| 297 |
|
| 298 |
# Attention
|
| 299 |
-
attn_cls = FlashSelfAttention
|
| 300 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
|
| 302 |
self.inner_attn = attn_cls(
|
| 303 |
causal=causal,
|
|
@@ -310,6 +582,7 @@ class MHA(nn.Module):
|
|
| 310 |
attention_dropout=config.attn_pdrop,
|
| 311 |
)
|
| 312 |
|
|
|
|
| 313 |
self.layer_idx = layer_idx
|
| 314 |
self.return_residual = return_residual
|
| 315 |
self.checkpointing = checkpointing
|
|
@@ -323,23 +596,24 @@ class MHA(nn.Module):
|
|
| 323 |
if self.rotary_dim > 0:
|
| 324 |
qkv = self.rotary_emb(qkv)
|
| 325 |
|
| 326 |
-
|
|
|
|
| 327 |
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
|
| 341 |
-
|
| 342 |
-
|
| 343 |
|
| 344 |
if self.checkpointing:
|
| 345 |
return torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, key_padding_mask=key_padding_mask)
|
|
@@ -370,53 +644,54 @@ class MHA(nn.Module):
|
|
| 370 |
if past_key_values is not None:
|
| 371 |
kv = _update_kv_cache(kv, past_key_values, self.layer_idx)
|
| 372 |
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = (
|
| 377 |
-
None,
|
| 378 |
-
None,
|
| 379 |
-
None,
|
| 380 |
-
None,
|
| 381 |
-
)
|
| 382 |
-
if key_padding_mask is not None:
|
| 383 |
-
kv, _, cu_seqlens_k, max_seqlen_k = unpad_input(kv, key_padding_mask)
|
| 384 |
-
|
| 385 |
-
if seqlen_q == 1:
|
| 386 |
-
key_padding_mask = torch.ones(batch_size, 1, device=q.device)
|
| 387 |
-
elif seqlen_q != seqlen_k:
|
| 388 |
-
key_padding_mask = key_padding_mask[:, -seqlen_q:]
|
| 389 |
-
|
| 390 |
-
q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, key_padding_mask)
|
| 391 |
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
causal=causal,
|
| 398 |
-
cu_seqlens=cu_seqlens_q,
|
| 399 |
-
max_seqlen=max_seqlen_q,
|
| 400 |
-
cu_seqlens_k=cu_seqlens_k,
|
| 401 |
-
max_seqlen_k=max_seqlen_k,
|
| 402 |
)
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 412 |
)
|
| 413 |
|
| 414 |
-
return (
|
| 415 |
-
pad_input(attn_output, indices_q, batch_size, max_seqlen_q)
|
| 416 |
-
if key_padding_mask is not None
|
| 417 |
-
else attn_output
|
| 418 |
-
)
|
| 419 |
-
|
| 420 |
if self.checkpointing:
|
| 421 |
return torch.utils.checkpoint.checkpoint(
|
| 422 |
self.inner_cross_attn,
|
|
|
|
| 19 |
|
| 20 |
from .configuration_phi import PhiConfig
|
| 21 |
|
| 22 |
+
try:
|
| 23 |
+
from flash_attn.bert_padding import pad_input, unpad_input
|
| 24 |
+
from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
|
| 25 |
+
from flash_attn.modules.mha import FlashCrossAttention, FlashSelfAttention
|
| 26 |
+
from flash_attn.ops.fused_dense import FusedDense
|
| 27 |
+
except:
|
| 28 |
+
pad_input, unpad_input = None, None
|
| 29 |
+
FlashRotaryEmbedding = None
|
| 30 |
+
FlashSelfAttention, FlashCrossAttention = None, None
|
| 31 |
+
FusedDense = None
|
| 32 |
|
| 33 |
|
| 34 |
@dataclass
|
|
|
|
| 168 |
)
|
| 169 |
|
| 170 |
|
| 171 |
+
class RotaryEmbedding(nn.Module):
|
| 172 |
+
"""Rotary positional embedding (RoPE).
|
| 173 |
+
|
| 174 |
+
Reference:
|
| 175 |
+
RoFormer: Enhanced Transformer with Rotary Position Embedding.
|
| 176 |
+
https://arxiv.org/pdf/2104.09864.pdf.
|
| 177 |
+
|
| 178 |
+
"""
|
| 179 |
+
|
| 180 |
+
def __init__(
|
| 181 |
+
self,
|
| 182 |
+
dim: int,
|
| 183 |
+
base: int = 10000,
|
| 184 |
+
scale_base: Optional[float] = None,
|
| 185 |
+
pos_idx_in_fp32: bool = True,
|
| 186 |
+
max_position_embeddings: int = 2048,
|
| 187 |
+
device: Optional[str] = None,
|
| 188 |
+
**kwargs,
|
| 189 |
+
) -> None:
|
| 190 |
+
super().__init__()
|
| 191 |
+
|
| 192 |
+
if scale_base is not None:
|
| 193 |
+
raise NotImplementedError
|
| 194 |
+
|
| 195 |
+
self.dim = dim
|
| 196 |
+
self.base = float(base)
|
| 197 |
+
self.scale_base = scale_base
|
| 198 |
+
self.pos_idx_in_fp32 = pos_idx_in_fp32
|
| 199 |
+
self.max_position_embeddings = max_position_embeddings
|
| 200 |
+
self.device = device
|
| 201 |
+
|
| 202 |
+
# Generate and save the inverse frequency buffer (non-trainable)
|
| 203 |
+
inv_freq = self._compute_inv_freq(device)
|
| 204 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 205 |
+
|
| 206 |
+
# Generate and save the scale buffer (non-trainable)
|
| 207 |
+
scale = (
|
| 208 |
+
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
|
| 209 |
+
if scale_base is not None
|
| 210 |
+
else None
|
| 211 |
+
)
|
| 212 |
+
self.register_buffer("scale", scale, persistent=False)
|
| 213 |
+
|
| 214 |
+
# Initialize cached attributes since ONNX can't rely on dynamic initialization
|
| 215 |
+
self._update_cos_sin_cache(max_position_embeddings, device=device, dtype=torch.float32)
|
| 216 |
+
|
| 217 |
+
def _compute_inv_freq(self, device: Optional[str] = None) -> torch.FloatTensor:
|
| 218 |
+
return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
|
| 219 |
+
|
| 220 |
+
def _update_cos_sin_cache(
|
| 221 |
+
self,
|
| 222 |
+
seqlen: int,
|
| 223 |
+
device: Optional[str] = None,
|
| 224 |
+
dtype: Optional[torch.dtype] = None,
|
| 225 |
+
) -> None:
|
| 226 |
+
self._seq_len_cached = seqlen
|
| 227 |
+
|
| 228 |
+
# fp32 is preferred since the output of `torch.arange` can be quite large
|
| 229 |
+
# and bf16 would lose a lot of precision
|
| 230 |
+
if self.pos_idx_in_fp32:
|
| 231 |
+
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
| 232 |
+
if self.inv_freq.dtype != torch.float32:
|
| 233 |
+
inv_freq = self._compute_inv_freq(device=device)
|
| 234 |
+
else:
|
| 235 |
+
inv_freq = self.inv_freq
|
| 236 |
+
else:
|
| 237 |
+
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
| 238 |
+
inv_freq = self.inv_freq
|
| 239 |
+
|
| 240 |
+
# `torch.outer` is preferred since `torch.einsum` converts from fp32 to fp16 if used with AMP
|
| 241 |
+
freqs = torch.outer(t, inv_freq)
|
| 242 |
+
if self.scale is None:
|
| 243 |
+
self._cos_cached = torch.cos(freqs).to(dtype)
|
| 244 |
+
self._sin_cached = torch.sin(freqs).to(dtype)
|
| 245 |
+
else:
|
| 246 |
+
power = (
|
| 247 |
+
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
|
| 248 |
+
) / self.scale_base
|
| 249 |
+
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
|
| 250 |
+
|
| 251 |
+
# Force the scale multiplication to happen in fp32
|
| 252 |
+
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
| 253 |
+
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
| 254 |
+
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
|
| 255 |
+
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
|
| 256 |
+
|
| 257 |
+
def forward(
|
| 258 |
+
self,
|
| 259 |
+
qkv: torch.Tensor,
|
| 260 |
+
kv: Optional[torch.Tensor] = None,
|
| 261 |
+
seqlen_offset: int = 0,
|
| 262 |
+
**kwargs,
|
| 263 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 264 |
+
if (
|
| 265 |
+
self._seq_len_cached < qkv.shape[1] + seqlen_offset
|
| 266 |
+
or self._cos_cached.device != qkv.device
|
| 267 |
+
or self._cos_cached.dtype != qkv.dtype
|
| 268 |
+
or (self.training and self._cos_cached.is_inference())
|
| 269 |
+
):
|
| 270 |
+
self._update_cos_sin_cache(qkv.shape[1] + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
|
| 271 |
+
|
| 272 |
+
if kv is None:
|
| 273 |
+
return _apply_rotary_emb_qkv(
|
| 274 |
+
qkv,
|
| 275 |
+
self._cos_cached[seqlen_offset:],
|
| 276 |
+
self._sin_cached[seqlen_offset:],
|
| 277 |
+
)
|
| 278 |
+
else:
|
| 279 |
+
q = _apply_rotary_emb(
|
| 280 |
+
qkv,
|
| 281 |
+
self._cos_cached[seqlen_offset:],
|
| 282 |
+
self._sin_cached[seqlen_offset:],
|
| 283 |
+
)
|
| 284 |
+
kv = _apply_rotary_emb_kv(
|
| 285 |
+
kv,
|
| 286 |
+
self._cos_cached[seqlen_offset:],
|
| 287 |
+
self._sin_cached[seqlen_offset:],
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
return q, kv
|
| 291 |
+
|
| 292 |
+
|
| 293 |
class MLP(nn.Module):
|
| 294 |
"""Multi-Layer Perceptron.
|
| 295 |
|
|
|
|
| 324 |
return hidden_states
|
| 325 |
|
| 326 |
|
| 327 |
+
class SelfAttention(nn.Module):
|
| 328 |
+
"""Self-attention layer (compatible with PyTorch).
|
| 329 |
+
|
| 330 |
+
Reference:
|
| 331 |
+
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
|
| 332 |
+
|
| 333 |
+
"""
|
| 334 |
+
|
| 335 |
+
def __init__(
|
| 336 |
+
self,
|
| 337 |
+
causal: bool = True,
|
| 338 |
+
softmax_scale: Optional[float] = None,
|
| 339 |
+
attention_dropout: float = 0.0,
|
| 340 |
+
) -> None:
|
| 341 |
+
super().__init__()
|
| 342 |
+
|
| 343 |
+
self.causal = causal
|
| 344 |
+
self.softmax_scale = softmax_scale
|
| 345 |
+
self.drop = nn.Dropout(attention_dropout)
|
| 346 |
+
|
| 347 |
+
@torch.autocast("cpu", enabled=False)
|
| 348 |
+
@torch.autocast("cuda", enabled=False)
|
| 349 |
+
def forward(
|
| 350 |
+
self,
|
| 351 |
+
qkv: torch.FloatTensor,
|
| 352 |
+
causal: bool = None,
|
| 353 |
+
key_padding_mask: Optional[torch.BoolTensor] = None,
|
| 354 |
+
**kwargs,
|
| 355 |
+
) -> torch.FloatTensor:
|
| 356 |
+
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
|
| 357 |
+
q, k, v = qkv.unbind(dim=2)
|
| 358 |
+
|
| 359 |
+
q = q.to(torch.float32)
|
| 360 |
+
k = k.to(torch.float32)
|
| 361 |
+
|
| 362 |
+
causal = self.causal if causal is None else causal
|
| 363 |
+
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
|
| 364 |
+
|
| 365 |
+
# Autocast is manually disabled to avoid `torch.einsum` performing the operation
|
| 366 |
+
# using float16, which might lead to overflow
|
| 367 |
+
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
|
| 368 |
+
|
| 369 |
+
if key_padding_mask is not None:
|
| 370 |
+
padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device)
|
| 371 |
+
padding_mask.masked_fill_(key_padding_mask, 0.0)
|
| 372 |
+
|
| 373 |
+
scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
|
| 374 |
+
|
| 375 |
+
if causal:
|
| 376 |
+
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
|
| 377 |
+
scores = scores + causal_mask.to(dtype=scores.dtype)
|
| 378 |
+
|
| 379 |
+
attention = torch.softmax(scores, dim=-1).to(v.dtype)
|
| 380 |
+
attention = self.drop(attention)
|
| 381 |
+
|
| 382 |
+
output = torch.einsum("bhts,bshd->bthd", attention, v)
|
| 383 |
+
|
| 384 |
+
return output
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
class CrossAttention(nn.Module):
|
| 388 |
+
"""Cross-attention layer (compatible with PyTorch).
|
| 389 |
+
|
| 390 |
+
Reference:
|
| 391 |
+
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
|
| 392 |
+
|
| 393 |
+
"""
|
| 394 |
+
|
| 395 |
+
def __init__(
|
| 396 |
+
self,
|
| 397 |
+
causal: bool = True,
|
| 398 |
+
softmax_scale: Optional[float] = None,
|
| 399 |
+
attention_dropout: float = 0.0,
|
| 400 |
+
) -> None:
|
| 401 |
+
super().__init__()
|
| 402 |
+
|
| 403 |
+
self.causal = causal
|
| 404 |
+
self.softmax_scale = softmax_scale
|
| 405 |
+
self.drop = nn.Dropout(attention_dropout)
|
| 406 |
+
|
| 407 |
+
@torch.autocast("cpu", enabled=False)
|
| 408 |
+
@torch.autocast("cuda", enabled=False)
|
| 409 |
+
def forward(
|
| 410 |
+
self,
|
| 411 |
+
q: torch.FloatTensor,
|
| 412 |
+
kv: torch.FloatTensor,
|
| 413 |
+
causal: bool = None,
|
| 414 |
+
key_padding_mask: Optional[torch.BoolTensor] = None,
|
| 415 |
+
**kwargs,
|
| 416 |
+
) -> torch.FloatTensor:
|
| 417 |
+
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
| 418 |
+
seqlen_k = kv.shape[1]
|
| 419 |
+
|
| 420 |
+
if kv.shape[3] != q.shape[2]:
|
| 421 |
+
kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
|
| 422 |
+
k, v = kv.unbind(dim=2)
|
| 423 |
+
|
| 424 |
+
q = q.to(torch.float32)
|
| 425 |
+
k = k.to(torch.float32)
|
| 426 |
+
|
| 427 |
+
causal = self.causal if causal is None else causal
|
| 428 |
+
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
|
| 429 |
+
|
| 430 |
+
# Autocast is manually disabled to avoid `torch.einsum` performing the operation
|
| 431 |
+
# using float16, which might lead to overflow
|
| 432 |
+
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
|
| 433 |
+
|
| 434 |
+
if key_padding_mask is not None:
|
| 435 |
+
padding_mask = torch.full(
|
| 436 |
+
(batch_size, seqlen_k),
|
| 437 |
+
-10000.0,
|
| 438 |
+
dtype=scores.dtype,
|
| 439 |
+
device=scores.device,
|
| 440 |
+
)
|
| 441 |
+
padding_mask.masked_fill_(key_padding_mask, 0.0)
|
| 442 |
+
|
| 443 |
+
scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
|
| 444 |
+
|
| 445 |
+
if causal:
|
| 446 |
+
rows = rearrange(torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1")
|
| 447 |
+
cols = torch.arange(seqlen_k, device=k.device, dtype=torch.long)
|
| 448 |
+
causal_mask = cols > rows + seqlen_k - seqlen_q
|
| 449 |
+
|
| 450 |
+
scores = scores.masked_fill(causal_mask, -10000.0)
|
| 451 |
+
|
| 452 |
+
attention = torch.softmax(scores, dim=-1).to(v.dtype)
|
| 453 |
+
attention = self.drop(attention)
|
| 454 |
+
|
| 455 |
+
output = torch.einsum("bhts,bshd->bthd", attention, v)
|
| 456 |
+
|
| 457 |
+
return output
|
| 458 |
+
|
| 459 |
+
|
| 460 |
def _find_mha_dims(
|
| 461 |
config: PretrainedConfig,
|
| 462 |
n_head: Optional[int] = None,
|
|
|
|
| 532 |
# Rotary embedding
|
| 533 |
self.rotary_dim = rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0)
|
| 534 |
if self.rotary_dim > 0:
|
| 535 |
+
rotary_cls = FlashRotaryEmbedding if config.flash_rotary else RotaryEmbedding
|
| 536 |
+
if rotary_cls is None:
|
| 537 |
+
rotary_cls = RotaryEmbedding
|
| 538 |
+
|
| 539 |
rotary_kwargs = {}
|
| 540 |
+
if rotary_cls is RotaryEmbedding:
|
| 541 |
+
rotary_kwargs["max_position_embeddings"] = config.n_positions
|
| 542 |
+
|
| 543 |
self.rotary_emb = rotary_cls(
|
| 544 |
self.rotary_dim,
|
| 545 |
base=rotary_base,
|
|
|
|
| 563 |
self.out_proj = linear_cls(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype)
|
| 564 |
|
| 565 |
# Attention
|
| 566 |
+
attn_cls = FlashSelfAttention if config.flash_attn else SelfAttention
|
| 567 |
+
if attn_cls is None:
|
| 568 |
+
attn_cls = SelfAttention
|
| 569 |
+
|
| 570 |
+
cross_attn_cls = FlashCrossAttention if config.flash_attn else CrossAttention
|
| 571 |
+
if cross_attn_cls is None:
|
| 572 |
+
cross_attn_cls = CrossAttention
|
| 573 |
|
| 574 |
self.inner_attn = attn_cls(
|
| 575 |
causal=causal,
|
|
|
|
| 582 |
attention_dropout=config.attn_pdrop,
|
| 583 |
)
|
| 584 |
|
| 585 |
+
self.flash_attn = config.flash_attn and attn_cls is FlashSelfAttention
|
| 586 |
self.layer_idx = layer_idx
|
| 587 |
self.return_residual = return_residual
|
| 588 |
self.checkpointing = checkpointing
|
|
|
|
| 596 |
if self.rotary_dim > 0:
|
| 597 |
qkv = self.rotary_emb(qkv)
|
| 598 |
|
| 599 |
+
if self.flash_attn:
|
| 600 |
+
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
|
| 601 |
|
| 602 |
+
cu_seqlens, max_seqlen = None, None
|
| 603 |
+
if key_padding_mask is not None:
|
| 604 |
+
# If `key_padding_mask` is supplied, we need to unpad the input and retrieve
|
| 605 |
+
# the `cu_seqlens` and `max_seqlen` to be used by `flash-attn`
|
| 606 |
+
qkv, indices, cu_seqlens, max_seqlen = unpad_input(qkv, key_padding_mask)
|
| 607 |
|
| 608 |
+
if self.checkpointing:
|
| 609 |
+
attn_output = torch.utils.checkpoint.checkpoint(
|
| 610 |
+
self.inner_attn, qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
|
| 611 |
+
)
|
| 612 |
+
else:
|
| 613 |
+
attn_output = self.inner_attn(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen).to(qkv.device)
|
| 614 |
|
| 615 |
+
# If `key_padding_mask` is supplied, we need to pad the output back to the original shape
|
| 616 |
+
return pad_input(attn_output, indices, batch_size, seqlen) if key_padding_mask is not None else attn_output
|
| 617 |
|
| 618 |
if self.checkpointing:
|
| 619 |
return torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, key_padding_mask=key_padding_mask)
|
|
|
|
| 644 |
if past_key_values is not None:
|
| 645 |
kv = _update_kv_cache(kv, past_key_values, self.layer_idx)
|
| 646 |
|
| 647 |
+
if self.flash_attn:
|
| 648 |
+
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
| 649 |
+
seqlen_k = kv.shape[1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 650 |
|
| 651 |
+
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = (
|
| 652 |
+
None,
|
| 653 |
+
None,
|
| 654 |
+
None,
|
| 655 |
+
None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 656 |
)
|
| 657 |
+
if key_padding_mask is not None:
|
| 658 |
+
kv, _, cu_seqlens_k, max_seqlen_k = unpad_input(kv, key_padding_mask)
|
| 659 |
+
|
| 660 |
+
if seqlen_q == 1:
|
| 661 |
+
key_padding_mask = torch.ones(batch_size, 1, device=q.device)
|
| 662 |
+
elif seqlen_q != seqlen_k:
|
| 663 |
+
key_padding_mask = key_padding_mask[:, -seqlen_q:]
|
| 664 |
+
|
| 665 |
+
q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, key_padding_mask)
|
| 666 |
+
|
| 667 |
+
if self.checkpointing:
|
| 668 |
+
attn_output = torch.utils.checkpoint.checkpoint(
|
| 669 |
+
self.inner_cross_attn,
|
| 670 |
+
q,
|
| 671 |
+
kv,
|
| 672 |
+
causal=causal,
|
| 673 |
+
cu_seqlens=cu_seqlens_q,
|
| 674 |
+
max_seqlen=max_seqlen_q,
|
| 675 |
+
cu_seqlens_k=cu_seqlens_k,
|
| 676 |
+
max_seqlen_k=max_seqlen_k,
|
| 677 |
+
)
|
| 678 |
+
else:
|
| 679 |
+
attn_output = self.inner_cross_attn(
|
| 680 |
+
q,
|
| 681 |
+
kv,
|
| 682 |
+
causal=causal,
|
| 683 |
+
cu_seqlens=cu_seqlens_q,
|
| 684 |
+
max_seqlen=max_seqlen_q,
|
| 685 |
+
cu_seqlens_k=cu_seqlens_k,
|
| 686 |
+
max_seqlen_k=max_seqlen_k,
|
| 687 |
+
)
|
| 688 |
+
|
| 689 |
+
return (
|
| 690 |
+
pad_input(attn_output, indices_q, batch_size, max_seqlen_q)
|
| 691 |
+
if key_padding_mask is not None
|
| 692 |
+
else attn_output
|
| 693 |
)
|
| 694 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 695 |
if self.checkpointing:
|
| 696 |
return torch.utils.checkpoint.checkpoint(
|
| 697 |
self.inner_cross_attn,
|