Commit ·
3a7a4b9
1
Parent(s): 0507803
updating model class
Browse files- updating attention logic for GPT2Block to select GPT2Attention
- adding consider_aux_loss config to allow users to skip adding aux loss to the total loss
- configuration_lola_gpt2.py +2 -0
- modeling_lola_gpt2.py +10 -55
configuration_lola_gpt2.py
CHANGED
|
@@ -48,6 +48,7 @@ class LOLAConfig(PretrainedConfig):
|
|
| 48 |
num_experts=16,
|
| 49 |
topk=1,
|
| 50 |
router_aux_loss_coef=0.01,
|
|
|
|
| 51 |
**kwargs,
|
| 52 |
):
|
| 53 |
self.vocab_size = vocab_size
|
|
@@ -77,6 +78,7 @@ class LOLAConfig(PretrainedConfig):
|
|
| 77 |
self.bos_token_id = bos_token_id
|
| 78 |
self.eos_token_id = eos_token_id
|
| 79 |
self.router_aux_loss_coef = router_aux_loss_coef
|
|
|
|
| 80 |
|
| 81 |
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
| 82 |
|
|
|
|
| 48 |
num_experts=16,
|
| 49 |
topk=1,
|
| 50 |
router_aux_loss_coef=0.01,
|
| 51 |
+
consider_aux_loss=True,
|
| 52 |
**kwargs,
|
| 53 |
):
|
| 54 |
self.vocab_size = vocab_size
|
|
|
|
| 78 |
self.bos_token_id = bos_token_id
|
| 79 |
self.eos_token_id = eos_token_id
|
| 80 |
self.router_aux_loss_coef = router_aux_loss_coef
|
| 81 |
+
self.consider_aux_loss = consider_aux_loss
|
| 82 |
|
| 83 |
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
| 84 |
|
modeling_lola_gpt2.py
CHANGED
|
@@ -21,12 +21,8 @@ from torch.nn import CrossEntropyLoss
|
|
| 21 |
|
| 22 |
from transformers.modeling_outputs import (
|
| 23 |
BaseModelOutputWithPastAndCrossAttentions,
|
| 24 |
-
MoeCausalLMOutputWithPast
|
| 25 |
-
SequenceClassifierOutputWithPast,
|
| 26 |
-
QuestionAnsweringModelOutput
|
| 27 |
)
|
| 28 |
-
from transformers.modeling_utils import SequenceSummary
|
| 29 |
-
from transformers.pytorch_utils import Conv1D
|
| 30 |
from transformers.utils import (
|
| 31 |
logging
|
| 32 |
)
|
|
@@ -40,7 +36,6 @@ from typing import Optional, Tuple
|
|
| 40 |
import torch
|
| 41 |
from transformers.modeling_outputs import ModelOutput
|
| 42 |
import transformers
|
| 43 |
-
import importlib.util
|
| 44 |
|
| 45 |
|
| 46 |
logger = logging.get_logger(__name__)
|
|
@@ -50,7 +45,7 @@ expert_analysis_callback = lambda _: None
|
|
| 50 |
class LOLADependencyChecker:
|
| 51 |
def __init__(self):
|
| 52 |
self.expected_versions = {
|
| 53 |
-
"transformers": "4.
|
| 54 |
}
|
| 55 |
self.check_dependencies()
|
| 56 |
|
|
@@ -111,6 +106,8 @@ class LOLAModel(GPT2PreTrainedModel):
|
|
| 111 |
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
| 112 |
|
| 113 |
self.drop = nn.Dropout(config.embd_pdrop)
|
|
|
|
|
|
|
| 114 |
self.h = nn.ModuleList([
|
| 115 |
GPT2Block(config, layer_idx=i) if i % 2 == 0 else LOLABlock(config, layer_idx=i) for i in range(config.num_hidden_layers)
|
| 116 |
])
|
|
@@ -384,6 +381,7 @@ class LOLABlock(nn.Module):
|
|
| 384 |
|
| 385 |
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 386 |
self.attn = GPT2Attention(config, layer_idx=layer_idx)
|
|
|
|
| 387 |
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 388 |
self.moe = LOLAMOE(
|
| 389 |
hidden_size,
|
|
@@ -488,53 +486,6 @@ class LOLAMOE(nn.Module):
|
|
| 488 |
expert_analysis_callback(selected_experts)
|
| 489 |
return final_hidden_states, router_logits, aux_loss
|
| 490 |
|
| 491 |
-
class LOLAAttention(GPT2Attention):
|
| 492 |
-
def __init__(self, config, is_cross_attention=False, layer_idx=None):
|
| 493 |
-
super(GPT2Attention, SequenceClassifierOutputWithPast).__init__()
|
| 494 |
-
|
| 495 |
-
max_positions = config.max_position_embeddings
|
| 496 |
-
self.register_buffer(
|
| 497 |
-
"bias",
|
| 498 |
-
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
|
| 499 |
-
1, 1, max_positions, max_positions
|
| 500 |
-
),
|
| 501 |
-
#persistent=False,
|
| 502 |
-
)
|
| 503 |
-
self.register_buffer("masked_bias", torch.tensor(-1e4),
|
| 504 |
-
#persistent=False
|
| 505 |
-
)
|
| 506 |
-
|
| 507 |
-
self.embed_dim = config.hidden_size
|
| 508 |
-
self.num_heads = config.num_attention_heads
|
| 509 |
-
self.head_dim = self.embed_dim // self.num_heads
|
| 510 |
-
self.split_size = self.embed_dim
|
| 511 |
-
if self.head_dim * self.num_heads != self.embed_dim:
|
| 512 |
-
raise ValueError(
|
| 513 |
-
f"embed_dim must be divisible by num_heads (got embed_dim: {self.embed_dim} and num_heads:"
|
| 514 |
-
f" {self.num_heads})."
|
| 515 |
-
)
|
| 516 |
-
|
| 517 |
-
self.scale_attn_weights = config.scale_attn_weights
|
| 518 |
-
self.is_cross_attention = is_cross_attention
|
| 519 |
-
|
| 520 |
-
# Layer-wise attention scaling, reordering, and upcasting
|
| 521 |
-
self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
|
| 522 |
-
self.layer_idx = layer_idx
|
| 523 |
-
self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
|
| 524 |
-
|
| 525 |
-
if self.is_cross_attention:
|
| 526 |
-
self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
|
| 527 |
-
self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
|
| 528 |
-
else:
|
| 529 |
-
self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
|
| 530 |
-
self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
|
| 531 |
-
|
| 532 |
-
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
| 533 |
-
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
| 534 |
-
|
| 535 |
-
self.pruned_heads = set()
|
| 536 |
-
|
| 537 |
-
|
| 538 |
class LOLALMHeadModel(GPT2LMHeadModel):
|
| 539 |
|
| 540 |
config_class = LOLAConfig
|
|
@@ -545,6 +496,9 @@ class LOLALMHeadModel(GPT2LMHeadModel):
|
|
| 545 |
self.transformer = LOLAModel(config)
|
| 546 |
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
| 547 |
|
|
|
|
|
|
|
|
|
|
| 548 |
# Model parallel
|
| 549 |
self.model_parallel = False
|
| 550 |
self.device_map = None
|
|
@@ -595,7 +549,8 @@ class LOLALMHeadModel(GPT2LMHeadModel):
|
|
| 595 |
# Flatten the tokens
|
| 596 |
loss_fct = CrossEntropyLoss()
|
| 597 |
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
| 598 |
-
|
|
|
|
| 599 |
loss += self.config.router_aux_loss_coef * aux_loss
|
| 600 |
|
| 601 |
if not return_dict:
|
|
|
|
| 21 |
|
| 22 |
from transformers.modeling_outputs import (
|
| 23 |
BaseModelOutputWithPastAndCrossAttentions,
|
| 24 |
+
MoeCausalLMOutputWithPast
|
|
|
|
|
|
|
| 25 |
)
|
|
|
|
|
|
|
| 26 |
from transformers.utils import (
|
| 27 |
logging
|
| 28 |
)
|
|
|
|
| 36 |
import torch
|
| 37 |
from transformers.modeling_outputs import ModelOutput
|
| 38 |
import transformers
|
|
|
|
| 39 |
|
| 40 |
|
| 41 |
logger = logging.get_logger(__name__)
|
|
|
|
| 45 |
class LOLADependencyChecker:
|
| 46 |
def __init__(self):
|
| 47 |
self.expected_versions = {
|
| 48 |
+
"transformers": "4.47.0"
|
| 49 |
}
|
| 50 |
self.check_dependencies()
|
| 51 |
|
|
|
|
| 106 |
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
| 107 |
|
| 108 |
self.drop = nn.Dropout(config.embd_pdrop)
|
| 109 |
+
# To make sure the GPTBlock selects the right attention
|
| 110 |
+
config._attn_implementation='eager'
|
| 111 |
self.h = nn.ModuleList([
|
| 112 |
GPT2Block(config, layer_idx=i) if i % 2 == 0 else LOLABlock(config, layer_idx=i) for i in range(config.num_hidden_layers)
|
| 113 |
])
|
|
|
|
| 381 |
|
| 382 |
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 383 |
self.attn = GPT2Attention(config, layer_idx=layer_idx)
|
| 384 |
+
#self.attn = GPT2SdpaAttention(config, layer_idx=layer_idx)
|
| 385 |
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 386 |
self.moe = LOLAMOE(
|
| 387 |
hidden_size,
|
|
|
|
| 486 |
expert_analysis_callback(selected_experts)
|
| 487 |
return final_hidden_states, router_logits, aux_loss
|
| 488 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 489 |
class LOLALMHeadModel(GPT2LMHeadModel):
|
| 490 |
|
| 491 |
config_class = LOLAConfig
|
|
|
|
| 496 |
self.transformer = LOLAModel(config)
|
| 497 |
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
| 498 |
|
| 499 |
+
# To add aux loss or not
|
| 500 |
+
self.consider_aux_loss = config.consider_aux_loss
|
| 501 |
+
logger.debug(f'consider_aux_loss is set to {self.consider_aux_loss}')
|
| 502 |
# Model parallel
|
| 503 |
self.model_parallel = False
|
| 504 |
self.device_map = None
|
|
|
|
| 549 |
# Flatten the tokens
|
| 550 |
loss_fct = CrossEntropyLoss()
|
| 551 |
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
| 552 |
+
# We can avoid adding aux loss to the total loss if its not needed (e.g. LORA without targeting expert-gating)
|
| 553 |
+
if aux_loss is not None and self.consider_aux_loss:
|
| 554 |
loss += self.config.router_aux_loss_coef * aux_loss
|
| 555 |
|
| 556 |
if not return_dict:
|