Update modeling_frozen_lm_mlp.py
Browse files
modeling_frozen_lm_mlp.py
CHANGED
|
@@ -55,6 +55,8 @@ class FrozenLMModel(PreTrainedModel):
|
|
| 55 |
"""
|
| 56 |
if input_ids is None:
|
| 57 |
raise ValueError("input_ids must be provided")
|
|
|
|
|
|
|
| 58 |
|
| 59 |
# If input has sequence dimension, pool by averaging
|
| 60 |
if input_ids.dim() == 3: # (batch_size, seq_len, input_dim)
|
|
|
|
| 55 |
"""
|
| 56 |
if input_ids is None:
|
| 57 |
raise ValueError("input_ids must be provided")
|
| 58 |
+
|
| 59 |
+
input_ids = input_ids.float()
|
| 60 |
|
| 61 |
# If input has sequence dimension, pool by averaging
|
| 62 |
if input_ids.dim() == 3: # (batch_size, seq_len, input_dim)
|