Update modeling_motif.py
Browse files- modeling_motif.py +3 -121
modeling_motif.py
CHANGED
|
@@ -1040,13 +1040,12 @@ class MotifModel(MotifPreTrainedModel):
|
|
| 1040 |
super().__init__(config)
|
| 1041 |
self.padding_idx = config.pad_token_id
|
| 1042 |
self.vocab_size = config.vocab_size
|
| 1043 |
-
self.multi_token_heads = config.multi_token_heads
|
| 1044 |
|
| 1045 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 1046 |
# NOTE: For multi-token models, the last decoder layers (one for each token index)
|
| 1047 |
# are implemented as a part of `MotifModelForCausalLM` to enable a custom forward-backward procedure.
|
| 1048 |
|
| 1049 |
-
num_hidden_layers = config.num_hidden_layers
|
| 1050 |
self.layers = nn.ModuleList([MotifDecoderLayer(config = config, layer_idx=layer_idx) for layer_idx in range(num_hidden_layers)])
|
| 1051 |
self._attn_implementation = config._attn_implementation
|
| 1052 |
RMSNorm = MorehRMSNorm
|
|
@@ -1338,16 +1337,8 @@ class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
|
|
| 1338 |
super().__init__(config)
|
| 1339 |
self.model = MotifModel(config)
|
| 1340 |
self.vocab_size = config.vocab_size
|
| 1341 |
-
self.multi_token_heads = config.multi_token_heads
|
| 1342 |
|
| 1343 |
-
|
| 1344 |
-
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 1345 |
-
else:
|
| 1346 |
-
self.tokenwise_last_layers = nn.ModuleList(
|
| 1347 |
-
[MotifDecoderLayer(config, config.num_hidden_layers - 1) for _ in range(self.multi_token_heads)])
|
| 1348 |
-
self.tokenwise_lm_heads = nn.ModuleList(
|
| 1349 |
-
[nn.Linear(config.hidden_size, config.vocab_size, bias=False) for _ in range(self.multi_token_heads)])
|
| 1350 |
-
self.should_skip_separate_backward_pass = self.multi_token_heads is not None
|
| 1351 |
|
| 1352 |
# Initialize weights and apply final processing
|
| 1353 |
self.post_init()
|
|
@@ -1374,101 +1365,7 @@ class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
|
|
| 1374 |
def get_decoder(self):
|
| 1375 |
return self.model
|
| 1376 |
|
| 1377 |
-
|
| 1378 |
-
hidden_states: torch.FloatTensor,
|
| 1379 |
-
outputs: MotifModelOutputWithPast,
|
| 1380 |
-
labels: torch.LongTensor,
|
| 1381 |
-
position_ids: Optional[torch.LongTensor],
|
| 1382 |
-
output_attentions: Optional[bool],
|
| 1383 |
-
use_cache: Optional[bool],
|
| 1384 |
-
cache_position: Optional[torch.LongTensor],
|
| 1385 |
-
return_dict: Optional[bool],
|
| 1386 |
-
num_logits_to_keep: int = 0) -> CausalLMOutputWithPast:
|
| 1387 |
-
"""
|
| 1388 |
-
This implements the main forward-backward procedure for multi-token model training proposed in
|
| 1389 |
-
the paper https://arxiv.org/abs/2404.19737.
|
| 1390 |
-
Essentially,
|
| 1391 |
-
- The multi-token model tries to predict n (instead of 1) tokens at a time.
|
| 1392 |
-
- Applying this only during training and using first-token prediction during inference is still helpful.
|
| 1393 |
-
- The change in architecture: when using n-token prediction, each token index (between 1 and n) has its own
|
| 1394 |
-
(1) last attention layer and (2) lm head.
|
| 1395 |
-
- The change in loss: sum of cross-entropy losses corresponding to each token index.
|
| 1396 |
-
- Custom forward-backward procedure for memory efficiency: refer to the implementation of `multi_head_forward_backward`.
|
| 1397 |
-
"""
|
| 1398 |
-
if not return_dict:
|
| 1399 |
-
raise NotImplementedError("return_dict must be True for multi-token training")
|
| 1400 |
-
|
| 1401 |
-
past_key_values = outputs.past_key_values
|
| 1402 |
-
causal_mask = outputs.causal_mask
|
| 1403 |
-
position_embeddings = outputs.position_embeddings
|
| 1404 |
-
|
| 1405 |
-
if labels is not None:
|
| 1406 |
-
labels = labels.to(hidden_states.device)
|
| 1407 |
-
|
| 1408 |
-
def _tokenwise_forward(hidden_states: torch.Tensor, token_idx):
|
| 1409 |
-
## Model forward
|
| 1410 |
-
layer = self.tokenwise_last_layers[token_idx]
|
| 1411 |
-
lm_head = self.tokenwise_lm_heads[token_idx]
|
| 1412 |
-
|
| 1413 |
-
layer_outputs = layer(
|
| 1414 |
-
hidden_states,
|
| 1415 |
-
attention_mask=causal_mask,
|
| 1416 |
-
position_ids=position_ids,
|
| 1417 |
-
past_key_values=past_key_values, # TODO: update past_key_values?
|
| 1418 |
-
output_attentions=output_attentions,
|
| 1419 |
-
use_cache=use_cache,
|
| 1420 |
-
cache_position=cache_position,
|
| 1421 |
-
position_embeddings=position_embeddings,
|
| 1422 |
-
)
|
| 1423 |
-
last_hidden_states = layer_outputs[0]
|
| 1424 |
-
if num_logits_to_keep > 0:
|
| 1425 |
-
assert labels is None
|
| 1426 |
-
last_hidden_states = last_hidden_states[:, -num_logits_to_keep:, :]
|
| 1427 |
-
tokenwise_logits = lm_head(last_hidden_states)
|
| 1428 |
-
|
| 1429 |
-
if labels is None:
|
| 1430 |
-
return {
|
| 1431 |
-
"loss": None,
|
| 1432 |
-
"logits": tokenwise_logits,
|
| 1433 |
-
}
|
| 1434 |
-
|
| 1435 |
-
## Compute loss
|
| 1436 |
-
shift_n = token_idx + 1
|
| 1437 |
-
shift_logits = tokenwise_logits[..., :-shift_n, :].contiguous()
|
| 1438 |
-
shift_labels = labels[..., shift_n:].contiguous()
|
| 1439 |
-
|
| 1440 |
-
loss_fct = CrossEntropyLoss()
|
| 1441 |
-
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
| 1442 |
-
shift_labels = shift_labels.view(-1)
|
| 1443 |
-
|
| 1444 |
-
tokenwise_loss = loss_fct(shift_logits, shift_labels)
|
| 1445 |
-
|
| 1446 |
-
return {
|
| 1447 |
-
"loss": tokenwise_loss,
|
| 1448 |
-
"logits": tokenwise_logits,
|
| 1449 |
-
}
|
| 1450 |
-
|
| 1451 |
-
head_fns = [
|
| 1452 |
-
lambda hidden_states, token_idx=token_idx: _tokenwise_forward(hidden_states, token_idx)
|
| 1453 |
-
for token_idx in range(self.multi_token_heads)
|
| 1454 |
-
]
|
| 1455 |
-
loss, logits = multi_head_forward_backward(hidden_states,
|
| 1456 |
-
head_fns,
|
| 1457 |
-
return_keys=("loss", "logits"),
|
| 1458 |
-
return_only_first_head=True)
|
| 1459 |
-
|
| 1460 |
-
if not return_dict:
|
| 1461 |
-
output = (logits, ) + outputs[1:]
|
| 1462 |
-
return (loss, ) + output
|
| 1463 |
-
|
| 1464 |
-
return CausalLMOutputWithPast(
|
| 1465 |
-
loss=loss,
|
| 1466 |
-
logits=logits,
|
| 1467 |
-
past_key_values=outputs.past_key_values,
|
| 1468 |
-
hidden_states=outputs.hidden_states,
|
| 1469 |
-
attentions=outputs.attentions,
|
| 1470 |
-
)
|
| 1471 |
-
|
| 1472 |
@add_start_docstrings_to_model_forward(MOTIF_INPUTS_DOCSTRING)
|
| 1473 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
| 1474 |
def forward(
|
|
@@ -1524,8 +1421,6 @@ class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
|
|
| 1524 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1525 |
|
| 1526 |
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 1527 |
-
outputs_include_causal_mask = self.multi_token_heads is not None
|
| 1528 |
-
outputs_include_position_embeddings = self.multi_token_heads is not None
|
| 1529 |
outputs: MotifModelOutputWithPast = self.model(
|
| 1530 |
input_ids=input_ids,
|
| 1531 |
attention_mask=attention_mask,
|
|
@@ -1537,23 +1432,10 @@ class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
|
|
| 1537 |
output_hidden_states=output_hidden_states,
|
| 1538 |
return_dict=return_dict,
|
| 1539 |
cache_position=cache_position,
|
| 1540 |
-
outputs_include_causal_mask=outputs_include_causal_mask,
|
| 1541 |
-
outputs_include_position_embeddings=outputs_include_position_embeddings,
|
| 1542 |
)
|
| 1543 |
|
| 1544 |
hidden_states = outputs[0]
|
| 1545 |
|
| 1546 |
-
if self.multi_token_heads is not None:
|
| 1547 |
-
return self.multi_token_forward_backward(hidden_states,
|
| 1548 |
-
outputs,
|
| 1549 |
-
labels,
|
| 1550 |
-
position_ids,
|
| 1551 |
-
output_attentions,
|
| 1552 |
-
use_cache,
|
| 1553 |
-
cache_position,
|
| 1554 |
-
return_dict,
|
| 1555 |
-
num_logits_to_keep=num_logits_to_keep)
|
| 1556 |
-
|
| 1557 |
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 1558 |
hidden_states = hidden_states
|
| 1559 |
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
|
|
|
| 1040 |
super().__init__(config)
|
| 1041 |
self.padding_idx = config.pad_token_id
|
| 1042 |
self.vocab_size = config.vocab_size
|
|
|
|
| 1043 |
|
| 1044 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 1045 |
# NOTE: For multi-token models, the last decoder layers (one for each token index)
|
| 1046 |
# are implemented as a part of `MotifModelForCausalLM` to enable a custom forward-backward procedure.
|
| 1047 |
|
| 1048 |
+
num_hidden_layers = config.num_hidden_layers
|
| 1049 |
self.layers = nn.ModuleList([MotifDecoderLayer(config = config, layer_idx=layer_idx) for layer_idx in range(num_hidden_layers)])
|
| 1050 |
self._attn_implementation = config._attn_implementation
|
| 1051 |
RMSNorm = MorehRMSNorm
|
|
|
|
| 1337 |
super().__init__(config)
|
| 1338 |
self.model = MotifModel(config)
|
| 1339 |
self.vocab_size = config.vocab_size
|
|
|
|
| 1340 |
|
| 1341 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1342 |
|
| 1343 |
# Initialize weights and apply final processing
|
| 1344 |
self.post_init()
|
|
|
|
| 1365 |
def get_decoder(self):
|
| 1366 |
return self.model
|
| 1367 |
|
| 1368 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1369 |
@add_start_docstrings_to_model_forward(MOTIF_INPUTS_DOCSTRING)
|
| 1370 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
| 1371 |
def forward(
|
|
|
|
| 1421 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1422 |
|
| 1423 |
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
|
|
|
|
|
| 1424 |
outputs: MotifModelOutputWithPast = self.model(
|
| 1425 |
input_ids=input_ids,
|
| 1426 |
attention_mask=attention_mask,
|
|
|
|
| 1432 |
output_hidden_states=output_hidden_states,
|
| 1433 |
return_dict=return_dict,
|
| 1434 |
cache_position=cache_position,
|
|
|
|
|
|
|
| 1435 |
)
|
| 1436 |
|
| 1437 |
hidden_states = outputs[0]
|
| 1438 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1439 |
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 1440 |
hidden_states = hidden_states
|
| 1441 |
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|