TerenceLau commited on
Commit
12621a6
·
verified ·
1 Parent(s): 6f69439

Upload model

Browse files
Files changed (2) hide show
  1. config.json +1 -2
  2. modelling_sparrow.py +39 -2
config.json CHANGED
@@ -4,8 +4,7 @@
4
  ],
5
  "attention_bias": false,
6
  "auto_map": {
7
- "AutoConfig": "configuration_sparrow.SparrowConfig",
8
- "AutoModelForCausalLM": "modelling_sparrow.SparrowModel"
9
  },
10
  "dropout": 0.0,
11
  "flash_attn": true,
 
4
  ],
5
  "attention_bias": false,
6
  "auto_map": {
7
+ "AutoConfig": "modelling_sparrow.SparrowConfig"
 
8
  },
9
  "dropout": 0.0,
10
  "flash_attn": true,
modelling_sparrow.py CHANGED
@@ -1,12 +1,49 @@
1
  import math
 
 
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
 
6
- from transformers import PreTrainedModel
7
  from transformers.modeling_outputs import CausalLMOutputWithPast
8
 
9
- from configuration_sparrow import SparrowConfig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  ## RoPE - from https://arxiv.org/pdf/2104.09864v5
12
  def rotate_half(x):
 
1
  import math
2
+ from typing import Optional
3
+
4
  import torch
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
 
8
+ from transformers import PreTrainedModel, PretrainedConfig
9
  from transformers.modeling_outputs import CausalLMOutputWithPast
10
 
11
+ class SparrowConfig(PretrainedConfig):
12
+ model_type = "sparrow"
13
+
14
+ def __init__(
15
+ self,
16
+ hidden_size: int = 512,
17
+ num_hidden_layers: int = 8,
18
+ num_attention_heads: int = 16,
19
+ num_key_value_heads: Optional[int] = None,
20
+ max_seq_len: int = 512,
21
+ attention_bias: bool = False,
22
+ flash_attn: bool = True,
23
+ vocab_size: int = 32000,
24
+ hidden_dim: Optional[int] = None,
25
+ intermediate_dim: int = 2048,
26
+ norm_eps: float = 1e-5,
27
+ mlp_bias: bool = False,
28
+ dropout: float = 0.0,
29
+ **kwargs,
30
+ ):
31
+ super().__init__(**kwargs)
32
+ # attention args
33
+ self.hidden_size = hidden_size
34
+ self.num_hidden_layers = num_hidden_layers
35
+ self.num_attention_heads = num_attention_heads
36
+ self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads
37
+ self.max_seq_len = max_seq_len
38
+ self.attention_bias = attention_bias
39
+ self.flash_attn = flash_attn
40
+ # mlp args
41
+ self.vocab_size = vocab_size
42
+ self.hidden_dim = hidden_dim if hidden_dim is not None else hidden_size
43
+ self.intermediate_dim = intermediate_dim
44
+ self.norm_eps = norm_eps
45
+ self.mlp_bias = mlp_bias
46
+ self.dropout = dropout
47
 
48
  ## RoPE - from https://arxiv.org/pdf/2104.09864v5
49
  def rotate_half(x):