yezdata commited on
Commit
1085007
·
verified ·
1 Parent(s): a35b4ea

switch automodel for automodelforseqclass and ensure correct broadcast in loss computation

Browse files
Files changed (1) hide show
  1. modeling_emcoder.py +15 -12
modeling_emcoder.py CHANGED
@@ -2,7 +2,7 @@ import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
  from .rope_embeddings import RotaryEmbedding
5
- from transformers import PreTrainedModel, AutoConfig, AutoModel
6
  from transformers.modeling_outputs import SequenceClassifierOutput
7
 
8
  from .configuration_emcoder import EmCoderConfig
@@ -202,12 +202,12 @@ class EmCoder(PreTrainedModel):
202
  """
203
  return_dict = return_dict if return_dict is not None else True
204
 
205
- x = input_ids if input_ids is not None else kwargs.get("x")
206
- mask = attention_mask if attention_mask is not None else kwargs.get("mask")
207
-
208
- if x is None or mask is None:
209
- raise ValueError("input_ids (x) and attention_mask (mask) must be provided")
210
 
 
 
 
211
  if max_batch_size is None:
212
  max_batch_size = n_samples
213
 
@@ -241,7 +241,9 @@ class EmCoder(PreTrainedModel):
241
  loss = None
242
  if labels is not None:
243
  loss_fct = nn.BCEWithLogitsLoss()
244
- loss = loss_fct(all_logits.mean(dim=0), labels.to(all_logits.dtype))
 
 
245
 
246
  if not return_dict:
247
  output = (all_logits,)
@@ -266,11 +268,11 @@ class EmCoder(PreTrainedModel):
266
  """Standard forward pass without MC Dropout."""
267
  return_dict = return_dict if return_dict is not None else True
268
 
269
- x = input_ids if input_ids is not None else kwargs.get("x")
270
- mask = attention_mask if attention_mask is not None else kwargs.get("mask")
271
 
272
  if x is None or mask is None:
273
- raise ValueError("input_ids (x) and attention_mask (mask) must be provided")
274
 
275
  features = self.encoder(x, mask)
276
 
@@ -281,7 +283,8 @@ class EmCoder(PreTrainedModel):
281
  loss = None
282
  if labels is not None:
283
  loss_fct = nn.BCEWithLogitsLoss()
284
- loss = loss_fct(logits, labels.to(logits.dtype))
 
285
 
286
  if not return_dict:
287
  output = (logits,)
@@ -296,6 +299,6 @@ class EmCoder(PreTrainedModel):
296
 
297
  try:
298
  AutoConfig.register("emcoder", EmCoderConfig)
299
- AutoModel.register(EmCoderConfig, EmCoder)
300
  except ValueError:
301
  pass
 
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
  from .rope_embeddings import RotaryEmbedding
5
+ from transformers import PreTrainedModel, AutoConfig, AutoModelForSequenceClassification
6
  from transformers.modeling_outputs import SequenceClassifierOutput
7
 
8
  from .configuration_emcoder import EmCoderConfig
 
202
  """
203
  return_dict = return_dict if return_dict is not None else True
204
 
205
+ x = input_ids if input_ids is not None else kwargs.get("input_ids")
206
+ mask = attention_mask if attention_mask is not None else kwargs.get("attention_mask")
 
 
 
207
 
208
+ if x is None or mask is None:
209
+ raise ValueError("input_ids and attention_mask must be provided")
210
+
211
  if max_batch_size is None:
212
  max_batch_size = n_samples
213
 
 
241
  loss = None
242
  if labels is not None:
243
  loss_fct = nn.BCEWithLogitsLoss()
244
+ logits_mean = all_logits.mean(dim=0) # (B, num_labels)
245
+ target_labels = labels.to(dtype=all_logits.dtype).view(logits_mean.shape)
246
+ loss = loss_fct(logits_mean, target_labels)
247
 
248
  if not return_dict:
249
  output = (all_logits,)
 
268
  """Standard forward pass without MC Dropout."""
269
  return_dict = return_dict if return_dict is not None else True
270
 
271
+ x = input_ids if input_ids is not None else kwargs.get("input_ids")
272
+ mask = attention_mask if attention_mask is not None else kwargs.get("attention_mask")
273
 
274
  if x is None or mask is None:
275
+ raise ValueError("input_ids and attention_mask must be provided")
276
 
277
  features = self.encoder(x, mask)
278
 
 
283
  loss = None
284
  if labels is not None:
285
  loss_fct = nn.BCEWithLogitsLoss()
286
+ target_labels = labels.to(dtype=logits.dtype).view(logits.shape)
287
+ loss = loss_fct(logits, target_labels)
288
 
289
  if not return_dict:
290
  output = (logits,)
 
299
 
300
  try:
301
  AutoConfig.register("emcoder", EmCoderConfig)
302
+ AutoModelForSequenceClassification.register(EmCoderConfig, EmCoder)
303
  except ValueError:
304
  pass