Commit
·
6be42f2
1
Parent(s):
7da2fc9
Removed non-flash classes
Browse files- configuration_phi.py +0 -4
- modeling_phi.py +64 -339
configuration_phi.py
CHANGED
|
@@ -29,8 +29,6 @@ class PhiConfig(PretrainedConfig):
|
|
| 29 |
n_head_kv: Optional[int] = None,
|
| 30 |
rotary_dim: Optional[int] = 32,
|
| 31 |
activation_function: Optional[str] = "gelu_new",
|
| 32 |
-
flash_attn: bool = False,
|
| 33 |
-
flash_rotary: bool = False,
|
| 34 |
fused_dense: bool = False,
|
| 35 |
attn_pdrop: float = 0.0,
|
| 36 |
embd_pdrop: float = 0.0,
|
|
@@ -50,8 +48,6 @@ class PhiConfig(PretrainedConfig):
|
|
| 50 |
self.n_head_kv = n_head_kv
|
| 51 |
self.rotary_dim = min(rotary_dim, n_embd // n_head)
|
| 52 |
self.activation_function = activation_function
|
| 53 |
-
self.flash_attn = flash_attn
|
| 54 |
-
self.flash_rotary = flash_rotary
|
| 55 |
self.fused_dense = fused_dense
|
| 56 |
self.attn_pdrop = attn_pdrop
|
| 57 |
self.embd_pdrop = embd_pdrop
|
|
|
|
| 29 |
n_head_kv: Optional[int] = None,
|
| 30 |
rotary_dim: Optional[int] = 32,
|
| 31 |
activation_function: Optional[str] = "gelu_new",
|
|
|
|
|
|
|
| 32 |
fused_dense: bool = False,
|
| 33 |
attn_pdrop: float = 0.0,
|
| 34 |
embd_pdrop: float = 0.0,
|
|
|
|
| 48 |
self.n_head_kv = n_head_kv
|
| 49 |
self.rotary_dim = min(rotary_dim, n_embd // n_head)
|
| 50 |
self.activation_function = activation_function
|
|
|
|
|
|
|
| 51 |
self.fused_dense = fused_dense
|
| 52 |
self.attn_pdrop = attn_pdrop
|
| 53 |
self.embd_pdrop = embd_pdrop
|
modeling_phi.py
CHANGED
|
@@ -19,16 +19,10 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
| 19 |
|
| 20 |
from .configuration_phi import PhiConfig
|
| 21 |
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
from flash_attn.ops.fused_dense import FusedDense
|
| 27 |
-
except:
|
| 28 |
-
pad_input, unpad_input = None, None
|
| 29 |
-
FlashRotaryEmbedding = None
|
| 30 |
-
FlashSelfAttention, FlashCrossAttention = None, None
|
| 31 |
-
FusedDense = None
|
| 32 |
|
| 33 |
|
| 34 |
@dataclass
|
|
@@ -168,128 +162,6 @@ def _apply_rotary_emb_qkv(
|
|
| 168 |
)
|
| 169 |
|
| 170 |
|
| 171 |
-
class RotaryEmbedding(nn.Module):
|
| 172 |
-
"""Rotary positional embedding (RoPE).
|
| 173 |
-
|
| 174 |
-
Reference:
|
| 175 |
-
RoFormer: Enhanced Transformer with Rotary Position Embedding.
|
| 176 |
-
https://arxiv.org/pdf/2104.09864.pdf.
|
| 177 |
-
|
| 178 |
-
"""
|
| 179 |
-
|
| 180 |
-
def __init__(
|
| 181 |
-
self,
|
| 182 |
-
dim: int,
|
| 183 |
-
base: int = 10000,
|
| 184 |
-
scale_base: Optional[float] = None,
|
| 185 |
-
pos_idx_in_fp32: bool = True,
|
| 186 |
-
max_position_embeddings: int = 2048,
|
| 187 |
-
device: Optional[str] = None,
|
| 188 |
-
**kwargs,
|
| 189 |
-
) -> None:
|
| 190 |
-
super().__init__()
|
| 191 |
-
|
| 192 |
-
if scale_base is not None:
|
| 193 |
-
raise NotImplementedError
|
| 194 |
-
|
| 195 |
-
self.dim = dim
|
| 196 |
-
self.base = float(base)
|
| 197 |
-
self.scale_base = scale_base
|
| 198 |
-
self.pos_idx_in_fp32 = pos_idx_in_fp32
|
| 199 |
-
self.max_position_embeddings = max_position_embeddings
|
| 200 |
-
self.device = device
|
| 201 |
-
|
| 202 |
-
# Generate and save the inverse frequency buffer (non-trainable)
|
| 203 |
-
inv_freq = self._compute_inv_freq(device)
|
| 204 |
-
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 205 |
-
|
| 206 |
-
# Generate and save the scale buffer (non-trainable)
|
| 207 |
-
scale = (
|
| 208 |
-
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
|
| 209 |
-
if scale_base is not None
|
| 210 |
-
else None
|
| 211 |
-
)
|
| 212 |
-
self.register_buffer("scale", scale, persistent=False)
|
| 213 |
-
|
| 214 |
-
# Initialize cached attributes since ONNX can't rely on dynamic initialization
|
| 215 |
-
self._update_cos_sin_cache(max_position_embeddings, device=device, dtype=torch.float32)
|
| 216 |
-
|
| 217 |
-
def _compute_inv_freq(self, device: Optional[str] = None) -> torch.FloatTensor:
|
| 218 |
-
return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
|
| 219 |
-
|
| 220 |
-
def _update_cos_sin_cache(
|
| 221 |
-
self,
|
| 222 |
-
seqlen: int,
|
| 223 |
-
device: Optional[str] = None,
|
| 224 |
-
dtype: Optional[torch.dtype] = None,
|
| 225 |
-
) -> None:
|
| 226 |
-
self._seq_len_cached = seqlen
|
| 227 |
-
|
| 228 |
-
# fp32 is preferred since the output of `torch.arange` can be quite large
|
| 229 |
-
# and bf16 would lose a lot of precision
|
| 230 |
-
if self.pos_idx_in_fp32:
|
| 231 |
-
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
| 232 |
-
if self.inv_freq.dtype != torch.float32:
|
| 233 |
-
inv_freq = self._compute_inv_freq(device=device)
|
| 234 |
-
else:
|
| 235 |
-
inv_freq = self.inv_freq
|
| 236 |
-
else:
|
| 237 |
-
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
| 238 |
-
inv_freq = self.inv_freq
|
| 239 |
-
|
| 240 |
-
# `torch.outer` is preferred since `torch.einsum` converts from fp32 to fp16 if used with AMP
|
| 241 |
-
freqs = torch.outer(t, inv_freq)
|
| 242 |
-
if self.scale is None:
|
| 243 |
-
self._cos_cached = torch.cos(freqs).to(dtype)
|
| 244 |
-
self._sin_cached = torch.sin(freqs).to(dtype)
|
| 245 |
-
else:
|
| 246 |
-
power = (
|
| 247 |
-
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
|
| 248 |
-
) / self.scale_base
|
| 249 |
-
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
|
| 250 |
-
|
| 251 |
-
# Force the scale multiplication to happen in fp32
|
| 252 |
-
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
| 253 |
-
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
| 254 |
-
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
|
| 255 |
-
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
|
| 256 |
-
|
| 257 |
-
def forward(
|
| 258 |
-
self,
|
| 259 |
-
qkv: torch.Tensor,
|
| 260 |
-
kv: Optional[torch.Tensor] = None,
|
| 261 |
-
seqlen_offset: int = 0,
|
| 262 |
-
**kwargs,
|
| 263 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 264 |
-
if (
|
| 265 |
-
self._seq_len_cached < qkv.shape[1] + seqlen_offset
|
| 266 |
-
or self._cos_cached.device != qkv.device
|
| 267 |
-
or self._cos_cached.dtype != qkv.dtype
|
| 268 |
-
or (self.training and self._cos_cached.is_inference())
|
| 269 |
-
):
|
| 270 |
-
self._update_cos_sin_cache(qkv.shape[1] + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
|
| 271 |
-
|
| 272 |
-
if kv is None:
|
| 273 |
-
return _apply_rotary_emb_qkv(
|
| 274 |
-
qkv,
|
| 275 |
-
self._cos_cached[seqlen_offset:],
|
| 276 |
-
self._sin_cached[seqlen_offset:],
|
| 277 |
-
)
|
| 278 |
-
else:
|
| 279 |
-
q = _apply_rotary_emb(
|
| 280 |
-
qkv,
|
| 281 |
-
self._cos_cached[seqlen_offset:],
|
| 282 |
-
self._sin_cached[seqlen_offset:],
|
| 283 |
-
)
|
| 284 |
-
kv = _apply_rotary_emb_kv(
|
| 285 |
-
kv,
|
| 286 |
-
self._cos_cached[seqlen_offset:],
|
| 287 |
-
self._sin_cached[seqlen_offset:],
|
| 288 |
-
)
|
| 289 |
-
|
| 290 |
-
return q, kv
|
| 291 |
-
|
| 292 |
-
|
| 293 |
class MLP(nn.Module):
|
| 294 |
"""Multi-Layer Perceptron.
|
| 295 |
|
|
@@ -324,139 +196,6 @@ class MLP(nn.Module):
|
|
| 324 |
return hidden_states
|
| 325 |
|
| 326 |
|
| 327 |
-
class SelfAttention(nn.Module):
|
| 328 |
-
"""Self-attention layer (compatible with PyTorch).
|
| 329 |
-
|
| 330 |
-
Reference:
|
| 331 |
-
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
|
| 332 |
-
|
| 333 |
-
"""
|
| 334 |
-
|
| 335 |
-
def __init__(
|
| 336 |
-
self,
|
| 337 |
-
causal: bool = True,
|
| 338 |
-
softmax_scale: Optional[float] = None,
|
| 339 |
-
attention_dropout: float = 0.0,
|
| 340 |
-
) -> None:
|
| 341 |
-
super().__init__()
|
| 342 |
-
|
| 343 |
-
self.causal = causal
|
| 344 |
-
self.softmax_scale = softmax_scale
|
| 345 |
-
self.drop = nn.Dropout(attention_dropout)
|
| 346 |
-
|
| 347 |
-
@torch.autocast("cpu", enabled=False)
|
| 348 |
-
@torch.autocast("cuda", enabled=False)
|
| 349 |
-
def forward(
|
| 350 |
-
self,
|
| 351 |
-
qkv: torch.FloatTensor,
|
| 352 |
-
causal: bool = None,
|
| 353 |
-
key_padding_mask: Optional[torch.BoolTensor] = None,
|
| 354 |
-
**kwargs,
|
| 355 |
-
) -> torch.FloatTensor:
|
| 356 |
-
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
|
| 357 |
-
q, k, v = qkv.unbind(dim=2)
|
| 358 |
-
|
| 359 |
-
q = q.to(torch.float32)
|
| 360 |
-
k = k.to(torch.float32)
|
| 361 |
-
|
| 362 |
-
causal = self.causal if causal is None else causal
|
| 363 |
-
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
|
| 364 |
-
|
| 365 |
-
# Autocast is manually disabled to avoid `torch.einsum` performing the operation
|
| 366 |
-
# using float16, which might lead to overflow
|
| 367 |
-
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
|
| 368 |
-
|
| 369 |
-
if key_padding_mask is not None:
|
| 370 |
-
padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device)
|
| 371 |
-
padding_mask.masked_fill_(key_padding_mask, 0.0)
|
| 372 |
-
|
| 373 |
-
scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
|
| 374 |
-
|
| 375 |
-
if causal:
|
| 376 |
-
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
|
| 377 |
-
scores = scores + causal_mask.to(dtype=scores.dtype)
|
| 378 |
-
|
| 379 |
-
attention = torch.softmax(scores, dim=-1).to(v.dtype)
|
| 380 |
-
attention = self.drop(attention)
|
| 381 |
-
|
| 382 |
-
output = torch.einsum("bhts,bshd->bthd", attention, v)
|
| 383 |
-
|
| 384 |
-
return output
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
class CrossAttention(nn.Module):
|
| 388 |
-
"""Cross-attention layer (compatible with PyTorch).
|
| 389 |
-
|
| 390 |
-
Reference:
|
| 391 |
-
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
|
| 392 |
-
|
| 393 |
-
"""
|
| 394 |
-
|
| 395 |
-
def __init__(
|
| 396 |
-
self,
|
| 397 |
-
causal: bool = True,
|
| 398 |
-
softmax_scale: Optional[float] = None,
|
| 399 |
-
attention_dropout: float = 0.0,
|
| 400 |
-
) -> None:
|
| 401 |
-
super().__init__()
|
| 402 |
-
|
| 403 |
-
self.causal = causal
|
| 404 |
-
self.softmax_scale = softmax_scale
|
| 405 |
-
self.drop = nn.Dropout(attention_dropout)
|
| 406 |
-
|
| 407 |
-
@torch.autocast("cpu", enabled=False)
|
| 408 |
-
@torch.autocast("cuda", enabled=False)
|
| 409 |
-
def forward(
|
| 410 |
-
self,
|
| 411 |
-
q: torch.FloatTensor,
|
| 412 |
-
kv: torch.FloatTensor,
|
| 413 |
-
causal: bool = None,
|
| 414 |
-
key_padding_mask: Optional[torch.BoolTensor] = None,
|
| 415 |
-
**kwargs,
|
| 416 |
-
) -> torch.FloatTensor:
|
| 417 |
-
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
| 418 |
-
seqlen_k = kv.shape[1]
|
| 419 |
-
|
| 420 |
-
if kv.shape[3] != q.shape[2]:
|
| 421 |
-
kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
|
| 422 |
-
k, v = kv.unbind(dim=2)
|
| 423 |
-
|
| 424 |
-
q = q.to(torch.float32)
|
| 425 |
-
k = k.to(torch.float32)
|
| 426 |
-
|
| 427 |
-
causal = self.causal if causal is None else causal
|
| 428 |
-
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
|
| 429 |
-
|
| 430 |
-
# Autocast is manually disabled to avoid `torch.einsum` performing the operation
|
| 431 |
-
# using float16, which might lead to overflow
|
| 432 |
-
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
|
| 433 |
-
|
| 434 |
-
if key_padding_mask is not None:
|
| 435 |
-
padding_mask = torch.full(
|
| 436 |
-
(batch_size, seqlen_k),
|
| 437 |
-
-10000.0,
|
| 438 |
-
dtype=scores.dtype,
|
| 439 |
-
device=scores.device,
|
| 440 |
-
)
|
| 441 |
-
padding_mask.masked_fill_(key_padding_mask, 0.0)
|
| 442 |
-
|
| 443 |
-
scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
|
| 444 |
-
|
| 445 |
-
if causal:
|
| 446 |
-
rows = rearrange(torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1")
|
| 447 |
-
cols = torch.arange(seqlen_k, device=k.device, dtype=torch.long)
|
| 448 |
-
causal_mask = cols > rows + seqlen_k - seqlen_q
|
| 449 |
-
|
| 450 |
-
scores = scores.masked_fill(causal_mask, -10000.0)
|
| 451 |
-
|
| 452 |
-
attention = torch.softmax(scores, dim=-1).to(v.dtype)
|
| 453 |
-
attention = self.drop(attention)
|
| 454 |
-
|
| 455 |
-
output = torch.einsum("bhts,bshd->bthd", attention, v)
|
| 456 |
-
|
| 457 |
-
return output
|
| 458 |
-
|
| 459 |
-
|
| 460 |
def _find_mha_dims(
|
| 461 |
config: PretrainedConfig,
|
| 462 |
n_head: Optional[int] = None,
|
|
@@ -532,14 +271,8 @@ class MHA(nn.Module):
|
|
| 532 |
# Rotary embedding
|
| 533 |
self.rotary_dim = rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0)
|
| 534 |
if self.rotary_dim > 0:
|
| 535 |
-
rotary_cls = FlashRotaryEmbedding
|
| 536 |
-
if rotary_cls is None:
|
| 537 |
-
rotary_cls = RotaryEmbedding
|
| 538 |
-
|
| 539 |
rotary_kwargs = {}
|
| 540 |
-
if rotary_cls is RotaryEmbedding:
|
| 541 |
-
rotary_kwargs["max_position_embeddings"] = config.n_positions
|
| 542 |
-
|
| 543 |
self.rotary_emb = rotary_cls(
|
| 544 |
self.rotary_dim,
|
| 545 |
base=rotary_base,
|
|
@@ -563,13 +296,8 @@ class MHA(nn.Module):
|
|
| 563 |
self.out_proj = linear_cls(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype)
|
| 564 |
|
| 565 |
# Attention
|
| 566 |
-
attn_cls = FlashSelfAttention
|
| 567 |
-
|
| 568 |
-
attn_cls = SelfAttention
|
| 569 |
-
|
| 570 |
-
cross_attn_cls = FlashCrossAttention if config.flash_attn else CrossAttention
|
| 571 |
-
if cross_attn_cls is None:
|
| 572 |
-
cross_attn_cls = CrossAttention
|
| 573 |
|
| 574 |
self.inner_attn = attn_cls(
|
| 575 |
causal=causal,
|
|
@@ -582,7 +310,6 @@ class MHA(nn.Module):
|
|
| 582 |
attention_dropout=config.attn_pdrop,
|
| 583 |
)
|
| 584 |
|
| 585 |
-
self.flash_attn = config.flash_attn and attn_cls is FlashSelfAttention
|
| 586 |
self.layer_idx = layer_idx
|
| 587 |
self.return_residual = return_residual
|
| 588 |
self.checkpointing = checkpointing
|
|
@@ -596,24 +323,23 @@ class MHA(nn.Module):
|
|
| 596 |
if self.rotary_dim > 0:
|
| 597 |
qkv = self.rotary_emb(qkv)
|
| 598 |
|
| 599 |
-
|
| 600 |
-
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
|
| 601 |
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
| 607 |
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
|
| 614 |
|
| 615 |
-
|
| 616 |
-
|
| 617 |
|
| 618 |
if self.checkpointing:
|
| 619 |
return torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, key_padding_mask=key_padding_mask)
|
|
@@ -644,54 +370,53 @@ class MHA(nn.Module):
|
|
| 644 |
if past_key_values is not None:
|
| 645 |
kv = _update_kv_cache(kv, past_key_values, self.layer_idx)
|
| 646 |
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 650 |
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 656 |
)
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
|
| 666 |
-
|
| 667 |
-
if self.checkpointing:
|
| 668 |
-
attn_output = torch.utils.checkpoint.checkpoint(
|
| 669 |
-
self.inner_cross_attn,
|
| 670 |
-
q,
|
| 671 |
-
kv,
|
| 672 |
-
causal=causal,
|
| 673 |
-
cu_seqlens=cu_seqlens_q,
|
| 674 |
-
max_seqlen=max_seqlen_q,
|
| 675 |
-
cu_seqlens_k=cu_seqlens_k,
|
| 676 |
-
max_seqlen_k=max_seqlen_k,
|
| 677 |
-
)
|
| 678 |
-
else:
|
| 679 |
-
attn_output = self.inner_cross_attn(
|
| 680 |
-
q,
|
| 681 |
-
kv,
|
| 682 |
-
causal=causal,
|
| 683 |
-
cu_seqlens=cu_seqlens_q,
|
| 684 |
-
max_seqlen=max_seqlen_q,
|
| 685 |
-
cu_seqlens_k=cu_seqlens_k,
|
| 686 |
-
max_seqlen_k=max_seqlen_k,
|
| 687 |
-
)
|
| 688 |
-
|
| 689 |
-
return (
|
| 690 |
-
pad_input(attn_output, indices_q, batch_size, max_seqlen_q)
|
| 691 |
-
if key_padding_mask is not None
|
| 692 |
-
else attn_output
|
| 693 |
)
|
| 694 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 695 |
if self.checkpointing:
|
| 696 |
return torch.utils.checkpoint.checkpoint(
|
| 697 |
self.inner_cross_attn,
|
|
|
|
| 19 |
|
| 20 |
from .configuration_phi import PhiConfig
|
| 21 |
|
| 22 |
+
from flash_attn.bert_padding import pad_input, unpad_input
|
| 23 |
+
from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
|
| 24 |
+
from flash_attn.modules.mha import FlashCrossAttention, FlashSelfAttention
|
| 25 |
+
from flash_attn.ops.fused_dense import FusedDense
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
|
| 28 |
@dataclass
|
|
|
|
| 162 |
)
|
| 163 |
|
| 164 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
class MLP(nn.Module):
|
| 166 |
"""Multi-Layer Perceptron.
|
| 167 |
|
|
|
|
| 196 |
return hidden_states
|
| 197 |
|
| 198 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
def _find_mha_dims(
|
| 200 |
config: PretrainedConfig,
|
| 201 |
n_head: Optional[int] = None,
|
|
|
|
| 271 |
# Rotary embedding
|
| 272 |
self.rotary_dim = rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0)
|
| 273 |
if self.rotary_dim > 0:
|
| 274 |
+
rotary_cls = FlashRotaryEmbedding
|
|
|
|
|
|
|
|
|
|
| 275 |
rotary_kwargs = {}
|
|
|
|
|
|
|
|
|
|
| 276 |
self.rotary_emb = rotary_cls(
|
| 277 |
self.rotary_dim,
|
| 278 |
base=rotary_base,
|
|
|
|
| 296 |
self.out_proj = linear_cls(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype)
|
| 297 |
|
| 298 |
# Attention
|
| 299 |
+
attn_cls = FlashSelfAttention
|
| 300 |
+
cross_attn_cls = FlashCrossAttention
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
|
| 302 |
self.inner_attn = attn_cls(
|
| 303 |
causal=causal,
|
|
|
|
| 310 |
attention_dropout=config.attn_pdrop,
|
| 311 |
)
|
| 312 |
|
|
|
|
| 313 |
self.layer_idx = layer_idx
|
| 314 |
self.return_residual = return_residual
|
| 315 |
self.checkpointing = checkpointing
|
|
|
|
| 323 |
if self.rotary_dim > 0:
|
| 324 |
qkv = self.rotary_emb(qkv)
|
| 325 |
|
| 326 |
+
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
|
|
|
|
| 327 |
|
| 328 |
+
cu_seqlens, max_seqlen = None, None
|
| 329 |
+
if key_padding_mask is not None:
|
| 330 |
+
# If `key_padding_mask` is supplied, we need to unpad the input and retrieve
|
| 331 |
+
# the `cu_seqlens` and `max_seqlen` to be used by `flash-attn`
|
| 332 |
+
qkv, indices, cu_seqlens, max_seqlen = unpad_input(qkv, key_padding_mask)
|
| 333 |
|
| 334 |
+
if self.checkpointing:
|
| 335 |
+
attn_output = torch.utils.checkpoint.checkpoint(
|
| 336 |
+
self.inner_attn, qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
|
| 337 |
+
)
|
| 338 |
+
else:
|
| 339 |
+
attn_output = self.inner_attn(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen).to(qkv.device)
|
| 340 |
|
| 341 |
+
# If `key_padding_mask` is supplied, we need to pad the output back to the original shape
|
| 342 |
+
return pad_input(attn_output, indices, batch_size, seqlen) if key_padding_mask is not None else attn_output
|
| 343 |
|
| 344 |
if self.checkpointing:
|
| 345 |
return torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, key_padding_mask=key_padding_mask)
|
|
|
|
| 370 |
if past_key_values is not None:
|
| 371 |
kv = _update_kv_cache(kv, past_key_values, self.layer_idx)
|
| 372 |
|
| 373 |
+
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
| 374 |
+
seqlen_k = kv.shape[1]
|
| 375 |
+
|
| 376 |
+
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = (
|
| 377 |
+
None,
|
| 378 |
+
None,
|
| 379 |
+
None,
|
| 380 |
+
None,
|
| 381 |
+
)
|
| 382 |
+
if key_padding_mask is not None:
|
| 383 |
+
kv, _, cu_seqlens_k, max_seqlen_k = unpad_input(kv, key_padding_mask)
|
| 384 |
+
|
| 385 |
+
if seqlen_q == 1:
|
| 386 |
+
key_padding_mask = torch.ones(batch_size, 1, device=q.device)
|
| 387 |
+
elif seqlen_q != seqlen_k:
|
| 388 |
+
key_padding_mask = key_padding_mask[:, -seqlen_q:]
|
| 389 |
+
|
| 390 |
+
q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, key_padding_mask)
|
| 391 |
|
| 392 |
+
if self.checkpointing:
|
| 393 |
+
attn_output = torch.utils.checkpoint.checkpoint(
|
| 394 |
+
self.inner_cross_attn,
|
| 395 |
+
q,
|
| 396 |
+
kv,
|
| 397 |
+
causal=causal,
|
| 398 |
+
cu_seqlens=cu_seqlens_q,
|
| 399 |
+
max_seqlen=max_seqlen_q,
|
| 400 |
+
cu_seqlens_k=cu_seqlens_k,
|
| 401 |
+
max_seqlen_k=max_seqlen_k,
|
| 402 |
)
|
| 403 |
+
else:
|
| 404 |
+
attn_output = self.inner_cross_attn(
|
| 405 |
+
q,
|
| 406 |
+
kv,
|
| 407 |
+
causal=causal,
|
| 408 |
+
cu_seqlens=cu_seqlens_q,
|
| 409 |
+
max_seqlen=max_seqlen_q,
|
| 410 |
+
cu_seqlens_k=cu_seqlens_k,
|
| 411 |
+
max_seqlen_k=max_seqlen_k,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 412 |
)
|
| 413 |
|
| 414 |
+
return (
|
| 415 |
+
pad_input(attn_output, indices_q, batch_size, max_seqlen_q)
|
| 416 |
+
if key_padding_mask is not None
|
| 417 |
+
else attn_output
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
if self.checkpointing:
|
| 421 |
return torch.utils.checkpoint.checkpoint(
|
| 422 |
self.inner_cross_attn,
|