roemmele commited on
Commit
6b1b70b
·
verified ·
1 Parent(s): 2367728

Fix GenerationMixin inheritance for transformers >= 4.50

Browse files

RNNLMForCausalLM now explicitly inherits from GenerationMixin (in addition to PreTrainedModel), fixing an AttributeError on _extract_generation_mode_kwargs when calling generate() with transformers >= 4.50.

One-line change: class RNNLMForCausalLM(PreTrainedModel, GenerationMixin):

Files changed (1) hide show
  1. rnnlm_model/modeling_rnnlm.py +1 -1
rnnlm_model/modeling_rnnlm.py CHANGED
@@ -101,7 +101,7 @@ class GRUKerasCompat(nn.Module):
101
  return output, h.unsqueeze(0)
102
 
103
 
104
- class RNNLMForCausalLM(PreTrainedModel):
105
  """
106
  RNN-based Causal Language Model for text generation.
107
  Compatible with HuggingFace TextGenerationPipeline.
 
101
  return output, h.unsqueeze(0)
102
 
103
 
104
+ class RNNLMForCausalLM(PreTrainedModel, GenerationMixin):
105
  """
106
  RNN-based Causal Language Model for text generation.
107
  Compatible with HuggingFace TextGenerationPipeline.