Update modeling_svd_llama.py
Browse files- modeling_svd_llama.py +5 -5
modeling_svd_llama.py
CHANGED
|
@@ -11,7 +11,7 @@ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
|
| 11 |
from transformers.utils import logging
|
| 12 |
from transformers import LlamaForCausalLM
|
| 13 |
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel, LlamaRotaryEmbedding, LlamaRMSNorm, repeat_kv, apply_rotary_pos_emb
|
| 14 |
-
from
|
| 15 |
|
| 16 |
|
| 17 |
logger = logging.get_logger(__name__)
|
|
@@ -21,7 +21,7 @@ _CONFIG_FOR_DOC = "LlamaConfig"
|
|
| 21 |
ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
|
| 22 |
|
| 23 |
class SVDLlamaMLP(nn.Module):
|
| 24 |
-
def __init__(self, config:
|
| 25 |
super().__init__()
|
| 26 |
self.config = config
|
| 27 |
self.hidden_size = config.hidden_size
|
|
@@ -48,7 +48,7 @@ class SVDLlamaMLP(nn.Module):
|
|
| 48 |
class SVDLlamaAttention(nn.Module):
|
| 49 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 50 |
|
| 51 |
-
def __init__(self, config:
|
| 52 |
super().__init__()
|
| 53 |
self.config = config
|
| 54 |
self.layer_idx = layer_idx
|
|
@@ -334,14 +334,14 @@ class SVDLLaMASDPA(SVDLlamaAttention):
|
|
| 334 |
|
| 335 |
|
| 336 |
class SVDLlamaDecoderLayer(LlamaDecoderLayer):
|
| 337 |
-
def __init__(self, config:
|
| 338 |
super().__init__(config, layer_idx)
|
| 339 |
self.self_attn = SVDLlamaAttention(config=config, layer_idx=layer_idx)
|
| 340 |
self.mlp = SVDLlamaMLP(config)
|
| 341 |
|
| 342 |
|
| 343 |
class SVDLlamaForCausalLM(LlamaForCausalLM):
|
| 344 |
-
def __init__(self, config:
|
| 345 |
super().__init__(config)
|
| 346 |
self.model = LlamaModel(config)
|
| 347 |
self.model.layers = nn.ModuleList(
|
|
|
|
| 11 |
from transformers.utils import logging
|
| 12 |
from transformers import LlamaForCausalLM
|
| 13 |
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel, LlamaRotaryEmbedding, LlamaRMSNorm, repeat_kv, apply_rotary_pos_emb
|
| 14 |
+
from transformers import LlamaConfig
|
| 15 |
|
| 16 |
|
| 17 |
logger = logging.get_logger(__name__)
|
|
|
|
| 21 |
ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
|
| 22 |
|
| 23 |
class SVDLlamaMLP(nn.Module):
|
| 24 |
+
def __init__(self, config: LlamaConfig):
|
| 25 |
super().__init__()
|
| 26 |
self.config = config
|
| 27 |
self.hidden_size = config.hidden_size
|
|
|
|
| 48 |
class SVDLlamaAttention(nn.Module):
|
| 49 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 50 |
|
| 51 |
+
def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
|
| 52 |
super().__init__()
|
| 53 |
self.config = config
|
| 54 |
self.layer_idx = layer_idx
|
|
|
|
| 334 |
|
| 335 |
|
| 336 |
class SVDLlamaDecoderLayer(LlamaDecoderLayer):
|
| 337 |
+
def __init__(self, config: LlamaConfig, layer_idx: int):
|
| 338 |
super().__init__(config, layer_idx)
|
| 339 |
self.self_attn = SVDLlamaAttention(config=config, layer_idx=layer_idx)
|
| 340 |
self.mlp = SVDLlamaMLP(config)
|
| 341 |
|
| 342 |
|
| 343 |
class SVDLlamaForCausalLM(LlamaForCausalLM):
|
| 344 |
+
def __init__(self, config: LlamaConfig):
|
| 345 |
super().__init__(config)
|
| 346 |
self.model = LlamaModel(config)
|
| 347 |
self.model.layers = nn.ModuleList(
|