TerenceLau commited on
Commit
ae96e35
·
verified ·
1 Parent(s): c626103

Upload model

Browse files
config.json CHANGED
@@ -1,9 +1,12 @@
1
  {
2
- "_name_or_path": "/data/sparrow/results/checkpoint-15000",
3
  "architectures": [
4
  "SparrowModel"
5
  ],
6
  "attention_bias": false,
 
 
 
 
7
  "dropout": 0.0,
8
  "flash_attn": true,
9
  "hidden_dim": 512,
@@ -12,7 +15,7 @@
12
  "max_seq_len": 512,
13
  "mlp_bias": false,
14
  "model_type": "sparrow",
15
- "norm_eps": "1e-5",
16
  "num_attention_heads": 16,
17
  "num_hidden_layers": 8,
18
  "num_key_value_heads": 16,
 
1
  {
 
2
  "architectures": [
3
  "SparrowModel"
4
  ],
5
  "attention_bias": false,
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_sparrow.SparrowConfig",
8
+ "AutoModelForCausalLM": "modelling_sparrow.SparrowModel"
9
+ },
10
  "dropout": 0.0,
11
  "flash_attn": true,
12
  "hidden_dim": 512,
 
15
  "max_seq_len": 512,
16
  "mlp_bias": false,
17
  "model_type": "sparrow",
18
+ "norm_eps": 1e-05,
19
  "num_attention_heads": 16,
20
  "num_hidden_layers": 8,
21
  "num_key_value_heads": 16,
configuration_sparrow.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ from transformers import PretrainedConfig
3
+
4
+ class SparrowConfig(PretrainedConfig):
5
+ model_type = "sparrow"
6
+
7
+ def __init__(
8
+ self,
9
+ hidden_size: int = 512,
10
+ num_hidden_layers: int = 8,
11
+ num_attention_heads: int = 16,
12
+ num_key_value_heads: Optional[int] = None,
13
+ max_seq_len: int = 512,
14
+ attention_bias: bool = False,
15
+ flash_attn: bool = True,
16
+ vocab_size: int = 32000,
17
+ hidden_dim: Optional[int] = None,
18
+ intermediate_dim: int = 2048,
19
+ norm_eps: float = 1e-5,
20
+ mlp_bias: bool = False,
21
+ dropout: float = 0.0,
22
+ **kwargs,
23
+ ):
24
+ super().__init__(**kwargs)
25
+ # attention args
26
+ self.hidden_size = hidden_size
27
+ self.num_hidden_layers = num_hidden_layers
28
+ self.num_attention_heads = num_attention_heads
29
+ self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads
30
+ self.max_seq_len = max_seq_len
31
+ self.attention_bias = attention_bias
32
+ self.flash_attn = flash_attn
33
+ # mlp args
34
+ self.vocab_size = vocab_size
35
+ self.hidden_dim = hidden_dim if hidden_dim is not None else hidden_size
36
+ self.intermediate_dim = intermediate_dim
37
+ self.norm_eps = norm_eps
38
+ self.mlp_bias = mlp_bias
39
+ self.dropout = dropout
modelling_sparrow.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from transformers import PreTrainedModel
7
+ from transformers.modeling_outputs import CausalLMOutputWithPast
8
+
9
+ from model.configuration_sparrow import SparrowConfig
10
+
11
+ ## RoPE - from https://arxiv.org/pdf/2104.09864v5
12
+ def rotate_half(x):
13
+ x1, x2 = x.chunk(2, dim=-1)
14
+ return torch.cat((-x2, x1), dim=-1)
15
+
16
+ def apply_rotate_pos_emb(q, k, cos, sin, unsqueeze_dim=2):
17
+
18
+ cos = cos.unsqueeze(unsqueeze_dim)
19
+ sin = sin.unsqueeze(unsqueeze_dim)
20
+
21
+ q_embed = (q*cos) + (rotate_half(q)*sin)
22
+ k_embed = (k*cos) + (rotate_half(k)*sin)
23
+
24
+ return q_embed, k_embed
25
+
26
+ class RotaryEmbedding(nn.Module):
27
+ def __init__(self, dim, max_seq_len=2048):
28
+ super(RotaryEmbedding, self).__init__()
29
+ self.hidden_size = dim
30
+ self.max_seq_len = max_seq_len
31
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
32
+ t = torch.arange(max_seq_len).float().unsqueeze(1)
33
+ freqs = t @ inv_freq.unsqueeze(0)
34
+ freqs = torch.cat((freqs, freqs), dim=-1)
35
+
36
+ self.register_buffer("cos_cached", freqs.cos())
37
+ self.register_buffer("sin_cached", freqs.sin())
38
+
39
+ def forward(self, q, k):
40
+ cos = self.cos_cached[:q.shape[1], :].unsqueeze(0)
41
+ sin = self.sin_cached[:q.shape[1], :].unsqueeze(0)
42
+ return apply_rotate_pos_emb(q, k, cos, sin)
43
+
44
+
45
+ ## RMSNorm
46
+ class RMSNorm(nn.Module):
47
+ def __init__(self, dim: int, eps: float=1.0e-6):
48
+ super(RMSNorm, self).__init__()
49
+ self.eps = eps
50
+ self.weight = nn.Parameter(torch.ones(dim))
51
+
52
+ def normalize(self, x):
53
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
54
+
55
+ def forward(self, x):
56
+ output = self.normalize(x).type_as(x)
57
+ return output * self.weight
58
+
59
+ def repeat_kv(x, n_rep):
60
+ batch, length, num_key_value_heads, head_dim = x.shape
61
+ if n_rep == 1:
62
+ return x
63
+
64
+ x = x[:, :, :, None, :].expand(batch, length, num_key_value_heads, n_rep, head_dim)
65
+ return x.reshape(batch, length, num_key_value_heads * n_rep, head_dim)
66
+
67
+ ## SparrowAttention
68
+ class SparrowAttention(nn.Module):
69
+ '''
70
+ '''
71
+ def __init__(self, config: SparrowConfig=None):
72
+ super(SparrowAttention, self).__init__()
73
+ self.config = config
74
+ self.hidden_size = config.hidden_size
75
+ self.num_hidden_layers = config.num_hidden_layers
76
+ self.num_attention_heads = config.num_attention_heads
77
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_attention_heads)
78
+ self.num_key_value_heads = config.num_key_value_heads
79
+ self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
80
+ self.vocab_size = config.vocab_size
81
+ self.dropout = config.dropout
82
+ self.rotary_emb = RotaryEmbedding(dim=self.head_dim)
83
+
84
+ self.wq = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=self.config.attention_bias)
85
+ self.wk = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.config.attention_bias)
86
+ self.wv = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.config.attention_bias)
87
+ self.wo = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=self.config.attention_bias)
88
+ self.k_cache, self.v_cache = None, None
89
+ self.attention_dropout = nn.Dropout(self.dropout)
90
+ self.residual_dropout = nn.Dropout(self.dropout)
91
+
92
+ def forward(self, x: torch.Tensor, use_kv_cache=False):
93
+ b, s = x.shape[:2]
94
+ if use_kv_cache and self.eval():
95
+ if self.k_cache is None or self.k_cache.shape[1] != s-1:
96
+ q, k, v = self.wq(x), self.wk(x), self.wv(x)
97
+ else:
98
+ token = x[:, -1:, :]
99
+ q = torch.cat((torch.zeros_like(x[:, :-1, :]), self.wq(token)), dim=1)
100
+ k = torch.cat((self.k_cache, self.wk(token)), dim=1)
101
+ v = torch.cat((self.v_cache, self.wv(token)), dim=1)
102
+
103
+ self.k_cache, self.v_cache = k, v
104
+ else:
105
+ q, k, v = self.wq(x), self.wk(x), self.wv(x)
106
+
107
+ q = q.view(b, s, self.num_attention_heads, self.head_dim)
108
+ k = k.view(b, s, self.num_key_value_heads, self.head_dim)
109
+ v = v.view(b, s, self.num_key_value_heads, self.head_dim)
110
+ q, k = self.rotary_emb(q, k)
111
+ k, v = repeat_kv(k, self.num_key_value_groups), repeat_kv(v, self.num_key_value_groups)
112
+
113
+ q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
114
+
115
+ if self.config.flash_attn:
116
+ output = F.scaled_dot_product_attention(q, k, v, attn_mask=None,
117
+ dropout_p=self.dropout if self.training else 0.0,
118
+ is_causal=True)
119
+ else:
120
+ mask = torch.full((1, 1, self.config.max_seq_len, self.config.max_seq_len), float("-inf"), device=x.device)
121
+ mask = torch.triu(mask, diagonal=1)
122
+ scores = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
123
+ scores = scores + mask[:, :, :s, :s]
124
+ scores = F.softmax(scores.float(), dim=-1).type_as(q)
125
+ scores = self.attention_dropout(scores)
126
+ output = torch.matmul(scores, v)
127
+
128
+ output = output.transpose(1, 2).contiguous().view(b, s, -1)
129
+ output = self.wo(output)
130
+ output = self.residual_dropout(output)
131
+ return output
132
+
133
+ class SparrowLinear(nn.Module):
134
+ def __init__(self, config: SparrowConfig=None):
135
+ super(SparrowLinear, self).__init__()
136
+ self.config = config
137
+ self.hidden_size = config.hidden_size
138
+ self.intermediate_dim = config.intermediate_dim
139
+ self.gate = nn.Linear(self.hidden_size, self.intermediate_dim, bias=self.config.mlp_bias)
140
+ self.up = nn.Linear(self.hidden_size, self.intermediate_dim, bias=self.config.mlp_bias)
141
+ self.out = nn.Linear(self.intermediate_dim, self.hidden_size, bias=self.config.mlp_bias)
142
+
143
+ def forward(self, x):
144
+ return self.out(F.silu(self.gate(x)) * self.up(x))
145
+
146
+ class SparrowDecoderLayer(nn.Module):
147
+ def __init__(self, config: SparrowConfig=None, layer_idx: int=None):
148
+ super(SparrowDecoderLayer, self).__init__()
149
+ self.hidden_size = config.hidden_size
150
+ self.attention = SparrowAttention(config=config)
151
+ self.linear = SparrowLinear(config=config)
152
+ self.input_norm = RMSNorm(dim=config.hidden_size)
153
+ self.pos_attn_norm = RMSNorm(dim=config.hidden_size)
154
+ self.layer_idx = layer_idx
155
+
156
+ def forward(self, x, use_kv_cache):
157
+ residual = x
158
+ x = self.input_norm(x)
159
+ residual, x = x, self.attention(x=x, use_kv_cache=use_kv_cache) + residual
160
+ x = self.linear(self.pos_attn_norm(x))
161
+ x = x + residual
162
+ return x
163
+
164
+ class SparrowModel(PreTrainedModel):
165
+ config_class = SparrowConfig
166
+
167
+ def __init__(self, config):
168
+ super().__init__(config)
169
+ self.config = config
170
+ self.vocab_size = self.config.vocab_size
171
+ self.num_hidden_layers = self.config.num_hidden_layers
172
+ self.token_embedding = nn.Embedding(self.config.vocab_size, self.config.hidden_size)
173
+ self.dropout = nn.Dropout(self.config.dropout)
174
+
175
+ self.decoder = nn.ModuleList()
176
+ for layer_idx in range(self.num_hidden_layers):
177
+ self.decoder.append(SparrowDecoderLayer(config=self.config, layer_idx=layer_idx))
178
+
179
+ self.norm = RMSNorm(dim=self.config.hidden_size)
180
+ self.output = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=self.config.mlp_bias)
181
+ self.token_embedding.weight = self.output.weight
182
+ self.apply(self.weights_init)
183
+ self.loss = None
184
+
185
+ for pn, p in self.named_parameters():
186
+ if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
187
+ torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * self.config.num_hidden_layers))
188
+
189
+ def weights_init(self, module):
190
+ if isinstance(module, nn.Linear):
191
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
192
+ if module.bias is not None:
193
+ torch.nn.init.zeros_(module.bias)
194
+ elif isinstance(module, nn.Embedding):
195
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
196
+ if module.padding_idx is not None:
197
+ module.weight.data[module.padding_idx].zero_()
198
+
199
+ def forward(self, input_ids, labels, use_kv_cache=False):
200
+ x = self.dropout(self.token_embedding(input_ids))
201
+
202
+ for idx, layer in enumerate(self.decoder):
203
+ x = layer(x=x, use_kv_cache=use_kv_cache)
204
+
205
+ if labels is not None:
206
+ logits = self.output(x)
207
+ self.loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=0)
208
+ else:
209
+ logits = self.output(x[:, [-1], :])
210
+ self.loss = None
211
+
212
+ return CausalLMOutputWithPast(self.loss, logits)
213
+
214
+ @torch.inference_mode
215
+ def generate(self, input_ids, eos, max_new_tokens, temperature=0.7, top_k=None, stream=True, repetition_penalty=1.,
216
+ use_kv_cache=True):
217
+
218
+ s = input_ids.shape[1]
219
+ while input_ids.shape[1] < max_new_tokens - 1:
220
+ inference_res = self(input_ids, labels=None, use_kv_cache=use_kv_cache)
221
+ logits = inference_res.logits
222
+ logits = logits[:, -1, :]
223
+
224
+ for token in set(input_ids.tolist()[0]):
225
+ logits[:, token] /= repetition_penalty
226
+
227
+ if temperature == 0.0:
228
+ _, idx_next = torch.topk(logits, k=1, dim=-1)
229
+ else:
230
+ logits = logits / temperature
231
+ if top_k is not None:
232
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
233
+ logits[logits < v[:, [-1]]] = -float('Inf')
234
+
235
+ probs = F.softmax(logits, dim=-1)
236
+ idx_next = torch.multinomial(probs, num_samples=1, generator=None)
237
+
238
+ if idx_next == eos:
239
+ break
240
+
241
+ input_ids = torch.cat((input_ids, idx_next), dim=1)
242
+ if stream:
243
+ yield input_ids[:, s:]
244
+
245
+ if not stream:
246
+ yield input_ids[:, s:]
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:177c43fde459c6e4f610c58e282c98b8563f387f0db5bb1db3495948206cc82e
3
+ size 204011452