tensorfiend commited on
Commit
b4b0382
·
verified ·
1 Parent(s): 9b021de

Upload modeling_dotlm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_dotlm.py +33 -0
modeling_dotlm.py CHANGED
@@ -382,3 +382,36 @@ class DotLMForCausalLM(PreTrainedModel, GenerationMixin):
382
  (k.index_select(0, beam_idx), v.index_select(0, beam_idx))
383
  for (k, v) in past_key_values
384
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
  (k.index_select(0, beam_idx), v.index_select(0, beam_idx))
383
  for (k, v) in past_key_values
384
  )
385
+
386
+ @torch.no_grad()
387
+ def generate(self, input_ids=None, max_new_tokens=256, temperature=1.0,
388
+ top_k=None, do_sample=True, eos_token_id=None, **kwargs):
389
+ """Custom autoregressive generate that bypasses GenerationMixin internals."""
390
+ self._ensure_rope_and_mask()
391
+ kv_cache = None
392
+ curr_ids = input_ids
393
+
394
+ for _ in range(max_new_tokens):
395
+ if curr_ids.size(1) > self.config.context_len:
396
+ curr_ids = curr_ids[:, -self.config.context_len:]
397
+
398
+ model_input = curr_ids if kv_cache is None else curr_ids[:, -1:]
399
+ out = self.forward(model_input, past_key_values=kv_cache, use_cache=True, return_dict=True)
400
+ kv_cache = out.past_key_values
401
+
402
+ logits = out.logits[:, -1, :]
403
+ if do_sample:
404
+ logits = logits / max(temperature, 1e-8)
405
+ if top_k is not None:
406
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
407
+ logits[logits < v[:, [-1]]] = -float("Inf")
408
+ probs = F.softmax(logits, dim=-1)
409
+ next_token = torch.multinomial(probs, num_samples=1)
410
+ else:
411
+ next_token = logits.argmax(dim=-1, keepdim=True)
412
+
413
+ curr_ids = torch.cat([curr_ids, next_token], dim=1)
414
+ if eos_token_id is not None and (next_token == eos_token_id).all():
415
+ break
416
+
417
+ return curr_ids