Add GenerationMixin to LIMEForCausalLM
Browse files- modeling_lime.py +2 -1
modeling_lime.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import torch
|
| 2 |
from torch import nn
|
| 3 |
from transformers import PreTrainedModel
|
|
|
|
| 4 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 5 |
from typing import Optional, Tuple, Union
|
| 6 |
from ukraine.research.transformer.transformer import Transformer
|
|
@@ -21,7 +22,7 @@ def make_norm(config: LIMEConfig):
|
|
| 21 |
return nn.RMSNorm(config.d_model)
|
| 22 |
|
| 23 |
|
| 24 |
-
class LIMEForCausalLM(PreTrainedModel):
|
| 25 |
config_class = LIMEConfig
|
| 26 |
base_model_prefix = "lime"
|
| 27 |
_tied_weights_keys = ["transformer.output_fc.weight"]
|
|
|
|
| 1 |
import torch
|
| 2 |
from torch import nn
|
| 3 |
from transformers import PreTrainedModel
|
| 4 |
+
from transformers.generation import GenerationMixin
|
| 5 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 6 |
from typing import Optional, Tuple, Union
|
| 7 |
from ukraine.research.transformer.transformer import Transformer
|
|
|
|
| 22 |
return nn.RMSNorm(config.d_model)
|
| 23 |
|
| 24 |
|
| 25 |
+
class LIMEForCausalLM(PreTrainedModel, GenerationMixin):
|
| 26 |
config_class = LIMEConfig
|
| 27 |
base_model_prefix = "lime"
|
| 28 |
_tied_weights_keys = ["transformer.output_fc.weight"]
|