vermouthdky commited on
Commit
7c50ccf
·
verified ·
1 Parent(s): 49da737

Upload 3 files

Browse files
Files changed (2) hide show
  1. modeling_qwen2.py +266 -87
  2. nets.py +174 -0
modeling_qwen2.py CHANGED
@@ -40,8 +40,8 @@ from transformers.utils import (add_start_docstrings,
40
  is_flash_attn_greater_or_equal_2_10, logging,
41
  replace_return_docstrings)
42
 
43
- from ..nets import EnsembleModel
44
  from .configuration_qwen2 import QwenEnPRMConfig as Qwen2Config
 
45
 
46
  if is_flash_attn_2_available():
47
  from transformers.modeling_flash_attention_utils import \
@@ -92,19 +92,30 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
92
  # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
93
  causal_mask = attention_mask
94
  else:
95
- causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
 
 
 
 
 
96
  if sequence_length != 1:
97
  causal_mask = torch.triu(causal_mask, diagonal=1)
98
- causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
 
 
99
  causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
100
  if attention_mask is not None:
101
- causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
 
 
102
  mask_length = attention_mask.shape[-1]
103
- padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
104
- padding_mask = padding_mask == 0
105
- causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
106
- padding_mask, min_dtype
107
  )
 
 
 
 
108
 
109
  return causal_mask
110
 
@@ -138,17 +149,27 @@ class Qwen2RotaryEmbedding(nn.Module):
138
  self.dim = dim
139
  self.max_position_embeddings = max_position_embeddings
140
  self.base = base
141
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
 
 
 
 
 
 
142
  self.register_buffer("inv_freq", inv_freq, persistent=False)
143
 
144
  # Build here to make `torch.jit.trace` work.
145
  self._set_cos_sin_cache(
146
- seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
 
 
147
  )
148
 
149
  def _set_cos_sin_cache(self, seq_len, device, dtype):
150
  self.max_seq_len_cached = seq_len
151
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
 
 
152
 
153
  freqs = torch.outer(t, self.inv_freq)
154
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
@@ -216,7 +237,9 @@ class Qwen2MLP(nn.Module):
216
  self.act_fn = ACT2FN[config.hidden_act]
217
 
218
  def forward(self, hidden_state):
219
- return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
 
 
220
 
221
 
222
  # Copied from transformers.models.llama.modeling_llama.repeat_kv
@@ -228,7 +251,9 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
228
  batch, num_key_value_heads, slen, head_dim = hidden_states.shape
229
  if n_rep == 1:
230
  return hidden_states
231
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
 
 
232
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
233
 
234
 
@@ -264,10 +289,18 @@ class Qwen2Attention(nn.Module):
264
  f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
265
  f" and `num_heads`: {self.num_heads})."
266
  )
267
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
268
- self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
269
- self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
270
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
 
 
 
 
 
 
 
 
271
 
272
  self.rotary_emb = Qwen2RotaryEmbedding(
273
  self.head_dim,
@@ -291,9 +324,15 @@ class Qwen2Attention(nn.Module):
291
  key_states = self.k_proj(hidden_states)
292
  value_states = self.v_proj(hidden_states)
293
 
294
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
295
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
296
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
 
 
 
 
 
 
297
 
298
  kv_seq_len = key_states.shape[-2]
299
  if past_key_value is not None:
@@ -305,17 +344,27 @@ class Qwen2Attention(nn.Module):
305
  )
306
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
307
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
308
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
 
 
309
 
310
  if past_key_value is not None:
311
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
312
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
 
 
 
 
 
 
313
 
314
  # repeat k/v heads if n_kv_heads < n_heads
315
  key_states = repeat_kv(key_states, self.num_key_value_groups)
316
  value_states = repeat_kv(value_states, self.num_key_value_groups)
317
 
318
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 
 
319
 
320
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
321
  raise ValueError(
@@ -328,8 +377,12 @@ class Qwen2Attention(nn.Module):
328
  attn_weights = attn_weights + causal_mask
329
 
330
  # upcast attention to fp32
331
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
332
- attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
 
 
 
 
333
  attn_output = torch.matmul(attn_weights, value_states)
334
 
335
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
@@ -383,9 +436,15 @@ class Qwen2FlashAttention2(Qwen2Attention):
383
  key_states = self.k_proj(hidden_states)
384
  value_states = self.v_proj(hidden_states)
385
 
386
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
387
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
388
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
 
 
 
 
 
 
389
 
390
  kv_seq_len = key_states.shape[-2]
391
  if past_key_value is not None:
@@ -399,12 +458,16 @@ class Qwen2FlashAttention2(Qwen2Attention):
399
 
400
  # Because the input can be padded, the absolute sequence length depends on the max position id.
401
  rotary_seq_len = (
402
- max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len
 
 
403
  )
404
 
405
  cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
406
 
407
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
 
 
408
 
409
  if past_key_value is not None:
410
  # Activate slicing cache only if the config has a value `sliding_windows` attribute
@@ -430,10 +493,19 @@ class Qwen2FlashAttention2(Qwen2Attention):
430
 
431
  if attention_mask is not None:
432
  attention_mask = attention_mask[:, slicing_tokens:]
433
- attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
 
 
 
434
 
435
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
436
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
 
 
 
 
 
 
437
 
438
  # repeat k/v heads if n_kv_heads < n_heads
439
  key_states = repeat_kv(key_states, self.num_key_value_groups)
@@ -539,20 +611,34 @@ class Qwen2SdpaAttention(Qwen2Attention):
539
  key_states = self.k_proj(hidden_states)
540
  value_states = self.v_proj(hidden_states)
541
 
542
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
543
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
544
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
 
 
 
 
 
 
545
 
546
  kv_seq_len = key_states.shape[-2]
547
  if past_key_value is not None:
548
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
549
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
550
 
551
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
 
 
552
 
553
  if past_key_value is not None:
554
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
555
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
 
 
 
 
 
 
556
 
557
  key_states = repeat_kv(key_states, self.num_key_value_groups)
558
  value_states = repeat_kv(value_states, self.num_key_value_groups)
@@ -607,11 +693,15 @@ class Qwen2DecoderLayer(nn.Module):
607
  f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
608
  "unexpected results may be encountered."
609
  )
