anarlavrenov commited on
Commit
3f5bb73
·
verified ·
1 Parent(s): 31f4c0d

Add GenerationMixin to LIMEForCausalLM

Browse files
Files changed (1) hide show
  1. modeling_lime.py +2 -1
modeling_lime.py CHANGED
@@ -1,6 +1,7 @@
1
  import torch
2
  from torch import nn
3
  from transformers import PreTrainedModel
 
4
  from transformers.modeling_outputs import CausalLMOutputWithPast
5
  from typing import Optional, Tuple, Union
6
  from ukraine.research.transformer.transformer import Transformer
@@ -21,7 +22,7 @@ def make_norm(config: LIMEConfig):
21
  return nn.RMSNorm(config.d_model)
22
 
23
 
24
- class LIMEForCausalLM(PreTrainedModel):
25
  config_class = LIMEConfig
26
  base_model_prefix = "lime"
27
  _tied_weights_keys = ["transformer.output_fc.weight"]
 
1
  import torch
2
  from torch import nn
3
  from transformers import PreTrainedModel
4
+ from transformers.generation import GenerationMixin
5
  from transformers.modeling_outputs import CausalLMOutputWithPast
6
  from typing import Optional, Tuple, Union
7
  from ukraine.research.transformer.transformer import Transformer
 
22
  return nn.RMSNorm(config.d_model)
23
 
24
 
25
+ class LIMEForCausalLM(PreTrainedModel, GenerationMixin):
26
  config_class = LIMEConfig
27
  base_model_prefix = "lime"
28
  _tied_weights_keys = ["transformer.output_fc.weight"]