robinfaro commited on
Commit
936ffab
·
verified ·
1 Parent(s): 9cba846

Upload modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +5 -4
modeling.py CHANGED
@@ -403,12 +403,13 @@ class MoEGPTForCausalLM(PreTrainedModel):
403
  ]
404
 
405
  @torch.no_grad()
406
- def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
407
  """
408
  Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
409
  the sequence max_new_tokens times, feeding the predictions back into the model each time.
410
  Most likely you'll want to make sure to be in model.eval() mode of operation for this.
411
  """
 
412
  for _ in range(max_new_tokens):
413
  # if the sequence context is growing too long we must crop it at sequence_length
414
  idx_cond = (
@@ -417,7 +418,7 @@ class MoEGPTForCausalLM(PreTrainedModel):
417
  else idx[:, -self.config.sequence_length :]
418
  )
419
  # forward the model to get the logits for the index in the sequence
420
- logits = self(idx_cond, get_logits=True)["logits"]
421
  # pluck the logits at the final step and scale by desired temperature
422
  logits = logits[:, -1, :] / temperature
423
  # optionally crop the logits to only the top k options
@@ -434,7 +435,7 @@ class MoEGPTForCausalLM(PreTrainedModel):
434
  return idx
435
 
436
  @torch.no_grad()
437
- def generate_from_string(self, in_str, max_new_tokens, temperature=1.0, top_k=None):
438
  idx = (
439
  torch.tensor(
440
  self.tokenizer.encode(in_str, allowed_special={"<|endoftext|>"})
@@ -443,7 +444,7 @@ class MoEGPTForCausalLM(PreTrainedModel):
443
  .to(self.lm_head.weight.device)
444
  )
445
  out_idx = (
446
- self.generate(idx, max_new_tokens, temperature, top_k)
447
  .view(-1)
448
  .to("cpu")
449
  .numpy()
 
403
  ]
404
 
405
  @torch.no_grad()
406
+ def generate(self, input_ids, max_new_tokens, date = None, temperature=1.0, top_k=None):
407
  """
408
  Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
409
  the sequence max_new_tokens times, feeding the predictions back into the model each time.
410
  Most likely you'll want to make sure to be in model.eval() mode of operation for this.
411
  """
412
+ idx = input_ids
413
  for _ in range(max_new_tokens):
414
  # if the sequence context is growing too long we must crop it at sequence_length
415
  idx_cond = (
 
418
  else idx[:, -self.config.sequence_length :]
419
  )
420
  # forward the model to get the logits for the index in the sequence
421
+ logits = self(idx_cond, date, get_logits=True).logits
422
  # pluck the logits at the final step and scale by desired temperature
423
  logits = logits[:, -1, :] / temperature
424
  # optionally crop the logits to only the top k options
 
435
  return idx
436
 
437
  @torch.no_grad()
438
+ def generate_from_string(self, in_str, max_new_tokens, date = None, temperature=1.0, top_k=None):
439
  idx = (
440
  torch.tensor(
441
  self.tokenizer.encode(in_str, allowed_special={"<|endoftext|>"})
 
444
  .to(self.lm_head.weight.device)
445
  )
446
  out_idx = (
447
+ self.generate(idx, max_new_tokens, date, temperature, top_k)
448
  .view(-1)
449
  .to("cpu")
450
  .numpy()