Update modeling_motif.py
Browse files- modeling_motif.py +3 -12
modeling_motif.py
CHANGED
|
@@ -558,7 +558,7 @@ class MotifAttention(nn.Module):
|
|
| 558 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 559 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
| 560 |
|
| 561 |
-
attn_output = self.o_proj(attn_output)
|
| 562 |
|
| 563 |
if not output_attentions:
|
| 564 |
attn_weights = None
|
|
@@ -1285,7 +1285,7 @@ class MotifModel(MotifPreTrainedModel):
|
|
| 1285 |
all_self_attns += (layer_outputs[1], )
|
| 1286 |
|
| 1287 |
# <|_2_|>
|
| 1288 |
-
hidden_states = self.norm(hidden_states)
|
| 1289 |
|
| 1290 |
# add hidden states from the last decoder layer
|
| 1291 |
if output_hidden_states:
|
|
@@ -1461,15 +1461,6 @@ class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
|
|
| 1461 |
# Initialize weights and apply final processing
|
| 1462 |
self.post_init()
|
| 1463 |
|
| 1464 |
-
# <|_3_|>
|
| 1465 |
-
if config.muP:
|
| 1466 |
-
self.lm_head.__do_scale_tager_mu_dim_base_model__=True
|
| 1467 |
-
|
| 1468 |
-
# <|_4_|>
|
| 1469 |
-
self.lm_head_alpha = 1
|
| 1470 |
-
if config.wesar_weights:
|
| 1471 |
-
self.lm_head_alpha = nn.Parameter(torch.tensor(1).float())
|
| 1472 |
-
|
| 1473 |
if getattr(config, "tie_word_embeddings", True):
|
| 1474 |
logger.info('tie embeddings')
|
| 1475 |
self.tie_weights()
|
|
@@ -1676,7 +1667,7 @@ class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
|
|
| 1676 |
num_logits_to_keep=num_logits_to_keep)
|
| 1677 |
|
| 1678 |
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 1679 |
-
hidden_states = hidden_states
|
| 1680 |
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
| 1681 |
logits = logits.float()
|
| 1682 |
|
|
|
|
| 558 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 559 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
| 560 |
|
| 561 |
+
attn_output = self.o_proj(attn_output)
|
| 562 |
|
| 563 |
if not output_attentions:
|
| 564 |
attn_weights = None
|
|
|
|
| 1285 |
all_self_attns += (layer_outputs[1], )
|
| 1286 |
|
| 1287 |
# <|_2_|>
|
| 1288 |
+
hidden_states = self.norm(hidden_states)
|
| 1289 |
|
| 1290 |
# add hidden states from the last decoder layer
|
| 1291 |
if output_hidden_states:
|
|
|
|
| 1461 |
# Initialize weights and apply final processing
|
| 1462 |
self.post_init()
|
| 1463 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1464 |
if getattr(config, "tie_word_embeddings", True):
|
| 1465 |
logger.info('tie embeddings')
|
| 1466 |
self.tie_weights()
|
|
|
|
| 1667 |
num_logits_to_keep=num_logits_to_keep)
|
| 1668 |
|
| 1669 |
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 1670 |
+
hidden_states = hidden_states
|
| 1671 |
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
| 1672 |
logits = logits.float()
|
| 1673 |
|