Sławomir Dadas commited on
Commit
c582324
·
1 Parent(s): c2b7542

Transformers 5 compatibility fixes

Browse files
Files changed (3) hide show
  1. config.json +4 -4
  2. configuration.py +2 -4
  3. modeling.py +86 -113
config.json CHANGED
@@ -26,11 +26,11 @@
26
  "pack_qkv": true,
27
  "pad_token_id": 0,
28
  "position_embedding_type": "rope",
29
- "rope_scaling": {
30
- "factor": 2.0,
31
- "type": "ntk"
 
32
  },
33
- "rope_theta": 160000,
34
  "transformers_version": "4.56.1",
35
  "type_vocab_size": 2,
36
  "unpad_inputs": true,
 
26
  "pack_qkv": true,
27
  "pad_token_id": 0,
28
  "position_embedding_type": "rope",
29
+ "rope_parameters": {
30
+ "rope_theta": 160000,
31
+ "factor": 2.0,
32
+ "rope_type": "default"
33
  },
 
34
  "transformers_version": "4.56.1",
35
  "type_vocab_size": 2,
36
  "unpad_inputs": true,
configuration.py CHANGED
@@ -108,8 +108,7 @@ class NewConfig(PretrainedConfig):
108
  layer_norm_eps=1e-12,
109
  # pad_token_id=0,
110
  position_embedding_type="rope",
111
- rope_theta=10000.0,
112
- rope_scaling=None,
113
  classifier_dropout=None,
114
  pack_qkv=True,
115
  unpad_inputs=False,
@@ -134,9 +133,8 @@ class NewConfig(PretrainedConfig):
134
  self.layer_norm_type = layer_norm_type
135
  self.layer_norm_eps = layer_norm_eps
136
  self.position_embedding_type = position_embedding_type
137
- self.rope_theta = rope_theta
138
- self.rope_scaling = rope_scaling
139
  self.classifier_dropout = classifier_dropout
 
140
 
141
  self.pack_qkv = pack_qkv
142
  self.unpad_inputs = unpad_inputs
 
108
  layer_norm_eps=1e-12,
109
  # pad_token_id=0,
110
  position_embedding_type="rope",
111
+ rope_parameters=None,
 
112
  classifier_dropout=None,
113
  pack_qkv=True,
114
  unpad_inputs=False,
 
133
  self.layer_norm_type = layer_norm_type
134
  self.layer_norm_eps = layer_norm_eps
135
  self.position_embedding_type = position_embedding_type
 
 
136
  self.classifier_dropout = classifier_dropout
137
+ self.rope_parameters = rope_parameters
138
 
139
  self.pack_qkv = pack_qkv
140
  self.unpad_inputs = unpad_inputs
modeling.py CHANGED
@@ -16,11 +16,13 @@
16
  """PyTorch NEW model."""
17
 
18
  import math
19
- from typing import List, Optional, Tuple, Union
 
20
 
21
  import torch
22
  import torch.utils.checkpoint
23
  from torch import nn
 
24
 
25
  from transformers.activations import ACT2FN
