TerenceLau commited on
Commit
5446460
·
verified ·
1 Parent(s): ae79607

Update modelling_sparrow.py

Browse files
Files changed (1) hide show
  1. modelling_sparrow.py +74 -38
modelling_sparrow.py CHANGED
@@ -214,14 +214,7 @@ class SparrowModel(PreTrainedModel):
214
  self.decoder.append(SparrowDecoderLayer(config=self.config, layer_idx=layer_idx))
215
 
216
  self.norm = RMSNorm(dim=self.config.hidden_size)
217
- self.output = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=self.config.mlp_bias)
218
- self.token_embedding.weight = self.output.weight
219
  self.apply(self.weights_init)
220
- self.loss = None
221
-
222
- for pn, p in self.named_parameters():
223
- if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
224
- torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * self.config.num_hidden_layers))
225
 
226
  def weights_init(self, module):
227
  if isinstance(module, nn.Linear):
@@ -233,51 +226,94 @@ class SparrowModel(PreTrainedModel):
233
  if module.padding_idx is not None:
234
  module.weight.data[module.padding_idx].zero_()
235
 
236
- def forward(self, input_ids, labels, use_kv_cache=False):
237
  x = self.dropout(self.token_embedding(input_ids))
238
 
239
  for idx, layer in enumerate(self.decoder):
240
  x = layer(x=x, use_kv_cache=use_kv_cache)
241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  if labels is not None:
243
  logits = self.output(x)
244
- self.loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=0)
245
  else:
246
  logits = self.output(x[:, [-1], :])
247
  self.loss = None
248
 
249
  return CausalLMOutputWithPast(self.loss, logits)
250
 
251
- @torch.inference_mode
252
- def generate(self, input_ids, eos, max_new_tokens, temperature=0.7, top_k=None, stream=True, repetition_penalty=1.,
253
- use_kv_cache=True):
254
-
255
  s = input_ids.shape[1]
256
- while input_ids.shape[1] < max_new_tokens - 1:
257
- inference_res = self(input_ids, labels=None, use_kv_cache=use_kv_cache)
258
- logits = inference_res.logits
259
- logits = logits[:, -1, :]
260
-
261
- for token in set(input_ids.tolist()[0]):
262
- logits[:, token] /= repetition_penalty
263
-
264
- if temperature == 0.0:
265
- _, idx_next = torch.topk(logits, k=1, dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  else:
267
- logits = logits / temperature
268
- if top_k is not None:
269
  v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
270
- logits[logits < v[:, [-1]]] = -float('Inf')
271
-
272
- probs = F.softmax(logits, dim=-1)
273
- idx_next = torch.multinomial(probs, num_samples=1, generator=None)
274
-
275
- if idx_next == eos:
276
  break
277
-
278
- input_ids = torch.cat((input_ids, idx_next), dim=1)
279
- if stream:
280
- yield input_ids[:, s:]
281
-
282
- if not stream:
283
- yield input_ids[:, s:]
 
214
  self.decoder.append(SparrowDecoderLayer(config=self.config, layer_idx=layer_idx))
215
 
216
  self.norm = RMSNorm(dim=self.config.hidden_size)
 
 
217
  self.apply(self.weights_init)
 
 
 
 
 
218
 
219
  def weights_init(self, module):
220
  if isinstance(module, nn.Linear):
 
226
  if module.padding_idx is not None:
227
  module.weight.data[module.padding_idx].zero_()
228
 
229
+ def forward(self, input_ids, use_kv_cache=False):
230
  x = self.dropout(self.token_embedding(input_ids))
231
 
232
  for idx, layer in enumerate(self.decoder):
233
  x = layer(x=x, use_kv_cache=use_kv_cache)
234
 
235
+ return self.norm(x)
236
+
237
+ class SparrowModelForCausalLM(SparrowModel):
238
+ def __init__(self, config):
239
+ super().__init__(config)
240
+ self.output = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=self.config.mlp_bias)
241
+ self.token_embedding.weight = self.output.weight
242
+ self.loss = None
243
+
244
+ for pn, p in self.named_parameters():
245
+ if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
246
+ torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * self.config.num_hidden_layers))
247
+
248
+ def forward(self, input_ids, labels=None, use_kv_cache=False):
249
+ x = super().forward(input_ids, use_kv_cache)
250
+
251
  if labels is not None:
252
  logits = self.output(x)
253
+ self.loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-100)
254
  else:
255
  logits = self.output(x[:, [-1], :])
256
  self.loss = None
257
 
258
  return CausalLMOutputWithPast(self.loss, logits)
259
 
260
+ @torch.no_grad()
261
+ def generate(self, input_ids, eos=1, max_new_tokens=50, temperature=0.7, top_k=None, repetition_penalty=1.,
262
+ use_kv_cache=True, use_beam_search=False, beam_size=3):
 
263
  s = input_ids.shape[1]
264
+
265
+ if use_beam_search:
266
+ sequences = [(input_ids, 0)] # List of (sequence, cumulative log probability)
267
+ for _ in range(max_new_tokens - 1):
268
+ all_candidates = []
269
+ for seq, score in sequences:
270
+ inference_res = self(seq, labels=None, use_kv_cache=use_kv_cache)
271
+ logits = inference_res.logits[:, -1, :]
272
+
273
+ if repetition_penalty != 1.0:
274
+ for token in set(seq.tolist()[0]):
275
+ logits[:, token] /= repetition_penalty
276
+
277
+ logits = logits / temperature if temperature > 0 else logits
278
+ probs = F.log_softmax(logits, dim=-1)
279
+ top_log_prob, idx_next = torch.topk(probs, beam_size, dim=-1)
280
+
281
+ for i in range(beam_size):
282
+ next_seq = torch.cat((seq, idx_next[:, i].unsqueeze(1)), dim=1)
283
+ next_score = score + top_log_prob[:, i].item()
284
+ all_candidates.append((next_seq, next_score))
285
+
286
+ sequences = sorted(all_candidates, key=lambda x: x[1], reverse=True)[:beam_size]
287
+ if all(seq[0][:, -1].item() == eos for seq in sequences):
288
+ break
289
+
290
+ best_seq = sequences[0][0]
291
+ return best_seq.tolist()[0][s:]
292
+
293
+ # Greedy search (default)
294
+ generated_tokens = []
295
+ while len(generated_tokens) < max_new_tokens - 1:
296
+ inference_res = self(input_ids, labels=None, use_kv_cache=use_kv_cache)
297
+ logits = inference_res.logits[:, -1, :]
298
+
299
+ if repetition_penalty != 1.0:
300
+ for token in set(input_ids.tolist()[0]):
301
+ logits[:, token] /= repetition_penalty
302
+
303
+ if temperature == 0.0:
304
+ idx_next = torch.argmax(logits, dim=-1, keepdim=True)
305
  else:
306
+ logits = logits / temperature
307
+ if top_k is not None:
308
  v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
309
+ logits[logits < v[:, [-1]]] = -float('Inf')
310
+ probs = F.softmax(logits, dim=-1)
311
+ idx_next = torch.multinomial(probs, num_samples=1)
312
+
313
+ if idx_next.item() == eos:
 
314
  break
315
+
316
+ input_ids = torch.cat((input_ids, idx_next), dim=1)
317
+ generated_tokens.append(idx_next.item())
318
+
319
+ return generated_tokens