610
- self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
 
 
611
 
612
  self.mlp = Qwen2MLP(config)
613
  self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
614
- self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
 
615
 
616
  def forward(
617
  self,
@@ -623,7 +713,9 @@ class Qwen2DecoderLayer(nn.Module):
623
  use_cache: Optional[bool] = False,
624
  cache_position: Optional[torch.LongTensor] = None,
625
  **kwargs,
626
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
 
 
627
  """
628
  Args:
629
  hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
@@ -810,9 +902,14 @@ class Qwen2Model(Qwen2PreTrainedModel):
810
  self.padding_idx = config.pad_token_id
811
  self.vocab_size = config.vocab_size
812
 
813
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
 
 
814
  self.layers = nn.ModuleList(
815
- [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
 
 
 
816
  )
817
  self._attn_implementation = config._attn_implementation
818
  self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -841,13 +938,21 @@ class Qwen2Model(Qwen2PreTrainedModel):
841
  return_dict: Optional[bool] = None,
842
  cache_position: Optional[torch.LongTensor] = None,
843
  ) -> Union[Tuple, BaseModelOutputWithPast]:
844
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
 
845
  output_hidden_states = (
846
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
847
  )
848
  use_cache = use_cache if use_cache is not None else self.config.use_cache
849
 
850
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
851
 
852
  if (input_ids is None) ^ (inputs_embeds is not None):
853
  raise ValueError(
@@ -874,15 +979,23 @@ class Qwen2Model(Qwen2PreTrainedModel):
874
  inputs_embeds = self.embed_tokens(input_ids)
875
 
876
  if cache_position is None:
877
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
 
 
878
  cache_position = torch.arange(
879
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
 
 
880
  )
881
  if position_ids is None:
882
  position_ids = cache_position.unsqueeze(0)
883
 
884
  causal_mask = self._update_causal_mask(
885
- attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
 
 
 
 
886
  )
887
 
888
  hidden_states = inputs_embeds
@@ -934,10 +1047,18 @@ class Qwen2Model(Qwen2PreTrainedModel):
934
 
935
  next_cache = None
936
  if use_cache:
937
- next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
 
 
 
 
938
 
939
  if not return_dict:
940
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
 
 
 
 
941
  return BaseModelOutputWithPast(
942
  last_hidden_state=hidden_states,
943
  past_key_values=next_cache,
@@ -967,11 +1088,17 @@ class Qwen2Model(Qwen2PreTrainedModel):
967
  # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
968
  # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
969
  # to infer the attention mask.
970
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
 
 
971
  using_static_cache = False # isinstance(past_key_values, StaticCache)
972
 
973
  # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
974
- if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
 
 
 
 
975
  if AttentionMaskConverter._ignore_causal_mask_sdpa(
976
  attention_mask,
977
  inputs_embeds=input_tensor,
@@ -1013,7 +1140,9 @@ class Qwen2Model(Qwen2PreTrainedModel):
1013
  # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1014
  # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1015
  # Details: https://github.com/pytorch/pytorch/issues/110213
1016
- causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
 
 
1017
 
1018
  return causal_mask
1019
 
@@ -1049,7 +1178,9 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
1049
  return self.model
1050
 
1051
  @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
1052
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
 
 
1053
  def forward(
1054
  self,
1055
  input_ids: torch.LongTensor = None,
@@ -1090,11 +1221,19 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
1090
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1091
  ```"""
1092
 
1093
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
 
1094
  output_hidden_states = (
1095
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
 
 
 
1096
  )
1097
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1098
 
1099
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1100
  outputs = self.model(
@@ -1157,7 +1296,9 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
1157
  if past_key_values is not None:
1158
  if inputs_embeds is not None: # Exception 1
1159
  input_ids = input_ids[:, -cache_position.shape[0] :]
1160
- elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
 
 
1161
  input_ids = input_ids[:, cache_position]
1162
 
1163
  if attention_mask is not None and position_ids is None:
@@ -1176,7 +1317,11 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
1176
  else:
1177
  model_inputs = {"input_ids": input_ids}
1178
 
1179
- if False and isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
 
 
 
 
1180
  if inputs_embeds is not None:
1181
  batch_size, sequence_length = inputs_embeds.shape
1182
  device = inputs_embeds.device
@@ -1261,7 +1406,9 @@ class Qwen2ForSequenceClassification(Qwen2PreTrainedModel):
1261
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1262
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1263
  """
1264
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
1265
 
1266
  transformer_outputs = self.model(
1267
  input_ids,
@@ -1283,19 +1430,25 @@ class Qwen2ForSequenceClassification(Qwen2PreTrainedModel):
1283
  batch_size = inputs_embeds.shape[0]
1284
 
1285
  if self.config.pad_token_id is None and batch_size != 1:
1286
- raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
 
 
1287
  if self.config.pad_token_id is None:
1288
  sequence_lengths = -1
1289
  else:
1290
  if input_ids is not None:
1291
  # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1292
- sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
 
 
1293
  sequence_lengths = sequence_lengths % input_ids.shape[-1]
1294
  sequence_lengths = sequence_lengths.to(logits.device)
1295
  else:
1296
  sequence_lengths = -1
1297
 
1298
- pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
 
 
1299
 
1300
  loss = None
1301
  if labels is not None:
@@ -1303,7 +1456,9 @@ class Qwen2ForSequenceClassification(Qwen2PreTrainedModel):
1303
  if self.config.problem_type is None:
1304
  if self.num_labels == 1:
1305
  self.config.problem_type = "regression"
1306
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
 
 
1307
  self.config.problem_type = "single_label_classification"
1308
  else:
1309
  self.config.problem_type = "multi_label_classification"
@@ -1316,7 +1471,9 @@ class Qwen2ForSequenceClassification(Qwen2PreTrainedModel):
1316
  loss = loss_fct(pooled_logits, labels)
1317
  elif self.config.problem_type == "single_label_classification":
1318
  loss_fct = CrossEntropyLoss()
1319
- loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
 
 
1320
  elif self.config.problem_type == "multi_label_classification":
1321
  loss_fct = BCEWithLogitsLoss()
1322
  loss = loss_fct(pooled_logits, labels)
@@ -1384,7 +1541,9 @@ class Qwen2ForTokenClassification(Qwen2PreTrainedModel):
1384
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1385
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1386
  """
1387
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
1388
 
1389
  outputs = self.model(
1390
  input_ids,
@@ -1445,7 +1604,7 @@ class Qwen2ForEnsemblePRM(Qwen2PreTrainedModel):
1445
  encoding_dim=config.hidden_size,
1446
  num_ensemble=config.num_ensemble,
1447
  )
1448
- # self.score.init()
1449
  # Initialize weights and apply final processing
1450
  self.post_init()
1451
 
@@ -1462,7 +1621,7 @@ class Qwen2ForEnsemblePRM(Qwen2PreTrainedModel):
1462
  outputs.logits = torch.nn.functional.sigmoid(outputs.logits)
1463
  return outputs
1464
 
1465
- def _compute_loss(self, logits, labels, return_reg_loss=False):
1466
  # NOTE: we only compute the loss for specific position (labels != -100)
1467
  logits = logits.float()
1468
  loss = None
@@ -1471,21 +1630,23 @@ class Qwen2ForEnsemblePRM(Qwen2PreTrainedModel):
1471
  # only support hard labels; not need for soft labels
1472
  loss_fct = BCEWithLogitsLoss(reduction="none")
1473
 
1474
- loss = loss_fct(logits, labels[None].repeat([logits.size(0), 1, 1]).to(logits.dtype))
 
 
1475
  # select loss for specific position
1476
  mask = (labels != -100)[None].repeat([logits.size(0), 1, 1])
1477
  # and random mask instance for differnet ensemble model
1478
- data_aloc_mask = torch.rand(mask.size(0), mask.size(1)) < self.learning_probability
 
 
1479
  mask = mask & data_aloc_mask[:, :, None].to(mask.device)
1480
 
1481
  loss = torch.masked_select(loss, mask)
1482
  loss = loss.mean()
1483
- reg_loss = self.regularization_lambda * self.score.regularization()
1484
- loss += reg_loss
1485
- if not return_reg_loss:
1486
- return loss
1487
- else:
1488
- return (loss, reg_loss)
1489
 
1490
  @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
1491
  def forward(
@@ -1501,7 +1662,9 @@ class Qwen2ForEnsemblePRM(Qwen2PreTrainedModel):
1501
  output_hidden_states: Optional[bool] = None,
1502
  return_dict: Optional[bool] = None,
1503
  ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1504
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
1505
 
1506
  transformer_outputs = self.model(
1507
  input_ids,
@@ -1515,7 +1678,9 @@ class Qwen2ForEnsemblePRM(Qwen2PreTrainedModel):
1515
  return_dict=return_dict,
1516
  )
1517
  hidden_states = transformer_outputs[0] # (b, l, h)
1518
- hidden_states = hidden_states[None, :, :, :].repeat(self.score.num_ensemble, 1, 1, 1) # (e, l, h)
 
 
1519
  logits = self.score(hidden_states)
1520
 
1521
  if input_ids is not None:
@@ -1524,13 +1689,17 @@ class Qwen2ForEnsemblePRM(Qwen2PreTrainedModel):
1524
  batch_size = inputs_embeds.shape[0]
1525
 
1526
  if self.config.pad_token_id is None and batch_size != 1:
1527
- raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
 
 
1528
  if self.config.pad_token_id is None:
1529
  sequence_lengths = -1
1530
  else:
1531
  if input_ids is not None:
1532
  # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1533
- sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
 
 
1534
  sequence_lengths = sequence_lengths % input_ids.shape[-1]
1535
  sequence_lengths = sequence_lengths.to(logits.device)
1536
  else:
@@ -1538,7 +1707,9 @@ class Qwen2ForEnsemblePRM(Qwen2PreTrainedModel):
1538
 
1539
  logits = logits.float()
1540
  loss = None
1541
- logits = logits.squeeze(-1) # (ensemble, batch_size, seq_len, 1) -> (ensemble, batch_size, seq_len)
 
 
1542
  if labels is not None:
1543
  if self.config.problem_type is None: # NOTE: no use
1544
  if labels.dtype is not torch.long:
@@ -1550,16 +1721,24 @@ class Qwen2ForEnsemblePRM(Qwen2PreTrainedModel):
1550
  # only support hard labels; not need for soft labels
1551
  loss_fct = BCEWithLogitsLoss(reduction="none")
1552
 
1553
- loss = loss_fct(logits, labels[None].repeat([logits.size(0), 1, 1]).to(logits.dtype))
 
 
1554
  # select loss for specific position
1555
  mask = (labels != -100)[None].repeat([logits.size(0), 1, 1])
1556
  # and random mask instance for differnet ensemble model
1557
- data_aloc_mask = torch.rand(mask.size(0), mask.size(1)) < self.learning_probability
 
 
1558
  mask = mask & data_aloc_mask[:, :, None].to(mask.device)
1559
 
1560
  loss = torch.masked_select(loss, mask)
1561
  loss = loss.mean()
1562
- loss += self.regularization_lambda * labels.size(0) * self.score.regularization()
 
 
 
 
1563
 
1564
  if not return_dict:
1565
  output = (logits,) + transformer_outputs[1:]
 
40
  is_flash_attn_greater_or_equal_2_10, logging,
41
  replace_return_docstrings)
42
 
 
43
  from .configuration_qwen2 import QwenEnPRMConfig as Qwen2Config
44
+ from .nets import EnsembleModel
45
 
46
  if is_flash_attn_2_available():
47
  from transformers.modeling_flash_attention_utils import \
 
92
  # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
93
  causal_mask = attention_mask
94
  else:
95
+ causal_mask = torch.full(
96
+ (sequence_length, target_length),
97
+ fill_value=min_dtype,
98
+ dtype=dtype,
99
+ device=device,
100
+ )
101
  if sequence_length != 1:
102
  causal_mask = torch.triu(causal_mask, diagonal=1)
103
+ causal_mask *= torch.arange(
104
+ target_length, device=device
105
+ ) > cache_position.reshape(-1, 1)
106
  causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
107
  if attention_mask is not None:
108
+ causal_mask = (
109
+ causal_mask.clone()
110
+ ) # copy to contiguous memory for in-place edit
111
  mask_length = attention_mask.shape[-1]
112
+ padding_mask = (
113
+ causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
 
 
114
  )
115
+ padding_mask = padding_mask == 0
116
+ causal_mask[:, :, :, :mask_length] = causal_mask[
117
+ :, :, :, :mask_length
118
+ ].masked_fill(padding_mask, min_dtype)
119
 
120
  return causal_mask
121
 
 
149
  self.dim = dim
150
  self.max_position_embeddings = max_position_embeddings
151
  self.base = base
152
+ inv_freq = 1.0 / (
153
+ self.base
154
+ ** (
155
+ torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device)
156
+ / self.dim
157
+ )
158
+ )
159
  self.register_buffer("inv_freq", inv_freq, persistent=False)
160
 
161
  # Build here to make `torch.jit.trace` work.
162
  self._set_cos_sin_cache(
163
+ seq_len=max_position_embeddings,
164
+ device=self.inv_freq.device,
165
+ dtype=torch.get_default_dtype(),
166
  )
167
 
168
  def _set_cos_sin_cache(self, seq_len, device, dtype):
169
  self.max_seq_len_cached = seq_len
170
+ t = torch.arange(
171
+ self.max_seq_len_cached, device=device, dtype=torch.int64
172
+ ).type_as(self.inv_freq)
173
 
174
  freqs = torch.outer(t, self.inv_freq)
175
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
 
237
  self.act_fn = ACT2FN[config.hidden_act]
238
 
239
  def forward(self, hidden_state):
240
+ return self.down_proj(
241
+ self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)
242
+ )
243
 
