Upload model
Browse files- config.json +6 -1
- modeling_mamba.py +74 -80
config.json
CHANGED
|
@@ -1,6 +1,10 @@
|
|
| 1 |
{
|
|
|
|
|
|
|
|
|
|
| 2 |
"auto_map": {
|
| 3 |
-
"AutoConfig": "configuration_mamba.MambaConfig"
|
|
|
|
| 4 |
},
|
| 5 |
"bias": false,
|
| 6 |
"conv_bias": true,
|
|
@@ -14,6 +18,7 @@
|
|
| 14 |
"model_type": "mamba",
|
| 15 |
"n_layer": 24,
|
| 16 |
"pad_vocab_size_multiple": 8,
|
|
|
|
| 17 |
"transformers_version": "4.37.2",
|
| 18 |
"vocab_size": 50280
|
| 19 |
}
|
|
|
|
| 1 |
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"MambaLMHeadModel"
|
| 4 |
+
],
|
| 5 |
"auto_map": {
|
| 6 |
+
"AutoConfig": "configuration_mamba.MambaConfig",
|
| 7 |
+
"AutoModelForCausalLM": "modeling_mamba.MambaLMHeadModel"
|
| 8 |
},
|
| 9 |
"bias": false,
|
| 10 |
"conv_bias": true,
|
|
|
|
| 18 |
"model_type": "mamba",
|
| 19 |
"n_layer": 24,
|
| 20 |
"pad_vocab_size_multiple": 8,
|
| 21 |
+
"torch_dtype": "float32",
|
| 22 |
"transformers_version": "4.37.2",
|
| 23 |
"vocab_size": 50280
|
| 24 |
}
|
modeling_mamba.py
CHANGED
|
@@ -380,23 +380,17 @@ class MambaModel(MambaPretrainedModel):
|
|
| 380 |
**kwargs,
|
| 381 |
)
|
| 382 |
|
| 383 |
-
# self.embedding = nn.Embedding(
|
| 384 |
-
# num_embeddings=config.vocab_size,
|
| 385 |
-
# embedding_dim=config.d_model,
|
| 386 |
-
# )
|
| 387 |
-
|
| 388 |
-
|
| 389 |
self.embedding = nn.Embedding(
|
| 390 |
-
num_embeddings=config.vocab_size,
|
| 391 |
-
embedding_dim=config.d_model,
|
| 392 |
)
|
| 393 |
|
| 394 |
self.layers = nn.ModuleList(
|
| 395 |
-
[ResidualBlock(config) for _ in range(self.config.n_layer)]
|
| 396 |
)
|
| 397 |
# self.layers = nn.ModuleList([MambaBlock(config) for _ in range(config.n_layer)])
|
| 398 |
# # self.norm_f = RMSNorm(d_model=embedding_dim)
|
| 399 |
-
self.norm_f = RMSNorm(config.d_model)
|
| 400 |
|
| 401 |
# self.gradient_checkpointing = False
|
| 402 |
# # self.post_init()
|
|
@@ -454,54 +448,54 @@ class MambaModel(MambaPretrainedModel):
|
|
| 454 |
# def set_input_embeddings(self, value):
|
| 455 |
# self.embed_out = value
|
| 456 |
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
|
| 506 |
|
| 507 |
# Influences:
|
|
@@ -538,31 +532,31 @@ class MambaLMHeadModel(MambaPretrainedModel):
|
|
| 538 |
# Initialize weights and apply final processing
|
| 539 |
self.post_init()
|
| 540 |
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
|
| 554 |
-
|
| 555 |
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
|
| 567 |
# # def prepare_inputs_for_generation(
|
| 568 |
# # self, input_ids, attention_mask=None, **model_kwargs
|
|
|
|
| 380 |
**kwargs,
|
| 381 |
)
|
| 382 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 383 |
self.embedding = nn.Embedding(
|
| 384 |
+
num_embeddings=self.config.vocab_size,
|
| 385 |
+
embedding_dim=self.config.d_model,
|
| 386 |
)
|
| 387 |
|
| 388 |
self.layers = nn.ModuleList(
|
| 389 |
+
[ResidualBlock(self.config) for _ in range(self.config.n_layer)]
|
| 390 |
)
|
| 391 |
# self.layers = nn.ModuleList([MambaBlock(config) for _ in range(config.n_layer)])
|
| 392 |
# # self.norm_f = RMSNorm(d_model=embedding_dim)
|
| 393 |
+
self.norm_f = RMSNorm(self.config.d_model)
|
| 394 |
|
| 395 |
# self.gradient_checkpointing = False
|
| 396 |
# # self.post_init()
|
|
|
|
| 448 |
# def set_input_embeddings(self, value):
|
| 449 |
# self.embed_out = value
|
| 450 |
|
| 451 |
+
def forward(
|
| 452 |
+
self,
|
| 453 |
+
input_ids: torch.LongTensor = None,
|
| 454 |
+
output_hidden_states=False,
|
| 455 |
+
return_dict: Optional[bool] = None,
|
| 456 |
+
**kwargs,
|
| 457 |
+
# ) -> BaseModelOutput:
|
| 458 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 459 |
+
batch_size = input_ids.shape[0]
|
| 460 |
+
hidden_size = self.config.hidden_size
|
| 461 |
+
hidden_states: Tuple[Tensor[(batch_size, sequence_length, hidden_size)]] = ()
|
| 462 |
+
sequence_length = input_ids.shape[1]
|
| 463 |
+
output_hidden_states = output_hidden_states or self.config.output_hidden_states
|
| 464 |
+
|
| 465 |
+
last_hidden_state = self.embedding(input_ids)
|
| 466 |
+
assert last_hidden_state.shape == (
|
| 467 |
+
batch_size,
|
| 468 |
+
sequence_length,
|
| 469 |
+
hidden_size,
|
| 470 |
+
), f"{last_hidden_state.shape} != {(batch_size, sequence_length, hidden_size)}"
|
| 471 |
+
hidden_states += (last_hidden_state,)
|
| 472 |
+
|
| 473 |
+
for layer in self.layers:
|
| 474 |
+
last_hidden_state = layer(last_hidden_state)
|
| 475 |
+
assert last_hidden_state.shape == (
|
| 476 |
+
batch_size,
|
| 477 |
+
sequence_length,
|
| 478 |
+
hidden_size,
|
| 479 |
+
), f"{last_hidden_state.shape} != {(batch_size, sequence_length, hidden_size)}"
|
| 480 |
+
hidden_states += (last_hidden_state,)
|
| 481 |
+
|
| 482 |
+
last_hidden_state = self.norm_f(last_hidden_state)
|
| 483 |
+
assert last_hidden_state.shape == (
|
| 484 |
+
batch_size,
|
| 485 |
+
sequence_length,
|
| 486 |
+
hidden_size,
|
| 487 |
+
), f"{last_hidden_state.shape} != {(batch_size, sequence_length, hidden_size)}"
|
| 488 |
+
hidden_states += (last_hidden_state,)
|
| 489 |
+
|
| 490 |
+
assert (
|
| 491 |
+
len(hidden_states) == self.config.n_layer + 2
|
| 492 |
+
), f"{len(hidden_states)} != {self.config.n_layer + 2}"
|
| 493 |
+
|
| 494 |
+
# return BaseModelOutput(
|
| 495 |
+
return BaseModelOutputWithPast(
|
| 496 |
+
hidden_states=hidden_states if output_hidden_states else None,
|
| 497 |
+
last_hidden_state=last_hidden_state,
|
| 498 |
+
)
|
| 499 |
|
| 500 |
|
| 501 |
# Influences:
|
|
|
|
| 532 |
# Initialize weights and apply final processing
|
| 533 |
self.post_init()
|
| 534 |
|
| 535 |
+
def forward(
|
| 536 |
+
self, input_ids, output_hidden_states=False, **kwargs
|
| 537 |
+
) -> CausalLMOutput:
|
| 538 |
+
batch_size = input_ids.shape[0]
|
| 539 |
+
sequence_length = input_ids.shape[1]
|
| 540 |
+
vocab_size = self.config.vocab_size
|
| 541 |
+
output_hidden_states = output_hidden_states or self.config.output_hidden_states
|
| 542 |
+
|
| 543 |
+
outputs = self.backbone(
|
| 544 |
+
input_ids=input_ids,
|
| 545 |
+
output_hidden_states=output_hidden_states,
|
| 546 |
+
)
|
| 547 |
|
| 548 |
+
last_hidden_state = outputs.last_hidden_state
|
| 549 |
|
| 550 |
+
logits: torch.FloatTensor[batch_size, sequence_length, vocab_size] = (
|
| 551 |
+
self.lm_head(
|
| 552 |
+
last_hidden_state,
|
| 553 |
+
)
|
| 554 |
+
)
|
| 555 |
|
| 556 |
+
return CausalLMOutput(
|
| 557 |
+
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
| 558 |
+
logits=logits,
|
| 559 |
+
)
|
| 560 |
|
| 561 |
# # def prepare_inputs_for_generation(
|
| 562 |
# # self, input_ids, attention_mask=None, **model_kwargs
|