File size: 26,647 Bytes
ad3fd89 c69a5b6 ad3fd89 c69a5b6 ad3fd89 c69a5b6 ad3fd89 c69a5b6 ad3fd89 6e12827 ad3fd89 6e12827 ad3fd89 6e12827 ad3fd89 6e12827 ad3fd89 6e12827 ad3fd89 c69a5b6 ad3fd89 c69a5b6 ad3fd89 c69a5b6 ad3fd89 c69a5b6 ad3fd89 c69a5b6 ad3fd89 c69a5b6 ad3fd89 c69a5b6 ad3fd89 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 | # coding=utf-8
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
import torch.utils.checkpoint
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
try:
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import pad_input, unpad_input
_FLASH_ATTN_AVAILABLE = True
except Exception:
flash_attn_func = None
flash_attn_varlen_func = None
pad_input = None
unpad_input = None
_FLASH_ATTN_AVAILABLE = False
@dataclass
class PackedSequenceMetadata:
cu_seqlens: torch.Tensor
max_seqlen: int
indices: Optional[torch.Tensor] = None
batch_size: Optional[int] = None
seq_len: Optional[int] = None
class MossTTSNanoGPT2RotaryEmbedding(nn.Module):
def __init__(self, dim: int, base: float = 10000.0) -> None:
super().__init__()
if dim % 2 != 0:
raise ValueError(f"RoPE head_dim must be even, got {dim}")
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(
self,
position_ids: torch.LongTensor,
*,
device: torch.device,
dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor]:
if position_ids.ndim == 1:
position_ids = position_ids.unsqueeze(0)
freqs = torch.einsum("bs,d->bsd", position_ids.to(device=device, dtype=self.inv_freq.dtype), self.inv_freq)
cos = freqs.cos().repeat_interleave(2, dim=-1).unsqueeze(2).to(dtype=dtype)
sin = freqs.sin().repeat_interleave(2, dim=-1).unsqueeze(2).to(dtype=dtype)
return cos, sin
def rotate_half(hidden_states: torch.Tensor) -> torch.Tensor:
even = hidden_states[..., ::2]
odd = hidden_states[..., 1::2]
return torch.stack((-odd, even), dim=-1).reshape_as(hidden_states)
def apply_rotary_pos_emb(
hidden_states: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
return (hidden_states * cos) + (rotate_half(hidden_states) * sin)
class MossTTSNanoGPT2MLP(nn.Module):
def __init__(self, config: GPT2Config) -> None:
super().__init__()
hidden_size = int(config.hidden_size)
inner_size = int(config.n_inner or 4 * hidden_size)
self.fc_in = nn.Linear(hidden_size, inner_size)
self.fc_out = nn.Linear(inner_size, hidden_size)
self.act = ACT2FN[config.activation_function]
self.dropout = nn.Dropout(config.resid_pdrop)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc_in(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.fc_out(hidden_states)
return self.dropout(hidden_states)
class MossTTSNanoGPT2Attention(nn.Module):
def __init__(self, config: GPT2Config, layer_idx: int, attn_implementation: str) -> None:
super().__init__()
hidden_size = int(config.hidden_size)
num_heads = int(config.num_attention_heads)
if hidden_size % num_heads != 0:
raise ValueError(f"hidden_size={hidden_size} must be divisible by num_attention_heads={num_heads}")
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.embed_dim = hidden_size
self.layer_idx = layer_idx
self.attn_implementation = attn_implementation
self.attn_dropout = float(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
self.scale_attn_weights = bool(getattr(config, "scale_attn_weights", True))
self.scale_attn_by_inverse_layer_idx = bool(getattr(config, "scale_attn_by_inverse_layer_idx", False))
self.position_embedding_type = str(getattr(config, "position_embedding_type", "absolute")).lower()
if self.position_embedding_type not in {"absolute", "rope"}:
raise ValueError(f"Unsupported position_embedding_type={self.position_embedding_type!r}")
self.c_attn = nn.Linear(hidden_size, 3 * hidden_size)
self.c_proj = nn.Linear(hidden_size, hidden_size)
self.rotary_emb = None
if self.position_embedding_type == "rope":
self.rotary_emb = MossTTSNanoGPT2RotaryEmbedding(
self.head_dim,
base=float(getattr(config, "rope_base", 10000.0)),
)
def _split_heads(self, tensor: torch.Tensor) -> torch.Tensor:
if tensor.ndim == 3:
batch_size, seq_len, _ = tensor.shape
return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim)
if tensor.ndim == 2:
total_tokens, _ = tensor.shape
return tensor.view(total_tokens, self.num_heads, self.head_dim)
raise ValueError(f"Unsupported tensor rank for attention split: {tensor.ndim}")
def _merge_heads(self, tensor: torch.Tensor) -> torch.Tensor:
if tensor.ndim == 4:
batch_size, seq_len, _, _ = tensor.shape
return tensor.reshape(batch_size, seq_len, self.embed_dim)
if tensor.ndim == 3:
total_tokens, _, _ = tensor.shape
return tensor.reshape(total_tokens, self.embed_dim)
raise ValueError(f"Unsupported tensor rank for attention merge: {tensor.ndim}")
def _causal_attention_mask(
self,
attention_mask: Optional[torch.Tensor],
query_length: int,
key_length: int,
device: torch.device,
) -> torch.Tensor:
query_positions = torch.arange(query_length, device=device, dtype=torch.long)
query_positions = query_positions + max(key_length - query_length, 0)
key_positions = torch.arange(key_length, device=device, dtype=torch.long)
causal = key_positions.unsqueeze(0) <= query_positions.unsqueeze(1)
causal = causal.unsqueeze(0).unsqueeze(0)
if attention_mask is None:
return causal
key_mask = attention_mask[:, None, None, :].to(dtype=torch.bool)
return causal & key_mask
def _eager_attention(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
) -> torch.Tensor:
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
scale = 1.0
if self.scale_attn_weights:
scale /= self.head_dim ** 0.5
if self.scale_attn_by_inverse_layer_idx:
scale /= float(self.layer_idx + 1)
scores = torch.matmul(query, key.transpose(-1, -2)) * scale
causal_mask = self._causal_attention_mask(
attention_mask=attention_mask,
query_length=query.shape[-2],
key_length=key.shape[-2],
device=query.device,
)
scores = scores.masked_fill(~causal_mask, torch.finfo(scores.dtype).min)
probs = torch.softmax(scores, dim=-1)
if self.training and self.attn_dropout > 0:
probs = torch.dropout(probs, self.attn_dropout, train=True)
output = torch.matmul(probs, value)
return output.transpose(1, 2).contiguous()
def _sdpa_attention(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
) -> torch.Tensor:
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
mask = None
query_attention_mask = None
if attention_mask is not None:
query_length = query.shape[-2]
key_length = key.shape[-2]
mask = self._causal_attention_mask(
attention_mask=attention_mask,
query_length=query_length,
key_length=key_length,
device=query.device,
)
query_attention_mask = attention_mask[:, -query_length:].to(dtype=torch.bool, device=query.device)
if not bool(query_attention_mask.all()):
# SDPA can produce NaNs when a query row is fully masked. For padded query positions,
# keep a single aligned key visible, then zero the query output after attention.
mask = mask.expand(query.shape[0], -1, -1, -1).clone()
invalid_batch, invalid_query = torch.nonzero(~query_attention_mask, as_tuple=True)
aligned_key = invalid_query + max(key_length - query_length, 0)
mask[invalid_batch, :, invalid_query, aligned_key] = True
output = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=mask,
dropout_p=self.attn_dropout if self.training else 0.0,
is_causal=mask is None,
)
if query_attention_mask is not None and not bool(query_attention_mask.all()):
output = output.masked_fill(~query_attention_mask[:, None, :, None], 0.0)
return output.transpose(1, 2).contiguous()
def _flash_attention(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
packed_metadata: Optional[PackedSequenceMetadata],
) -> torch.Tensor:
if not _FLASH_ATTN_AVAILABLE:
raise ImportError("flash_attn is not installed, but attn_implementation='flash_attention_2' was requested.")
if query.device.type != "cuda":
raise ValueError("flash_attention_2 requires CUDA tensors.")
if query.dtype not in (torch.float16, torch.bfloat16):
raise ValueError(
f"flash_attention_2 requires fp16/bf16 tensors, but received dtype={query.dtype}."
)
dropout_p = self.attn_dropout if self.training else 0.0
if packed_metadata is not None:
if packed_metadata.indices is not None:
query = query.reshape(-1, self.num_heads, self.head_dim).index_select(0, packed_metadata.indices)
key = key.reshape(-1, self.num_heads, self.head_dim).index_select(0, packed_metadata.indices)
value = value.reshape(-1, self.num_heads, self.head_dim).index_select(0, packed_metadata.indices)
output = flash_attn_varlen_func(
query,
key,
value,
packed_metadata.cu_seqlens,
packed_metadata.cu_seqlens,
packed_metadata.max_seqlen,
packed_metadata.max_seqlen,
dropout_p=dropout_p,
causal=True,
)
if packed_metadata.indices is None:
return output
return pad_input(
output,
packed_metadata.indices,
packed_metadata.batch_size,
packed_metadata.seq_len,
)
if attention_mask is None or bool(attention_mask.all()):
return flash_attn_func(
query,
key,
value,
dropout_p=dropout_p,
causal=True,
)
unpadded_query, indices, cu_seqlens, max_seqlen, _ = unpad_input(query, attention_mask)
unpadded_key, _, _, _, _ = unpad_input(key, attention_mask)
unpadded_value, _, _, _, _ = unpad_input(value, attention_mask)
output = flash_attn_varlen_func(
unpadded_query,
unpadded_key,
unpadded_value,
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen,
dropout_p=dropout_p,
causal=True,
)
return pad_input(output, indices, query.shape[0], query.shape[1])
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
packed_metadata: Optional[PackedSequenceMetadata] = None,
layer_past: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]:
qkv = self.c_attn(hidden_states)
query, key, value = qkv.split(self.embed_dim, dim=-1)
query = self._split_heads(query)
key = self._split_heads(key)
value = self._split_heads(value)
if self.rotary_emb is not None:
if position_ids is None:
raise ValueError("position_ids must be provided when position_embedding_type='rope'.")
cos, sin = self.rotary_emb(
position_ids.to(device=query.device),
device=query.device,
dtype=query.dtype,
)
query = apply_rotary_pos_emb(query, cos, sin)
key = apply_rotary_pos_emb(key, cos, sin)
if layer_past is not None:
past_key, past_value = layer_past
key = torch.cat([past_key.to(device=key.device, dtype=key.dtype), key], dim=1)
value = torch.cat([past_value.to(device=value.device, dtype=value.dtype), value], dim=1)
present = (key, value) if use_cache else None
if self.attn_implementation == "flash_attention_2" and layer_past is None:
attn_output = self._flash_attention(
query=query,
key=key,
value=value,
attention_mask=attention_mask,
packed_metadata=packed_metadata,
)
elif self.attn_implementation == "sdpa":
attn_output = self._sdpa_attention(
query=query,
key=key,
value=value,
attention_mask=attention_mask,
)
else:
attn_output = self._eager_attention(
query=query,
key=key,
value=value,
attention_mask=attention_mask,
)
attn_output = self._merge_heads(attn_output)
attn_output = self.c_proj(attn_output)
return self.resid_dropout(attn_output), present
class MossTTSNanoGPT2Block(nn.Module):
def __init__(self, config: GPT2Config, layer_idx: int, attn_implementation: str) -> None:
super().__init__()
hidden_size = int(config.hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = MossTTSNanoGPT2Attention(config, layer_idx=layer_idx, attn_implementation=attn_implementation)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = MossTTSNanoGPT2MLP(config)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
packed_metadata: Optional[PackedSequenceMetadata] = None,
layer_past: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]:
attn_output, present = self.attn(
self.ln_1(hidden_states),
attention_mask=attention_mask,
position_ids=position_ids,
packed_metadata=packed_metadata,
layer_past=layer_past,
use_cache=use_cache,
)
hidden_states = hidden_states + attn_output
hidden_states = hidden_states + self.mlp(self.ln_2(hidden_states))
return hidden_states, present
class MossTTSNanoGPT2Model(nn.Module):
def __init__(self, config: GPT2Config, attn_implementation: str = "eager") -> None:
super().__init__()
self.config = config
self.attn_implementation = attn_implementation
self.position_embedding_type = str(getattr(config, "position_embedding_type", "absolute")).lower()
if self.position_embedding_type not in {"absolute", "rope"}:
raise ValueError(f"Unsupported position_embedding_type={self.position_embedding_type!r}")
hidden_size = int(config.hidden_size)
self.wte = nn.Embedding(config.vocab_size, hidden_size)
self.wpe = nn.Embedding(config.n_positions, hidden_size) if self.position_embedding_type == "absolute" else nn.Identity()
self.drop = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList(
[MossTTSNanoGPT2Block(config, layer_idx=index, attn_implementation=attn_implementation) for index in range(config.n_layer)]
)
self.ln_f = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.gradient_checkpointing = False
self._reset_parameters()
def _reset_parameters(self) -> None:
init_std = float(self.config.initializer_range)
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=init_std)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=init_std)
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
@staticmethod
def _normalize_num_sequences(
cu_seqlens: torch.Tensor,
num_sequences: Optional[torch.Tensor],
device: torch.device,
) -> torch.Tensor:
if cu_seqlens.ndim == 1:
cu_seqlens = cu_seqlens.unsqueeze(0)
if num_sequences is None:
counts = []
for boundary in cu_seqlens:
diffs = boundary[1:] - boundary[:-1]
counts.append(int((diffs > 0).sum().item()))
return torch.tensor(counts, dtype=torch.int32, device=device)
if num_sequences.ndim == 0:
return num_sequences.unsqueeze(0)
return num_sequences
@staticmethod
def build_packed_position_ids(
attention_mask: Optional[torch.Tensor],
cu_seqlens: torch.Tensor,
num_sequences: Optional[torch.Tensor],
) -> torch.Tensor:
if cu_seqlens.ndim == 1:
cu_seqlens = cu_seqlens.unsqueeze(0)
batch_size, seq_len = cu_seqlens.shape[0], cu_seqlens.shape[1] - 1
device = cu_seqlens.device
position_ids = torch.zeros((batch_size, seq_len), dtype=torch.long, device=device)
counts = MossTTSNanoGPT2Model._normalize_num_sequences(cu_seqlens, num_sequences, device=device)
for batch_index in range(batch_size):
sequence_count = int(counts[batch_index].item())
boundaries = cu_seqlens[batch_index, : sequence_count + 1].tolist()
for start, end in zip(boundaries[:-1], boundaries[1:]):
start = int(start)
end = int(end)
if end > start:
position_ids[batch_index, start:end] = torch.arange(end - start, device=device)
if attention_mask is not None:
position_ids = position_ids * attention_mask.to(dtype=position_ids.dtype)
return position_ids
@staticmethod
def build_packed_metadata(
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
num_sequences: Optional[torch.Tensor],
) -> PackedSequenceMetadata:
if cu_seqlens.ndim == 1:
cu_seqlens = cu_seqlens.unsqueeze(0)
device = hidden_states.device
counts = MossTTSNanoGPT2Model._normalize_num_sequences(cu_seqlens, num_sequences, device=device)
flat_indices = []
cumulative = [0]
max_seqlen = 0
seq_len = hidden_states.shape[1]
for batch_index in range(hidden_states.shape[0]):
sequence_count = int(counts[batch_index].item())
boundaries = cu_seqlens[batch_index, : sequence_count + 1].tolist()
for start, end in zip(boundaries[:-1], boundaries[1:]):
start = int(start)
end = int(end)
if end <= start:
continue
segment_indices = batch_index * seq_len + torch.arange(start, end, device=device)
flat_indices.append(segment_indices)
cumulative.append(cumulative[-1] + (end - start))
max_seqlen = max(max_seqlen, end - start)
if not flat_indices:
raise ValueError("cu_seqlens did not describe any non-empty packed sequences.")
indices = torch.cat(flat_indices, dim=0)
return PackedSequenceMetadata(
cu_seqlens=torch.tensor(cumulative, dtype=torch.int32, device=device),
max_seqlen=max_seqlen,
indices=indices,
batch_size=hidden_states.shape[0],
seq_len=hidden_states.shape[1],
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[tuple[tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: bool = True,
cu_seqlens: Optional[torch.Tensor] = None,
num_sequences: Optional[torch.Tensor] = None,
) -> BaseModelOutputWithPast:
del input_ids, output_attentions
if inputs_embeds is None:
raise ValueError("inputs_embeds must be provided.")
use_cache = bool(use_cache)
if use_cache and cu_seqlens is not None:
raise ValueError("use_cache=True is not supported together with cu_seqlens packing.")
hidden_states = inputs_embeds
if attention_mask is None:
attention_mask = torch.ones(hidden_states.shape[:2], dtype=torch.bool, device=hidden_states.device)
else:
attention_mask = attention_mask.to(dtype=torch.bool, device=hidden_states.device)
query_attention_mask = attention_mask[:, -hidden_states.shape[1] :]
packed_metadata = None
if position_ids is None:
if cu_seqlens is not None:
position_ids = self.build_packed_position_ids(
attention_mask=attention_mask,
cu_seqlens=cu_seqlens.to(device=hidden_states.device),
num_sequences=num_sequences.to(device=hidden_states.device) if num_sequences is not None else None,
)
elif attention_mask is not None:
position_ids = attention_mask.long().cumsum(dim=-1) - 1
position_ids = position_ids.masked_fill(~attention_mask, 0)
position_ids = position_ids[:, -hidden_states.shape[1] :]
else:
past_length = 0
if past_key_values is not None and len(past_key_values) > 0:
past_length = past_key_values[0][0].shape[1]
position_ids = torch.arange(hidden_states.shape[1], device=hidden_states.device, dtype=torch.long)
position_ids = position_ids + past_length
position_ids = position_ids.unsqueeze(0).expand(hidden_states.shape[0], -1)
if cu_seqlens is not None and self.attn_implementation == "flash_attention_2":
packed_metadata = self.build_packed_metadata(
hidden_states=hidden_states,
cu_seqlens=cu_seqlens.to(device=hidden_states.device),
num_sequences=num_sequences.to(device=hidden_states.device) if num_sequences is not None else None,
)
if self.position_embedding_type == "absolute":
hidden_states = hidden_states + self.wpe(position_ids)
hidden_states = self.drop(hidden_states)
hidden_states = hidden_states * query_attention_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
all_hidden_states = () if output_hidden_states else None
presents = [] if use_cache else None
for layer_index, block in enumerate(self.h):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
if use_cache:
raise ValueError("use_cache=True is not supported when gradient checkpointing is enabled during training.")
def custom_forward(*inputs):
output, _ = block(
inputs[0],
attention_mask=inputs[1],
position_ids=inputs[2],
packed_metadata=packed_metadata,
layer_past=None,
use_cache=False,
)
return output
hidden_states = torch.utils.checkpoint.checkpoint(
custom_forward,
hidden_states,
attention_mask,
position_ids,
use_reentrant=False,
)
present = None
else:
hidden_states, present = block(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
packed_metadata=packed_metadata,
layer_past=None if past_key_values is None else past_key_values[layer_index],
use_cache=use_cache,
)
hidden_states = hidden_states * query_attention_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
if presents is not None:
presents.append(present)
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states * query_attention_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return (hidden_states, tuple(presents) if presents is not None else None, all_hidden_states, None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=tuple(presents) if presents is not None else None,
hidden_states=all_hidden_states,
attentions=None,
)
|