244
 
245
  # Copied from transformers.models.llama.modeling_llama.repeat_kv
 
251
  batch, num_key_value_heads, slen, head_dim = hidden_states.shape
252
  if n_rep == 1:
253
  return hidden_states
254
+ hidden_states = hidden_states[:, :, None, :, :].expand(
255
+ batch, num_key_value_heads, n_rep, slen, head_dim
256
+ )
257
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
258
 
259
 
 
289
  f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
290
  f" and `num_heads`: {self.num_heads})."
291
  )
292
+ self.q_proj = nn.Linear(
293
+ self.hidden_size, self.num_heads * self.head_dim, bias=True
294
+ )
295
+ self.k_proj = nn.Linear(
296
+ self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
297
+ )
298
+ self.v_proj = nn.Linear(
299
+ self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
300
+ )
301
+ self.o_proj = nn.Linear(
302
+ self.num_heads * self.head_dim, self.hidden_size, bias=False
303
+ )
304
 
305
  self.rotary_emb = Qwen2RotaryEmbedding(
306
  self.head_dim,
 
324
  key_states = self.k_proj(hidden_states)
325
  value_states = self.v_proj(hidden_states)
326
 
327
+ query_states = query_states.view(
328
+ bsz, q_len, self.num_heads, self.head_dim
329
+ ).transpose(1, 2)
330
+ key_states = key_states.view(
331
+ bsz, q_len, self.num_key_value_heads, self.head_dim
332
+ ).transpose(1, 2)
333
+ value_states = value_states.view(
334
+ bsz, q_len, self.num_key_value_heads, self.head_dim
335
+ ).transpose(1, 2)
336
 
337
  kv_seq_len = key_states.shape[-2]
338
  if past_key_value is not None:
 
344
  )
