yezdata commited on
Commit
63fe031
·
verified ·
1 Parent(s): a1b2a1c

fix self.config.return_dict

Browse files
Files changed (1) hide show
  1. modeling_emcoder.py +2 -2
modeling_emcoder.py CHANGED
@@ -200,7 +200,7 @@ class EmCoder(PreTrainedModel):
200
  Returns:
201
  Logits of shape (n_samples, B, num_labels).
202
  """
203
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
204
 
205
  x = input_ids if input_ids is not None else kwargs.get("x")
206
  mask = attention_mask if attention_mask is not None else kwargs.get("mask")
@@ -264,7 +264,7 @@ class EmCoder(PreTrainedModel):
264
  **kwargs,
265
  ) -> tuple[torch.Tensor, ...] | SequenceClassifierOutput:
266
  """Standard forward pass without MC Dropout."""
267
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
268
 
269
  x = input_ids if input_ids is not None else kwargs.get("x")
270
  mask = attention_mask if attention_mask is not None else kwargs.get("mask")
 
200
  Returns:
201
  Logits of shape (n_samples, B, num_labels).
202
  """
203
+ return_dict = return_dict if return_dict is not None else True
204
 
205
  x = input_ids if input_ids is not None else kwargs.get("x")
206
  mask = attention_mask if attention_mask is not None else kwargs.get("mask")
 
264
  **kwargs,
265
  ) -> tuple[torch.Tensor, ...] | SequenceClassifierOutput:
266
  """Standard forward pass without MC Dropout."""
267
+ return_dict = return_dict if return_dict is not None else True
268
 
269
  x = input_ids if input_ids is not None else kwargs.get("x")
270
  mask = attention_mask if attention_mask is not None else kwargs.get("mask")