roemmele commited on
Commit
78bf983
·
1 Parent(s): 2367728

Fix GenerationMixin inheritance for transformers >= 4.50 (#3)

Browse files

- Fix GenerationMixin inheritance for transformers >= 4.50 (6b1b70b14f47b7cefcbdec037b19316b907d46a6)

Files changed (1) hide show
  1. rnnlm_model/modeling_rnnlm.py +1 -1
rnnlm_model/modeling_rnnlm.py CHANGED
@@ -101,7 +101,7 @@ class GRUKerasCompat(nn.Module):
101
  return output, h.unsqueeze(0)
102
 
103
 
104
- class RNNLMForCausalLM(PreTrainedModel):
105
  """
106
  RNN-based Causal Language Model for text generation.
107
  Compatible with HuggingFace TextGenerationPipeline.
 
101
  return output, h.unsqueeze(0)
102
 
103
 
104
+ class RNNLMForCausalLM(PreTrainedModel, GenerationMixin):
105
  """
106
  RNN-based Causal Language Model for text generation.
107
  Compatible with HuggingFace TextGenerationPipeline.