mksethi commited on
Commit
b45ca52
·
verified ·
1 Parent(s): 976b1c1

Update modeling_frozen_lm_mlp.py

Browse files
Files changed (1) hide show
  1. modeling_frozen_lm_mlp.py +2 -0
modeling_frozen_lm_mlp.py CHANGED
@@ -55,6 +55,8 @@ class FrozenLMModel(PreTrainedModel):
55
  """
56
  if input_ids is None:
57
  raise ValueError("input_ids must be provided")
 
 
58
 
59
  # If input has sequence dimension, pool by averaging
60
  if input_ids.dim() == 3: # (batch_size, seq_len, input_dim)
 
55
  """
56
  if input_ids is None:
57
  raise ValueError("input_ids must be provided")
58
+
59
+ input_ids = input_ids.float()
60
 
61
  # If input has sequence dimension, pool by averaging
62
  if input_ids.dim() == 3: # (batch_size, seq_len, input_dim)