mjschock commited on
Commit
4e04e38
·
verified ·
1 Parent(s): 0e97820

Upload model

Browse files
Files changed (2) hide show
  1. config.json +6 -1
  2. modeling_mamba.py +15 -14
config.json CHANGED
@@ -1,6 +1,10 @@
1
  {
 
 
 
2
  "auto_map": {
3
- "AutoConfig": "configuration_mamba.MambaConfig"
 
4
  },
5
  "d_model": 768,
6
  "fused_add_norm": true,
@@ -10,6 +14,7 @@
10
  "residual_in_fp32": true,
11
  "rms_norm": true,
12
  "ssm_cfg": {},
 
13
  "transformers_version": "4.37.2",
14
  "vocab_size": 50277
15
  }
 
1
  {
2
+ "architectures": [
3
+ "MambaModel"
4
+ ],
5
  "auto_map": {
6
+ "AutoConfig": "configuration_mamba.MambaConfig",
7
+ "AutoModel": "modeling_mamba.MambaModel"
8
  },
9
  "d_model": 768,
10
  "fused_add_norm": true,
 
14
  "residual_in_fp32": true,
15
  "rms_norm": true,
16
  "ssm_cfg": {},
17
+ "torch_dtype": "float16",
18
  "transformers_version": "4.37.2",
19
  "vocab_size": 50277
20
  }
modeling_mamba.py CHANGED
@@ -1,6 +1,7 @@
1
- from typing import Optional
2
 
3
  from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
 
4
  from transformers import GenerationMixin, PreTrainedModel
5
  from transformers.generation import TextStreamer
6
 
@@ -48,18 +49,18 @@ class MambaModelForCausalLM(MambaModel, GenerationMixin):
48
  def generate(
49
  self,
50
  input_ids,
51
- max_length,
52
- top_k=1,
53
- top_p=0.0,
54
- temperature=1.0,
55
- return_dict_in_generate=False,
56
- output_scores=False,
57
- repetition_penalty=1.0,
58
- eos_token_id=None,
59
- teacher_outputs=None,
60
- vocab_size=None,
61
- cg=False,
62
- enable_timing=False,
63
  streamer: Optional[TextStreamer] = None,
64
  **kwargs,
65
  ):
@@ -77,5 +78,5 @@ class MambaModelForCausalLM(MambaModel, GenerationMixin):
77
  vocab_size=vocab_size,
78
  cg=cg,
79
  enable_timing=enable_timing,
80
- streamer = streamer,
81
  )
 
1
+ from typing import Optional, Tuple
2
 
3
  from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
4
+ import torch
5
  from transformers import GenerationMixin, PreTrainedModel
6
  from transformers.generation import TextStreamer
7
 
 
49
  def generate(
50
  self,
51
  input_ids,
52
+ max_length: int = 2048,
53
+ top_k: int = 1,
54
+ top_p: float = 0.0,
55
+ temperature: float = 1.0,
56
+ return_dict_in_generate: bool = False,
57
+ output_scores: bool = False,
58
+ repetition_penalty: float = 1.0,
59
+ eos_token_id: Optional[int] = None,
60
+ teacher_outputs: Optional[torch.Tensor] = None,
61
+ vocab_size: Optional[int] = None,
62
+ cg: bool = False,
63
+ enable_timing: bool = False,
64
  streamer: Optional[TextStreamer] = None,
65
  **kwargs,
66
  ):
 
78
  vocab_size=vocab_size,
79
  cg=cg,
80
  enable_timing=enable_timing,
81
+ streamer=streamer,
82
  )