345
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
346
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
347
+ query_states, key_states = apply_rotary_pos_emb(
348
+ query_states, key_states, cos, sin, position_ids
349
+ )
350
 
351
  if past_key_value is not None:
352
+ cache_kwargs = {
353
+ "sin": sin,
354
+ "cos": cos,
355
+ "cache_position": cache_position,
356
+ } # Specific to RoPE models
357
+ key_states, value_states = past_key_value.update(
358
+ key_states, value_states, self.layer_idx, cache_kwargs
359
+ )
360
 
361
  # repeat k/v heads if n_kv_heads < n_heads
362
  key_states = repeat_kv(key_states, self.num_key_value_groups)
363
  value_states = repeat_kv(value_states, self.num_key_value_groups)
364
 
365
+ attn_weights = torch.matmul(
366
+ query_states, key_states.transpose(2, 3)
367
+ ) / math.sqrt(self.head_dim)
368
 
369
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
370
  raise ValueError(
 
377
  attn_weights = attn_weights + causal_mask
378
 
379
  # upcast attention to fp32
380
+ attn_weights = nn.functional.softmax(
381
+ attn_weights, dim=-1, dtype=torch.float32
382
+ ).to(query_states.dtype)
383
+ attn_weights = nn.functional.dropout(
384
+ attn_weights, p=self.attention_dropout, training=self.training
385
+ )
386
  attn_output = torch.matmul(attn_weights, value_states)
387
 
388
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
 
436
  key_states = self.k_proj(hidden_states)
437
  value_states = self.v_proj(hidden_states)
438
 
439
+ query_states = query_states.view(
440
+ bsz, q_len, self.num_heads, self.head_dim
441
+ ).transpose(1, 2)
442
+ key_states = key_states.view(
443
+ bsz, q_len, self.num_key_value_heads, self.head_dim
444
+ ).transpose(1, 2)
445
+ value_states = value_states.view(
446
+ bsz, q_len, self.num_key_value_heads, self.head_dim
447
+ ).transpose(1, 2)
448
 
449
  kv_seq_len = key_states.shape[-2]
450
  if past_key_value is not None:
 
458
 
459
  # Because the input can be padded, the absolute sequence length depends on the max position id.
460
  rotary_seq_len = (
461
+ max(kv_seq_len, position_ids[:, -1].max().item() + 1)
462
+ if position_ids is not None
463
+ else kv_seq_len
464
  )
465
 
466
  cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
467
 
468
+ query_states, key_states = apply_rotary_pos_emb(
469
+ query_states, key_states, cos, sin, position_ids
470
+ )
471
 
472
  if past_key_value is not None:
473
  # Activate slicing cache only if the config has a value `sliding_windows` attribute
 
493
 
494
  if attention_mask is not None:
495
  attention_mask = attention_mask[:, slicing_tokens:]
496
+ attention_mask = torch.cat(
497
+ [attention_mask, torch.ones_like(attention_mask[:, -1:])],
498
+ dim=-1,
499
+ )
500
 
501
+ cache_kwargs = {
502
+ "sin": sin,
503
+ "cos": cos,
504
+ "cache_position": cache_position,
505
+ } # Specific to RoPE models
506
+ key_states, value_states = past_key_value.update(
507
+ key_states, value_states, self.layer_idx, cache_kwargs
508
+ )
509
 
510
  # repeat k/v heads if n_kv_heads < n_heads
511
  key_states = repeat_kv(key_states, self.num_key_value_groups)
 
611
  key_states = self.k_proj(hidden_states)
612
  value_states = self.v_proj(hidden_states)
613
 
614
+ query_states = query_states.view(
615
+ bsz, q_len, self.num_heads, self.head_dim
616
+ ).transpose(1, 2)
617
+ key_states = key_states.view(
618
+ bsz, q_len, self.num_key_value_heads, self.head_dim
619
+ ).transpose(1, 2)
620
+ value_states = value_states.view(
621
+ bsz, q_len, self.num_key_value_heads, self.head_dim
622
+ ).transpose(1, 2)
623
 
624
  kv_seq_len = key_states.shape[-2]
625
  if past_key_value is not None:
626
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
627
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
628
 
629
+ query_states, key_states = apply_rotary_pos_emb(
630
+ query_states, key_states, cos, sin, position_ids
631
+ )
632
 
633
  if past_key_value is not None:
634
+ cache_kwargs = {
635
+ "sin": sin,
636
+ "cos": cos,
637
+ "cache_position": cache_position,
638
+ } # Specific to RoPE models
639
+ key_states, value_states = past_key_value.update(
640
+ key_states, value_states, self.layer_idx, cache_kwargs
641
+ )
642
 
643
  key_states = repeat_kv(key_states, self.num_key_value_groups)
644
  value_states = repeat_kv(value_states, self.num_key_value_groups)
 
693
  f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
694
  "unexpected results may be encountered."
695
  )
