Sławomir Dadas commited on
Commit ·
c582324
1
Parent(s): c2b7542
Transformers 5 compatibility fixes
Browse files- config.json +4 -4
- configuration.py +2 -4
- modeling.py +86 -113
config.json
CHANGED
|
@@ -26,11 +26,11 @@
|
|
| 26 |
"pack_qkv": true,
|
| 27 |
"pad_token_id": 0,
|
| 28 |
"position_embedding_type": "rope",
|
| 29 |
-
"
|
| 30 |
-
|
| 31 |
-
|
|
|
|
| 32 |
},
|
| 33 |
-
"rope_theta": 160000,
|
| 34 |
"transformers_version": "4.56.1",
|
| 35 |
"type_vocab_size": 2,
|
| 36 |
"unpad_inputs": true,
|
|
|
|
| 26 |
"pack_qkv": true,
|
| 27 |
"pad_token_id": 0,
|
| 28 |
"position_embedding_type": "rope",
|
| 29 |
+
"rope_parameters": {
|
| 30 |
+
"rope_theta": 160000,
|
| 31 |
+
"factor": 2.0,
|
| 32 |
+
"rope_type": "default"
|
| 33 |
},
|
|
|
|
| 34 |
"transformers_version": "4.56.1",
|
| 35 |
"type_vocab_size": 2,
|
| 36 |
"unpad_inputs": true,
|
configuration.py
CHANGED
|
@@ -108,8 +108,7 @@ class NewConfig(PretrainedConfig):
|
|
| 108 |
layer_norm_eps=1e-12,
|
| 109 |
# pad_token_id=0,
|
| 110 |
position_embedding_type="rope",
|
| 111 |
-
|
| 112 |
-
rope_scaling=None,
|
| 113 |
classifier_dropout=None,
|
| 114 |
pack_qkv=True,
|
| 115 |
unpad_inputs=False,
|
|
@@ -134,9 +133,8 @@ class NewConfig(PretrainedConfig):
|
|
| 134 |
self.layer_norm_type = layer_norm_type
|
| 135 |
self.layer_norm_eps = layer_norm_eps
|
| 136 |
self.position_embedding_type = position_embedding_type
|
| 137 |
-
self.rope_theta = rope_theta
|
| 138 |
-
self.rope_scaling = rope_scaling
|
| 139 |
self.classifier_dropout = classifier_dropout
|
|
|
|
| 140 |
|
| 141 |
self.pack_qkv = pack_qkv
|
| 142 |
self.unpad_inputs = unpad_inputs
|
|
|
|
| 108 |
layer_norm_eps=1e-12,
|
| 109 |
# pad_token_id=0,
|
| 110 |
position_embedding_type="rope",
|
| 111 |
+
rope_parameters=None,
|
|
|
|
| 112 |
classifier_dropout=None,
|
| 113 |
pack_qkv=True,
|
| 114 |
unpad_inputs=False,
|
|
|
|
| 133 |
self.layer_norm_type = layer_norm_type
|
| 134 |
self.layer_norm_eps = layer_norm_eps
|
| 135 |
self.position_embedding_type = position_embedding_type
|
|
|
|
|
|
|
| 136 |
self.classifier_dropout = classifier_dropout
|
| 137 |
+
self.rope_parameters = rope_parameters
|
| 138 |
|
| 139 |
self.pack_qkv = pack_qkv
|
| 140 |
self.unpad_inputs = unpad_inputs
|
modeling.py
CHANGED
|
@@ -16,11 +16,13 @@
|
|
| 16 |
"""PyTorch NEW model."""
|
| 17 |
|
| 18 |
import math
|
| 19 |
-
from
|
|
|
|
| 20 |
|
| 21 |
import torch
|
| 22 |
import torch.utils.checkpoint
|
| 23 |
from torch import nn
|
|
|
|
| 24 |
|
| 25 |
from transformers.activations import ACT2FN
|
| 26 |
from transformers.modeling_outputs import (
|
|
@@ -139,6 +141,28 @@ class IndexPutFirstAxis(torch.autograd.Function):
|
|
| 139 |
index_put_first_axis = IndexPutFirstAxis.apply
|
| 140 |
|
| 141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
def pad_input(inputs: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int) -> torch.Tensor:
|
| 143 |
"""Add padding to sequences.
|
| 144 |
|
|
@@ -162,7 +186,7 @@ def rotate_half(x):
|
|
| 162 |
return torch.cat((-x2, x1), dim=-1)
|
| 163 |
|
| 164 |
|
| 165 |
-
def apply_rotary_pos_emb(q, k, cos, sin):
|
| 166 |
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 167 |
|
| 168 |
Args:
|
|
@@ -170,84 +194,75 @@ def apply_rotary_pos_emb(q, k, cos, sin):
|
|
| 170 |
k (`torch.Tensor`): The key tensor.
|
| 171 |
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 172 |
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
Returns:
|
| 174 |
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 175 |
"""
|
| 176 |
-
cos, sin = cos.to(q.dtype), sin.to(q.dtype)
|
| 177 |
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 178 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 179 |
return q_embed, k_embed
|
| 180 |
|
| 181 |
|
| 182 |
-
class RotaryEmbedding(
|
| 183 |
-
|
| 184 |
-
super().__init__()
|
| 185 |
-
|
| 186 |
-
self.dim = dim
|
| 187 |
-
self.max_position_embeddings = max_position_embeddings
|
| 188 |
-
self.base = base
|
| 189 |
-
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
| 190 |
-
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 191 |
-
|
| 192 |
-
# Build here to make `torch.jit.trace` work.
|
| 193 |
-
self._set_cos_sin_cache(
|
| 194 |
-
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
| 195 |
-
)
|
| 196 |
-
|
| 197 |
-
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 198 |
-
self.max_seq_len_cached = seq_len
|
| 199 |
-
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32)
|
| 200 |
-
|
| 201 |
-
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
| 202 |
-
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 203 |
-
emb = torch.cat((freqs, freqs), dim=-1)
|
| 204 |
-
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
| 205 |
-
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
| 206 |
-
|
| 207 |
-
def forward(self, x, seq_len=None):
|
| 208 |
-
# x: [bs, num_attention_heads, seq_len, head_size]
|
| 209 |
-
if seq_len > self.max_seq_len_cached:
|
| 210 |
-
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
| 211 |
-
|
| 212 |
-
return (
|
| 213 |
-
self.cos_cached[:seq_len, ...].to(dtype=x.dtype),
|
| 214 |
-
self.sin_cached[:seq_len, ...].to(dtype=x.dtype),
|
| 215 |
-
)
|
| 216 |
-
|
| 217 |
|
| 218 |
-
|
| 219 |
-
|
|
|
|
|
|
|
| 220 |
|
| 221 |
-
|
| 222 |
-
self.scaling_factor = scaling_factor
|
| 223 |
-
self.mixed_b = mixed_b
|
| 224 |
-
super().__init__(dim, max_position_embeddings, base, device)
|
| 225 |
-
max_position_embeddings = max_position_embeddings * self.scaling_factor
|
| 226 |
-
self._set_cos_sin_cache(max_position_embeddings, self.inv_freq.device, torch.get_default_dtype())
|
| 227 |
|
| 228 |
-
|
| 229 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
| 234 |
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
|
| 242 |
-
|
|
|
|
|
|
|
|
|
|
| 243 |
|
| 244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
|
| 246 |
-
|
| 247 |
-
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 248 |
-
emb = torch.cat((freqs, freqs), dim=-1)
|
| 249 |
-
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
| 250 |
-
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
| 251 |
|
| 252 |
|
| 253 |
class RMSNorm(nn.Module):
|
|
@@ -291,7 +306,7 @@ class NewEmbeddings(nn.Module):
|
|
| 291 |
config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
|
| 292 |
)
|
| 293 |
elif self.position_embedding_type == 'rope':
|
| 294 |
-
self.
|
| 295 |
else:
|
| 296 |
raise ValueError
|
| 297 |
|
|
@@ -308,27 +323,6 @@ class NewEmbeddings(nn.Module):
|
|
| 308 |
"position_ids", torch.arange(config.max_position_embeddings), persistent=False
|
| 309 |
)
|
| 310 |
|
| 311 |
-
def _init_rope(self, config):
|
| 312 |
-
kwargs = dict(
|
| 313 |
-
dim=int(config.hidden_size / config.num_attention_heads),
|
| 314 |
-
max_position_embeddings=config.max_position_embeddings,
|
| 315 |
-
base=config.rope_theta
|
| 316 |
-
)
|
| 317 |
-
if config.rope_scaling is None:
|
| 318 |
-
self.rotary_emb = RotaryEmbedding(**kwargs)
|
| 319 |
-
else:
|
| 320 |
-
kwargs.update(scaling_factor=config.rope_scaling["factor"])
|
| 321 |
-
scaling_type = config.rope_scaling["type"]
|
| 322 |
-
if scaling_type == 'ntk':
|
| 323 |
-
kwargs.update(mixed_b=config.rope_scaling.get('mixed_b', None))
|
| 324 |
-
self.rotary_emb = NTKScalingRotaryEmbedding(**kwargs)
|
| 325 |
-
# elif scaling_type == "linear":
|
| 326 |
-
# self.rotary_emb = LinearScalingRotaryEmbedding(**kwargs)
|
| 327 |
-
# elif scaling_type == "dynamic":
|
| 328 |
-
# self.rotary_emb = DynamicNTKScalingRotaryEmbedding(**kwargs)
|
| 329 |
-
else:
|
| 330 |
-
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
| 331 |
-
|
| 332 |
def forward(
|
| 333 |
self,
|
| 334 |
unpad_inputs: bool,
|
|
@@ -339,8 +333,6 @@ class NewEmbeddings(nn.Module):
|
|
| 339 |
position_ids: Optional[torch.Tensor] = None,
|
| 340 |
inputs_embeds: Optional[torch.Tensor] = None,
|
| 341 |
) -> Tuple[torch.Tensor, torch.Tensor, Optional[Tuple], Optional[List[int]]]:
|
| 342 |
-
"""
|
| 343 |
-
"""
|
| 344 |
if inputs_embeds is None:
|
| 345 |
device, input_shape = input_ids.device, input_ids.shape
|
| 346 |
else:
|
|
@@ -372,24 +364,21 @@ class NewEmbeddings(nn.Module):
|
|
| 372 |
|
| 373 |
# Set and unpad position_ids
|
| 374 |
if position_ids is None:
|
| 375 |
-
|
| 376 |
-
self.register_buffer(
|
| 377 |
-
"position_ids", torch.arange(seq_length), persistent=False
|
| 378 |
-
)
|
| 379 |
if unpad_inputs:
|
| 380 |
# [1, cumsum_seq_len]
|
| 381 |
-
position_ids = torch.cat([
|
| 382 |
else:
|
| 383 |
# [bs, seq_len]
|
| 384 |
-
position_ids =
|
| 385 |
elif unpad_inputs:
|
| 386 |
position_ids = position_ids[attention_mask_bool].unsqueeze(0) # [1, cumsum_seq_len]
|
| 387 |
|
| 388 |
# Compute rotary embedding
|
| 389 |
if self.position_embedding_type == 'rope':
|
| 390 |
-
rope_cos, rope_sin = self.rotary_emb(inputs_embeds,
|
| 391 |
-
rope_cos = rope_cos
|
| 392 |
-
rope_sin = rope_sin
|
| 393 |
rope_embeds = rope_cos, rope_sin
|
| 394 |
else:
|
| 395 |
rope_embeds = None
|
|
@@ -793,22 +782,6 @@ class NewPreTrainedModel(PreTrainedModel):
|
|
| 793 |
base_model_prefix = "new"
|
| 794 |
supports_gradient_checkpointing = True
|
| 795 |
|
| 796 |
-
def _init_weights(self, module):
|
| 797 |
-
"""Initialize the weights"""
|
| 798 |
-
if isinstance(module, nn.Linear):
|
| 799 |
-
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 800 |
-
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 801 |
-
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 802 |
-
if module.bias is not None:
|
| 803 |
-
module.bias.data.zero_()
|
| 804 |
-
elif isinstance(module, nn.Embedding):
|
| 805 |
-
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 806 |
-
if module.padding_idx is not None:
|
| 807 |
-
module.weight.data[module.padding_idx].zero_()
|
| 808 |
-
elif isinstance(module, nn.LayerNorm):
|
| 809 |
-
module.bias.data.zero_()
|
| 810 |
-
module.weight.data.fill_(1.0)
|
| 811 |
-
|
| 812 |
|
| 813 |
class NewModel(NewPreTrainedModel):
|
| 814 |
"""
|
|
|
|
| 16 |
"""PyTorch NEW model."""
|
| 17 |
|
| 18 |
import math
|
| 19 |
+
from contextlib import nullcontext
|
| 20 |
+
from typing import List, Optional, Tuple, Union, Callable
|
| 21 |
|
| 22 |
import torch
|
| 23 |
import torch.utils.checkpoint
|
| 24 |
from torch import nn
|
| 25 |
+
from transformers import ROPE_INIT_FUNCTIONS
|
| 26 |
|
| 27 |
from transformers.activations import ACT2FN
|
| 28 |
from transformers.modeling_outputs import (
|
|
|
|
| 141 |
index_put_first_axis = IndexPutFirstAxis.apply
|
| 142 |
|
| 143 |
|
| 144 |
+
def maybe_autocast(
|
| 145 |
+
device_type: str,
|
| 146 |
+
dtype: Optional["_dtype"] = None,
|
| 147 |
+
enabled: bool = True,
|
| 148 |
+
cache_enabled: bool | None = None,
|
| 149 |
+
):
|
| 150 |
+
"""
|
| 151 |
+
Context manager that only autocasts if:
|
| 152 |
+
|
| 153 |
+
- `autocast` is already enabled in this context
|
| 154 |
+
- Or this call to `maybe_autocast` has `enabled=True`
|
| 155 |
+
|
| 156 |
+
This prevents `autocast` being added to the graph when it is effectively a no-op.
|
| 157 |
+
Which makes graph splitting in `torch.compile` more flexible as it removes the
|
| 158 |
+
requirement that partition IDs be monotonically increasing.
|
| 159 |
+
"""
|
| 160 |
+
if torch.is_autocast_enabled(device_type) or enabled:
|
| 161 |
+
return torch.autocast(device_type, dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
|
| 162 |
+
else:
|
| 163 |
+
return nullcontext()
|
| 164 |
+
|
| 165 |
+
|
| 166 |
def pad_input(inputs: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int) -> torch.Tensor:
|
| 167 |
"""Add padding to sequences.
|
| 168 |
|
|
|
|
| 186 |
return torch.cat((-x2, x1), dim=-1)
|
| 187 |
|
| 188 |
|
| 189 |
+
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
|
| 190 |
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 191 |
|
| 192 |
Args:
|
|
|
|
| 194 |
k (`torch.Tensor`): The key tensor.
|
| 195 |
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 196 |
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 197 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 198 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 199 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 200 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 201 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 202 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 203 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 204 |
Returns:
|
| 205 |
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 206 |
"""
|
|
|
|
| 207 |
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 208 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 209 |
return q_embed, k_embed
|
| 210 |
|
| 211 |
|
| 212 |
+
class RotaryEmbedding(nn.Module):
|
| 213 |
+
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
|
| 215 |
+
def __init__(self, config: NewConfig, device=None):
|
| 216 |
+
super().__init__()
|
| 217 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 218 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 219 |
|
| 220 |
+
self.config = config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
|
| 222 |
+
self.rope_type = self.config.rope_parameters["rope_type"]
|
| 223 |
+
if self.rope_type == "default":
|
| 224 |
+
rope_init_fn: Callable = self.compute_default_rope_parameters
|
| 225 |
+
else:
|
| 226 |
+
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 227 |
+
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
| 228 |
|
| 229 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 230 |
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
|
|
| 231 |
|
| 232 |
+
@staticmethod
|
| 233 |
+
def compute_default_rope_parameters(
|
| 234 |
+
config: NewConfig | None = None,
|
| 235 |
+
device: Optional["torch.device"] = None,
|
| 236 |
+
) -> tuple["torch.Tensor", float]:
|
| 237 |
+
"""Computes rope parameters with NTK scaling"""
|
| 238 |
+
scaling_factor = config.rope_parameters.get("factor", 1.0)
|
| 239 |
+
dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
|
| 240 |
+
base = config.rope_parameters["rope_theta"]
|
| 241 |
+
mixed_b = config.rope_parameters.get("mixed_b", None)
|
| 242 |
+
|
| 243 |
+
base = base * (scaling_factor if mixed_b is None else 1)
|
| 244 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
| 245 |
+
if mixed_b is None:
|
| 246 |
+
inv_freq = inv_freq / scaling_factor ** (2 / dim)
|
| 247 |
+
else:
|
| 248 |
+
a = torch.tensor(scaling_factor).log() / (dim / 2) ** mixed_b
|
| 249 |
+
lambda_1_m = (a * torch.arange(1, dim // 2 + 1).float().to(device) ** mixed_b).exp()
|
| 250 |
+
inv_freq = inv_freq / lambda_1_m
|
| 251 |
+
return inv_freq, 1.0
|
| 252 |
|
| 253 |
+
@torch.no_grad()
|
| 254 |
+
def forward(self, x, position_ids):
|
| 255 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
| 256 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 257 |
|
| 258 |
+
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 259 |
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
| 260 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 261 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 262 |
+
cos = emb.cos() * self.attention_scaling
|
| 263 |
+
sin = emb.sin() * self.attention_scaling
|
| 264 |
|
| 265 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
|
| 267 |
|
| 268 |
class RMSNorm(nn.Module):
|
|
|
|
| 306 |
config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
|
| 307 |
)
|
| 308 |
elif self.position_embedding_type == 'rope':
|
| 309 |
+
self.rotary_emb = RotaryEmbedding(config)
|
| 310 |
else:
|
| 311 |
raise ValueError
|
| 312 |
|
|
|
|
| 323 |
"position_ids", torch.arange(config.max_position_embeddings), persistent=False
|
| 324 |
)
|
| 325 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
def forward(
|
| 327 |
self,
|
| 328 |
unpad_inputs: bool,
|
|
|
|
| 333 |
position_ids: Optional[torch.Tensor] = None,
|
| 334 |
inputs_embeds: Optional[torch.Tensor] = None,
|
| 335 |
) -> Tuple[torch.Tensor, torch.Tensor, Optional[Tuple], Optional[List[int]]]:
|
|
|
|
|
|
|
| 336 |
if inputs_embeds is None:
|
| 337 |
device, input_shape = input_ids.device, input_ids.shape
|
| 338 |
else:
|
|
|
|
| 364 |
|
| 365 |
# Set and unpad position_ids
|
| 366 |
if position_ids is None:
|
| 367 |
+
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
|
|
|
|
|
|
|
|
|
| 368 |
if unpad_inputs:
|
| 369 |
# [1, cumsum_seq_len]
|
| 370 |
+
position_ids = torch.cat([position_ids[:l] for l in length]).unsqueeze(0)
|
| 371 |
else:
|
| 372 |
# [bs, seq_len]
|
| 373 |
+
position_ids = position_ids[:seq_length].expand(batch_size, -1)
|
| 374 |
elif unpad_inputs:
|
| 375 |
position_ids = position_ids[attention_mask_bool].unsqueeze(0) # [1, cumsum_seq_len]
|
| 376 |
|
| 377 |
# Compute rotary embedding
|
| 378 |
if self.position_embedding_type == 'rope':
|
| 379 |
+
rope_cos, rope_sin = self.rotary_emb(inputs_embeds, position_ids)
|
| 380 |
+
rope_cos = rope_cos.unsqueeze(2) # [bs, seq_len, 1, dim]
|
| 381 |
+
rope_sin = rope_sin.unsqueeze(2) # [bs, seq_len, 1, dim]
|
| 382 |
rope_embeds = rope_cos, rope_sin
|
| 383 |
else:
|
| 384 |
rope_embeds = None
|
|
|
|
| 782 |
base_model_prefix = "new"
|
| 783 |
supports_gradient_checkpointing = True
|
| 784 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 785 |
|
| 786 |
class NewModel(NewPreTrainedModel):
|
| 787 |
"""
|