fxmeng commited on
Commit
57f39d9
·
verified ·
1 Parent(s): c74ae5e

Upload modeling_llamamla.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_llamamla.py +56 -0
modeling_llamamla.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+
3
+ import torch
4
+ from torch import nn
5
+ from transformers.cache_utils import Cache
6
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
7
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
8
+ from transformers.processing_utils import Unpack
9
+ from transformers.utils import LossKwargs
10
+
11
+ from transformers.models.llama.modeling_llama import (
12
+ LlamaModel,
13
+ LlamaDecoderLayer,
14
+ LlamaPreTrainedModel,
15
+ LlamaForCausalLM
16
+ )
17
+
18
+ from .configuration_llamamla import LlamaMLAConfig
19
+ from .mla import MLAAttention, eager_attention_forward
20
+
21
+
22
+ class LlamaMLADecoderLayer(LlamaDecoderLayer):
23
+
24
+ def __init__(self, config: LlamaMLAConfig, layer_idx: int):
25
+ super().__init__(config, layer_idx)
26
+ self.self_attn = MLAAttention(config, layer_idx)
27
+
28
+
29
+ class LlamaMLAPreTrainedModel(LlamaPreTrainedModel):
30
+
31
+ config_class = LlamaMLAConfig
32
+ _no_split_modules = ["LlamaMLADecoderLayer"]
33
+
34
+
35
+ class LlamaMLAModel(LlamaMLAPreTrainedModel, LlamaModel):
36
+
37
+ def __init__(self, config: LlamaMLAConfig):
38
+ super().__init__(config)
39
+
40
+ self.layers = nn.ModuleList(
41
+ [LlamaMLADecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
42
+ )
43
+
44
+
45
+ class LlamaMLAForCausalLM(LlamaMLAPreTrainedModel, LlamaForCausalLM):
46
+
47
+ def __init__(self, config):
48
+ super().__init__(config)
49
+ self.model = LlamaMLAModel(config)
50
+
51
+
52
+ __all__ = [
53
+ "LlamaMLAForCausalLM",
54
+ "LlamaMLAModel",
55
+ "LlamaMLAPreTrainedModel",
56
+ ]