696
+ self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](
697
+ config, layer_idx
698
+ )
699
 
700
  self.mlp = Qwen2MLP(config)
701
  self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
702
+ self.post_attention_layernorm = Qwen2RMSNorm(
703
+ config.hidden_size, eps=config.rms_norm_eps
704
+ )
705
 
706
  def forward(
707
  self,
 
713
  use_cache: Optional[bool] = False,
714
  cache_position: Optional[torch.LongTensor] = None,
715
  **kwargs,
716
+ ) -> Tuple[
717
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
718
+ ]:
719
  """
720
  Args:
721
  hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
 
902
  self.padding_idx = config.pad_token_id
903
  self.vocab_size = config.vocab_size
904
 
905
+ self.embed_tokens = nn.Embedding(
906
+ config.vocab_size, config.hidden_size, self.padding_idx
907
+ )
908
  self.layers = nn.ModuleList(
909
+ [
910
+ Qwen2DecoderLayer(config, layer_idx)
911
+ for layer_idx in range(config.num_hidden_layers)
912
+ ]
913
  )
914
  self._attn_implementation = config._attn_implementation
915
  self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
938
  return_dict: Optional[bool] = None,
939
  cache_position: Optional[torch.LongTensor] = None,
940
  ) -> Union[Tuple, BaseModelOutputWithPast]:
941
+ output_attentions = (
942
+ output_attentions
943
+ if output_attentions is not None
944
+ else self.config.output_attentions
945
+ )
946
  output_hidden_states = (
947
+ output_hidden_states
948
+ if output_hidden_states is not None
949
+ else self.config.output_hidden_states
950
  )
951
  use_cache = use_cache if use_cache is not None else self.config.use_cache
952
 
953
+ return_dict = (
954
+ return_dict if return_dict is not None else self.config.use_return_dict
955
+ )
956
 
957
  if (input_ids is None) ^ (inputs_embeds is not None):
958
  raise ValueError(
 
979
  inputs_embeds = self.embed_tokens(input_ids)
980
 
981
  if cache_position is None:
982
+ past_seen_tokens = (
983
+ past_key_values.get_seq_length() if past_key_values is not None else 0
984
+ )
985
  cache_position = torch.arange(
986
+ past_seen_tokens,
987
+ past_seen_tokens + inputs_embeds.shape[1],
988
+ device=inputs_embeds.device,
989
  )
990
  if position_ids is None:
991
  position_ids = cache_position.unsqueeze(0)
992
 
993
  causal_mask = self._update_causal_mask(
994
+ attention_mask,
995
+ inputs_embeds,
996
+ cache_position,
997
+ past_key_values,
998
+ output_attentions,
999
  )
1000
 
1001
  hidden_states = inputs_embeds
 
1047
 
1048
  next_cache = None
1049
  if use_cache:
1050
+ next_cache = (
1051
+ next_decoder_cache.to_legacy_cache()
1052
+ if use_legacy_cache
1053
+ else next_decoder_cache
1054
+ )
1055
 
1056
  if not return_dict:
1057
+ return tuple(
1058
+ v
1059
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
1060
+ if v is not None
1061
+ )
1062
  return BaseModelOutputWithPast(
1063
  last_hidden_state=hidden_states,
1064
  past_key_values=next_cache,
 
1088
  # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
1089
  # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
1090
  # to infer the attention mask.
1091
+ past_seen_tokens = (
1092
+ past_key_values.get_seq_length() if past_key_values is not None else 0
1093
+ )
1094
  using_static_cache = False # isinstance(past_key_values, StaticCache)
1095
 
1096
  # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1097
+ if (
1098
+ self.config._attn_implementation == "sdpa"
1099
+ and not using_static_cache
1100
+ and not output_attentions
1101
+ ):
1102
  if AttentionMaskConverter._ignore_causal_mask_sdpa(
1103
  attention_mask,
1104
  inputs_embeds=input_tensor,
 
1140
  # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1141
  # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1142
  # Details: https://github.com/pytorch/pytorch/issues/110213
1143
+ causal_mask = AttentionMaskConverter._unmask_unattended(
1144
+ causal_mask, min_dtype
1145
+ )
1146
 
1147
  return causal_mask
1148
 
 
1178
  return self.model
1179
 
1180
  @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
1181
+ @replace_return_docstrings(
1182
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
1183
+ )
1184
  def forward(
1185
  self,
1186
  input_ids: torch.LongTensor = None,
 
1221
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1222
  ```"""