26
  from transformers.modeling_outputs import (
@@ -139,6 +141,28 @@ class IndexPutFirstAxis(torch.autograd.Function):
139
  index_put_first_axis = IndexPutFirstAxis.apply
140
 
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  def pad_input(inputs: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int) -> torch.Tensor:
143
  """Add padding to sequences.
144
 
@@ -162,7 +186,7 @@ def rotate_half(x):
162
  return torch.cat((-x2, x1), dim=-1)
163
 
164
 
165
- def apply_rotary_pos_emb(q, k, cos, sin):
166
  """Applies Rotary Position Embedding to the query and key tensors.
167
 
168
  Args:
@@ -170,84 +194,75 @@ def apply_rotary_pos_emb(q, k, cos, sin):
170
  k (`torch.Tensor`): The key tensor.
171
  cos (`torch.Tensor`): The cosine part of the rotary embedding.
172
  sin (`torch.Tensor`): The sine part of the rotary embedding.
 
 
 
 
 
 
 
173
  Returns:
174
  `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
175
  """
176
- cos, sin = cos.to(q.dtype), sin.to(q.dtype)
177
  q_embed = (q * cos) + (rotate_half(q) * sin)
178
  k_embed = (k * cos) + (rotate_half(k) * sin)
179
  return q_embed, k_embed
180
 
181
 
182
- class RotaryEmbedding(torch.nn.Module):
183
- def __init__(self, dim, max_position_embeddings=512, base=10000.0, device=None):
184
- super().__init__()
185
-
186
- self.dim = dim
187
- self.max_position_embeddings = max_position_embeddings
188
- self.base = base
189
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
190
- self.register_buffer("inv_freq", inv_freq, persistent=False)
191
-
192
- # Build here to make `torch.jit.trace` work.
193
- self._set_cos_sin_cache(
194
- seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
195
- )
196
-
197
- def _set_cos_sin_cache(self, seq_len, device, dtype):
198
- self.max_seq_len_cached = seq_len
199
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32)
200
-
201
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
202
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
203
- emb = torch.cat((freqs, freqs), dim=-1)
204
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
205
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
206
-
207
- def forward(self, x, seq_len=None):
208
- # x: [bs, num_attention_heads, seq_len, head_size]
209
- if seq_len > self.max_seq_len_cached:
210
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
211
-
212
- return (
213
- self.cos_cached[:seq_len, ...].to(dtype=x.dtype),
214
- self.sin_cached[:seq_len, ...].to(dtype=x.dtype),
215
- )
216
-
217
 
218
- class NTKScalingRotaryEmbedding(RotaryEmbedding):
219
- """RotaryEmbedding extended with fixed and mixed NTK scaling. https://kexue.fm/archives/9706 """
 
 
220
 
221
- def __init__(self, dim, max_position_embeddings=512, base=10000, device=None, scaling_factor=1.0, mixed_b=None):
222
- self.scaling_factor = scaling_factor
223
- self.mixed_b = mixed_b
224
- super().__init__(dim, max_position_embeddings, base, device)
225
- max_position_embeddings = max_position_embeddings * self.scaling_factor
226
- self._set_cos_sin_cache(max_position_embeddings, self.inv_freq.device, torch.get_default_dtype())
227
 
228
- def _set_cos_sin_cache(self, seq_len, device, dtype):
229
- self.max_seq_len_cached = seq_len
 
 
 
 
230
 
231
- if seq_len > self.max_position_embeddings:
232
- base = self.base * (self.scaling_factor if self.mixed_b is None else 1)
233
- inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
234
 
235
- if self.mixed_b is None:
236
- inv_freq = inv_freq / self.scaling_factor ** (2 / self.dim) # (6)
237
- else:
238
- a = torch.tensor(self.scaling_factor).log() / (self.dim / 2) ** self.mixed_b # (13)
239
- lambda_1_m = (a * torch.arange(1, self.dim // 2 + 1).float().to(device) ** self.mixed_b).exp() # (12)
240
- inv_freq = inv_freq / lambda_1_m # (10)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
 
242
- self.register_buffer("inv_freq", inv_freq, persistent=False)
 
 
 
243
 
244
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32)
 
 
 
 
 
245
 
246
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
247
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
248
- emb = torch.cat((freqs, freqs), dim=-1)
249
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
250
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
251
 
252
 
253
  class RMSNorm(nn.Module):
@@ -291,7 +306,7 @@ class NewEmbeddings(nn.Module):
291
  config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
292
  )
293
  elif self.position_embedding_type == 'rope':
294
- self._init_rope(config)
295
  else:
296
  raise ValueError
297
 
@@ -308,27 +323,6 @@ class NewEmbeddings(nn.Module):
308
  "position_ids", torch.arange(config.max_position_embeddings), persistent=False
309
  )
310
 
311
- def _init_rope(self, config):
312
- kwargs = dict(
313
- dim=int(config.hidden_size / config.num_attention_heads),
314
- max_position_embeddings=config.max_position_embeddings,
315
- base=config.rope_theta
316
- )
317
- if config.rope_scaling is None:
318
- self.rotary_emb = RotaryEmbedding(**kwargs)
319
- else:
320
- kwargs.update(scaling_factor=config.rope_scaling["factor"])
321
- scaling_type = config.rope_scaling["type"]
322
- if scaling_type == 'ntk':
323
- kwargs.update(mixed_b=config.rope_scaling.get('mixed_b', None))
324
- self.rotary_emb = NTKScalingRotaryEmbedding(**kwargs)
325
- # elif scaling_type == "linear":
326
- # self.rotary_emb = LinearScalingRotaryEmbedding(**kwargs)
327
- # elif scaling_type == "dynamic":
328
- # self.rotary_emb = DynamicNTKScalingRotaryEmbedding(**kwargs)
329
- else:
330
- raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
331
-
332
  def forward(
333
  self,
334
  unpad_inputs: bool,
@@ -339,8 +333,6 @@ class NewEmbeddings(nn.Module):
339
  position_ids: Optional[torch.Tensor] = None,
340
  inputs_embeds: Optional[torch.Tensor] = None,
341
  ) -> Tuple[torch.Tensor, torch.Tensor, Optional[Tuple], Optional[List[int]]]:
342
- """
343
- """
344
  if inputs_embeds is None:
345
  device, input_shape = input_ids.device, input_ids.shape
346
  else:
@@ -372,24 +364,21 @@ class NewEmbeddings(nn.Module):
372
 
373
  # Set and unpad position_ids
374
  if position_ids is None:
375
- if seq_length > self.position_ids.size(0):
376
- self.register_buffer(
377
- "position_ids", torch.arange(seq_length), persistent=False
378
- )
379
  if unpad_inputs:
380
  # [1, cumsum_seq_len]
381
- position_ids = torch.cat([self.position_ids[:l] for l in length]).unsqueeze(0)
382
  else:
383
  # [bs, seq_len]
384
- position_ids = self.position_ids[:seq_length].expand(batch_size, -1)
385
  elif unpad_inputs:
386
  position_ids = position_ids[attention_mask_bool].unsqueeze(0) # [1, cumsum_seq_len]
387
 
388
  # Compute rotary embedding
389
  if self.position_embedding_type == 'rope':
390
- rope_cos, rope_sin = self.rotary_emb(inputs_embeds, seq_len=seq_length)
391
- rope_cos = rope_cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
392
- rope_sin = rope_sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
393
  rope_embeds = rope_cos, rope_sin
394
  else:
395
  rope_embeds = None
@@ -793,22 +782,6 @@ class NewPreTrainedModel(PreTrainedModel):
793
  base_model_prefix = "new"
794
  supports_gradient_checkpointing = True
795
 
796
- def _init_weights(self, module):
797
- """Initialize the weights"""
798
- if isinstance(module, nn.Linear):
799
- # Slightly different from the TF version which uses truncated_normal for initialization
800
- # cf https://github.com/pytorch/pytorch/pull/5617
801
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
802
- if module.bias is not None:
803
- module.bias.data.zero_()
804
- elif isinstance(module, nn.Embedding):
805
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
806
- if module.padding_idx is not None:
807
- module.weight.data[module.padding_idx].zero_()
808
- elif isinstance(module, nn.LayerNorm):
809
- module.bias.data.zero_()
810
- module.weight.data.fill_(1.0)
811
-
812
 
813
  class NewModel(NewPreTrainedModel):
814
  """
 
16
  """PyTorch NEW model."""
17
 
18
  import math
19
+ from contextlib import nullcontext
20
+ from typing import List, Optional, Tuple, Union, Callable
21
 
22
  import torch
23
  import torch.utils.checkpoint
24
  from torch import nn
25
+ from transformers import ROPE_INIT_FUNCTIONS
26
 
27
  from transformers.activations import ACT2FN
28
  from transformers.modeling_outputs import (
 
141
  index_put_first_axis = IndexPutFirstAxis.apply
142
 
143
 
144
+ def maybe_autocast(
145
+ device_type: str,
146
+ dtype: Optional["_dtype"] = None,
147
+ enabled: bool = True,
148
+ cache_enabled: bool | None = None,
149
+ ):
150
+ """
151
+ Context manager that only autocasts if:
152
+
153
+ - `autocast` is already enabled in this context
154
+ - Or this call to `maybe_autocast` has `enabled=True`
155
+
156
+ This prevents `autocast` being added to the graph when it is effectively a no-op.
157
+ Which makes graph splitting in `torch.compile` more flexible as it removes the
158
+ requirement that partition IDs be monotonically increasing.
159
+ """
160
+ if torch.is_autocast_enabled(device_type) or enabled:
161
+ return torch.autocast(device_type, dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
162
+ else:
163
+ return nullcontext()
164
+
165
+
166
  def pad_input(inputs: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int) -> torch.Tensor:
167
  """Add padding to sequences.
168
 
 
186
  return torch.cat((-x2, x1), dim=-1)
187
 
188
 
189
+ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
190
  """Applies Rotary Position Embedding to the query and key tensors.
191
 
192
  Args:
 
194
  k (`torch.Tensor`): The key tensor.
195
  cos (`torch.Tensor`): The cosine part of the rotary embedding.
196
  sin (`torch.Tensor`): The sine part of the rotary embedding.
197
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
198
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
199
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
200
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
201
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
202
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
203
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
204
  Returns:
205
  `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
206
  """
 
207
  q_embed = (q * cos) + (rotate_half(q) * sin)
208
  k_embed = (k * cos) + (rotate_half(k) * sin)
209
  return q_embed, k_embed
210
 
211
 
212
+ class RotaryEmbedding(nn.Module):
213
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
+ def __init__(self, config: NewConfig, device=None):
216
+ super().__init__()
217
+ self.max_seq_len_cached = config.max_position_embeddings
218
+ self.original_max_seq_len = config.max_position_embeddings
219
 
220
+ self.config = config
 
 
 
 
 
221
 
222
+ self.rope_type = self.config.rope_parameters["rope_type"]
223
+ if self.rope_type == "default":
224
+ rope_init_fn: Callable = self.compute_default_rope_parameters
225
+ else:
226
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
227
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
228
 
229
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
230
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
 
231
 
232
+ @staticmethod
233
+ def compute_default_rope_parameters(
234
+ config: NewConfig | None = None,
235
+ device: Optional["torch.device"] = None,
236
+ ) -> tuple["torch.Tensor", float]:
237
+ """Computes rope parameters with NTK scaling"""
238
+ scaling_factor = config.rope_parameters.get("factor", 1.0)
239
+ dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
240
+ base = config.rope_parameters["rope_theta"]
241
+ mixed_b = config.rope_parameters.get("mixed_b", None)
242
+
243
+ base = base * (scaling_factor if mixed_b is None else 1)
244
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
245
+ if mixed_b is None:
246
+ inv_freq = inv_freq / scaling_factor ** (2 / dim)
247
+ else:
248
+ a = torch.tensor(scaling_factor).log() / (dim / 2) ** mixed_b
249
+ lambda_1_m = (a * torch.arange(1, dim // 2 + 1).float().to(device) ** mixed_b).exp()
250
+ inv_freq = inv_freq / lambda_1_m
251
+ return inv_freq, 1.0
252
 
253
+ @torch.no_grad()
254
+ def forward(self, x, position_ids):
255
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
256
+ position_ids_expanded = position_ids[:, None, :].float()
257
 
258
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
259
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
260
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
261
+ emb = torch.cat((freqs, freqs), dim=-1)
262
+ cos = emb.cos() * self.attention_scaling
263
+ sin = emb.sin() * self.attention_scaling
264
 
265
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
 
 
 
 
266
 
267
 
268
  class RMSNorm(nn.Module):
 
306
  config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
307
  )
308
  elif self.position_embedding_type == 'rope':
309
+ self.rotary_emb = RotaryEmbedding(config)
310
  else:
311
  raise ValueError
312
 
 
323
  "position_ids", torch.arange(config.max_position_embeddings), persistent=False
324
  )
325
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
  def forward(
327
  self,
328
  unpad_inputs: bool,
 
333
  position_ids: Optional[torch.Tensor] = None,
334
  inputs_embeds: Optional[torch.Tensor] = None,
335
  ) -> Tuple[torch.Tensor, torch.Tensor, Optional[Tuple], Optional[List[int]]]:
 
 
336
  if inputs_embeds is None:
337
  device, input_shape = input_ids.device, input_ids.shape
338
  else:
 
364
 
365
  # Set and unpad position_ids
366
  if position_ids is None:
367
+ position_ids = torch.arange(seq_length, device=inputs_embeds.device)
 
 
 
368
  if unpad_inputs:
369
  # [1, cumsum_seq_len]
370
+ position_ids = torch.cat([position_ids[:l] for l in length]).unsqueeze(0)
371
  else:
372
  # [bs, seq_len]
373
+ position_ids = position_ids[:seq_length].expand(batch_size, -1)
374
  elif unpad_inputs:
375
  position_ids = position_ids[attention_mask_bool].unsqueeze(0) # [1, cumsum_seq_len]
376
 
377
  # Compute rotary embedding
378
  if self.position_embedding_type == 'rope':
379
+ rope_cos, rope_sin = self.rotary_emb(inputs_embeds, position_ids)
380
+ rope_cos = rope_cos.unsqueeze(2) # [bs, seq_len, 1, dim]
381
+ rope_sin = rope_sin.unsqueeze(2) # [bs, seq_len, 1, dim]
382
  rope_embeds = rope_cos, rope_sin
383
  else:
384
  rope_embeds = None
 
782
  base_model_prefix = "new"
783
  supports_gradient_checkpointing = True
784
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
785
 
786
  class NewModel(NewPreTrainedModel):
787
  """