1223
 
1224
+ output_attentions = (
1225
+ output_attentions
1226
+ if output_attentions is not None
1227
+ else self.config.output_attentions
1228
+ )
1229
  output_hidden_states = (
1230
+ output_hidden_states
1231
+ if output_hidden_states is not None
1232
+ else self.config.output_hidden_states
1233
+ )
1234
+ return_dict = (
1235
+ return_dict if return_dict is not None else self.config.use_return_dict
1236
  )
 
1237
 
1238
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1239
  outputs = self.model(
 
1296
  if past_key_values is not None:
1297
  if inputs_embeds is not None: # Exception 1
1298
  input_ids = input_ids[:, -cache_position.shape[0] :]
1299
+ elif (
1300
+ input_ids.shape[1] != cache_position.shape[0]
1301
+ ): # Default case (the "else", a no op, is Exception 2)
1302
  input_ids = input_ids[:, cache_position]
1303
 
1304
  if attention_mask is not None and position_ids is None:
 
1317
  else:
1318
  model_inputs = {"input_ids": input_ids}
1319
 
1320
+ if (
1321
+ False
1322
+ and isinstance(past_key_values, StaticCache)
1323
+ and attention_mask.ndim == 2
1324
+ ):
1325
  if inputs_embeds is not None:
1326
  batch_size, sequence_length = inputs_embeds.shape
1327
  device = inputs_embeds.device
 
1406
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1407
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1408
  """
1409
+ return_dict = (
1410
+ return_dict if return_dict is not None else self.config.use_return_dict
1411
+ )
1412
 
1413
  transformer_outputs = self.model(
1414
  input_ids,
 
1430
  batch_size = inputs_embeds.shape[0]
1431
 
1432
  if self.config.pad_token_id is None and batch_size != 1:
1433
+ raise ValueError(
1434
+ "Cannot handle batch sizes > 1 if no padding token is defined."
1435
+ )
1436
  if self.config.pad_token_id is None:
1437
  sequence_lengths = -1
1438
  else:
1439
  if input_ids is not None:
1440
  # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1441
+ sequence_lengths = (
1442
+ torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1443
+ )
1444
  sequence_lengths = sequence_lengths % input_ids.shape[-1]
1445
  sequence_lengths = sequence_lengths.to(logits.device)
1446
  else:
1447
  sequence_lengths = -1
1448
 
1449
+ pooled_logits = logits[
1450
+ torch.arange(batch_size, device=logits.device), sequence_lengths
1451
+ ]
1452
 
1453
  loss = None
1454
  if labels is not None:
 
1456
  if self.config.problem_type is None:
1457
  if self.num_labels == 1:
1458
  self.config.problem_type = "regression"
1459
+ elif self.num_labels > 1 and (
1460
+ labels.dtype == torch.long or labels.dtype == torch.int
1461
+ ):
1462
  self.config.problem_type = "single_label_classification"
1463
  else:
1464
  self.config.problem_type = "multi_label_classification"
 
1471
  loss = loss_fct(pooled_logits, labels)
1472
  elif self.config.problem_type == "single_label_classification":
1473
  loss_fct = CrossEntropyLoss()
1474
+ loss = loss_fct(
1475
+ pooled_logits.view(-1, self.num_labels), labels.view(-1)
1476
+ )
1477
  elif self.config.problem_type == "multi_label_classification":
1478
  loss_fct = BCEWithLogitsLoss()
1479
  loss = loss_fct(pooled_logits, labels)
 
1541
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1542
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1543
  """
1544
+ return_dict = (
1545
+ return_dict if return_dict is not None else self.config.use_return_dict
1546
+ )
1547
 
1548
  outputs = self.model(
1549
  input_ids,
 
1604
  encoding_dim=config.hidden_size,
1605
  num_ensemble=config.num_ensemble,
1606
  )
1607
+ self.score.init()
1608
  # Initialize weights and apply final processing
1609
  self.post_init()
1610
 
 
1621
  outputs.logits = torch.nn.functional.sigmoid(outputs.logits)
1622
  return outputs
1623
 
1624
+ def _compute_loss(self, logits, labels):
1625
  # NOTE: we only compute the loss for specific position (labels != -100)
1626
  logits = logits.float()
1627
  loss = None
 
1630
  # only support hard labels; not need for soft labels
1631
  loss_fct = BCEWithLogitsLoss(reduction="none")
1632
 
1633
+ loss = loss_fct(
1634
+ logits, labels[None].repeat([logits.size(0), 1, 1]).to(logits.dtype)
1635
+ )
1636
  # select loss for specific position
1637
  mask = (labels != -100)[None].repeat([logits.size(0), 1, 1])
1638
  # and random mask instance for differnet ensemble model
1639
+ data_aloc_mask = (
1640
+ torch.rand(mask.size(0), mask.size(1)) < self.learning_probability
1641
+ )
1642
  mask = mask & data_aloc_mask[:, :, None].to(mask.device)
1643
 
1644
  loss = torch.masked_select(loss, mask)
1645
  loss = loss.mean()
1646
+ loss += (
1647
+ self.regularization_lambda * labels.size(0) * self.score.regularization()
1648
+ )
1649
+ return loss
 
 
1650
 
1651
  @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
1652
  def forward(
 
1662
  output_hidden_states: Optional[bool] = None,
1663
  return_dict: Optional[bool] = None,
1664
  ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1665
+ return_dict = (
1666
+ return_dict if return_dict is not None else self.config.use_return_dict
1667
+ )
1668
 
1669
  transformer_outputs = self.model(
1670
  input_ids,
 
1678
  return_dict=return_dict,
1679
  )
1680
  hidden_states = transformer_outputs[0] # (b, l, h)
1681
+ hidden_states = hidden_states[None, :, :, :].repeat(
1682
+ self.score.num_ensemble, 1, 1, 1
1683
+ ) # (e, l, h)
1684
  logits = self.score(hidden_states)
1685
 
1686
  if input_ids is not None:
 
1689
  batch_size = inputs_embeds.shape[0]
1690
 
1691
  if self.config.pad_token_id is None and batch_size != 1:
1692
+ raise ValueError(
1693
+ "Cannot handle batch sizes > 1 if no padding token is defined."
1694
+ )
1695
  if self.config.pad_token_id is None:
1696
  sequence_lengths = -1
1697
  else:
1698
  if input_ids is not None:
1699
  # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1700
+ sequence_lengths = (
1701
+ torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1702
+ )
1703
  sequence_lengths = sequence_lengths % input_ids.shape[-1]
1704
  sequence_lengths = sequence_lengths.to(logits.device)
1705
  else:
 
1707
 
1708
  logits = logits.float()
1709
  loss = None
1710
+ logits = logits.squeeze(
1711
+ -1
1712
+ ) # (ensemble, batch_size, seq_len, 1) -> (ensemble, batch_size, seq_len)
1713
  if labels is not None:
1714
  if self.config.problem_type is None: # NOTE: no use
1715
  if labels.dtype is not torch.long:
 
1721
  # only support hard labels; not need for soft labels
1722
  loss_fct = BCEWithLogitsLoss(reduction="none")
1723
 
1724
+ loss = loss_fct(
1725
+ logits, labels[None].repeat([logits.size(0), 1, 1]).to(logits.dtype)
1726
+ )
1727
  # select loss for specific position
1728
  mask = (labels != -100)[None].repeat([logits.size(0), 1, 1])
1729
  # and random mask instance for differnet ensemble model
1730
+ data_aloc_mask = (
1731
+ torch.rand(mask.size(0), mask.size(1)) < self.learning_probability
1732
+ )
1733
  mask = mask & data_aloc_mask[:, :, None].to(mask.device)
1734
 
1735
  loss = torch.masked_select(loss, mask)
1736
  loss = loss.mean()
1737
+ loss += (
1738
+ self.regularization_lambda
1739
+ * labels.size(0)
1740
+ * self.score.regularization()
1741
+ )
1742
 
1743
  if not return_dict:
1744
  output = (logits,) + transformer_outputs[1:]
nets.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Garena Online Private Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Deep networks."""
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from torch import nn
21
+
22
+
23
+ def init_weights(m):
24
+ @torch.no_grad()
25
+ def truncated_normal_init(t, mean=0.0, std=0.01):
26
+ # torch.nn.init.normal_(t, mean=mean, std=std)
27
+ t.data.normal_(mean, std)
28
+ while True:
29
+ cond = torch.logical_or(t < mean - 2 * std, t > mean + 2 * std)
30
+ if not torch.sum(cond):
31
+ break
32
+ w = torch.empty(t.shape, device=t.device, dtype=t.dtype)
33
+ # torch.nn.init.normal_(w, mean=mean, std=std)
34
+ w.data.normal_(mean, std)
35
+ t = torch.where(cond, w, t)
36
+ return t
37
+
38
+ if type(m) is nn.Linear or isinstance(m, EnsembleFC):
39
+ truncated_normal_init(m.weight, std=1 / (2 * np.sqrt(m.in_features)))
40
+ if m.bias is not None:
41
+ m.bias.data.fill_(0.0)
42
+
43
+
44
+ def init_weights_uniform(m):
45
+ input_dim = m.in_features
46
+ torch.nn.init.uniform(m.weight, -1 / np.sqrt(input_dim), 1 / np.sqrt(input_dim))
47
+ if m.bias is not None:
48
+ m.bias.data.fill_(0.0)
49
+
50
+
51
+ class Swish(nn.Module):
52
+ def __init__(self):
53
+ super(Swish, self).__init__()
54
+
55
+ def forward(self, x):
56
+ x = x * F.sigmoid(x)
57
+ return x
58
+
59
+
60
+ class MLPModel(nn.Module):
61
+ def __init__(self, encoding_dim, hidden_dim=128, activation="relu") -> None:
62
+ super(MLPModel, self).__init__()
63
+ self.hidden_size = hidden_dim
64
+ self.output_dim = 1
65
+
66
+ self.nn1 = nn.Linear(encoding_dim, hidden_dim)
67
+ self.nn2 = nn.Linear(hidden_dim, hidden_dim)
68
+ self.nn_out = nn.Linear(hidden_dim, self.output_dim)
69
+
70
+ self.apply(init_weights)
71
+
72
+ if activation == "swish":
73
+ self.activation = Swish()
74
+ elif activation == "relu":
75
+ self.activation = nn.ReLU()
76
+ else:
77
+ raise ValueError(f"Unknown activation {activation}")
78
+
79
+ def get_params(self) -> torch.Tensor:
80
+ params = []
81
+ for pp in list(self.parameters()):
82
+ params.append(pp.view(-1))
83
+ return torch.cat(params)
84
+
85
+ def forward(self, encoding: torch.Tensor) -> torch.Tensor:
86
+ x = self.activation(self.nn1(encoding))
87
+ x = self.activation(self.nn2(x))
88
+ score = self.nn_out(x)
89
+ return score
90
+
91
+ def init(self):
92
+ self.init_params = self.get_params().data.clone()
93
+ if torch.cuda.is_available():
94
+ self.init_params = self.init_params.cuda()
95
+
96
+ def regularization(self):
97
+ """Prior towards independent initialization."""
98
+ return ((self.get_params() - self.init_params) ** 2).mean()
99
+
100
+
101
+ class EnsembleFC(nn.Module):
102
+ __constants__ = ["in_features", "out_features"]
103
+ in_features: int
104
+ out_features: int
105
+ ensemble_size: int
106
+ weight: torch.Tensor
107
+
108
+ def __init__(
109
+ self,
110
+ in_features: int,
111
+ out_features: int,
112
+ ensemble_size: int,
113
+ bias: bool = True,
114
+ dtype=torch.float32,
115
+ ) -> None:
116
+ super(EnsembleFC, self).__init__()
117
+ self.in_features = in_features
118
+ self.out_features = out_features
119
+ self.ensemble_size = ensemble_size
120
+ # init immediately to avoid error
121
+ self.weight = nn.Parameter(torch.empty(ensemble_size, in_features, out_features, dtype=dtype))
122
+ if bias:
123
+ self.bias = nn.Parameter(torch.empty(ensemble_size, out_features, dtype=dtype))
124
+ else:
125
+ self.register_parameter("bias", None)
126
+
127
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
128
+ input = input.to(self.weight.dtype)
129
+ wx = torch.einsum("eblh,ehm->eblm", input, self.weight)
130
+
131
+ return torch.add(wx, self.bias[:, None, None, :]) # w times x + b
132
+
133
+
134
+ class EnsembleModel(nn.Module):
135
+ def __init__(self, encoding_dim, num_ensemble, hidden_dim=128, activation="relu", dtype=torch.float32) -> None:
136
+ # super().__init__(encoding_dim, hidden_dim, activation)
137
+ super(EnsembleModel, self).__init__()
138
+ self.num_ensemble = num_ensemble
139
+ self.hidden_dim = hidden_dim
140
+ self.output_dim = 1
141
+
142
+ self.nn1 = EnsembleFC(encoding_dim, hidden_dim, num_ensemble, dtype=dtype)
143
+ self.nn2 = EnsembleFC(hidden_dim, hidden_dim, num_ensemble, dtype=dtype)
144
+ self.nn_out = EnsembleFC(hidden_dim, self.output_dim, num_ensemble, dtype=dtype)
145
+
146
+ self.apply(init_weights)
147
+
148
+ if activation == "swish":
149
+ self.activation = Swish()
150
+ elif activation == "relu":
151
+ self.activation = nn.ReLU()
152
+ else:
153
+ raise ValueError(f"Unknown activation {activation}")
154
+
155
+ def get_params(self) -> torch.Tensor:
156
+ params = []
157
+ for pp in list(self.parameters()):
158
+ params.append(pp.view(-1))
159
+ return torch.cat(params)
160
+
161
+ def forward(self, encoding: torch.Tensor) -> torch.Tensor:
162
+ x = self.activation(self.nn1(encoding))
163
+ x = self.activation(self.nn2(x))
164
+ score = self.nn_out(x)
165
+ return score
166
+
167
+ def init(self):
168
+ self.init_params = self.get_params().data.clone()
169
+ if torch.cuda.is_available():
170
+ self.init_params = self.init_params.cuda()
171
+
172
+ def regularization(self):
173
+ """Prior towards independent initialization."""
174
+ return ((self.get_params() - self.init_params) ** 2).mean()