leejunhyeok commited on
Commit
f19aee3
·
verified ·
1 Parent(s): 73e65dd

Update logit_processors/logit_.py

Browse files
Files changed (1) hide show
  1. logit_processors/logit_.py +7 -3
logit_processors/logit_.py CHANGED
@@ -9,6 +9,10 @@ from vllm.sampling_params import SamplingParams
9
  import os
10
  from collections import Counter
11
 
 
 
 
 
12
  class ThinkLogitsProcessor:
13
  def __init__(self, think_end_token = 219406, max_len: int = 131072, ratio: float = 0.95):
14
  self.think_end_token = think_end_token
@@ -23,9 +27,9 @@ class ThinkLogitsProcessor:
23
  n: n-gram size
24
  returns dict of {ngram_tuple: count} for repeated n-grams
25
  """
26
- ngrams = [tuple(input_ids[i:i+n]) for i in range(0, len(input_ids) - n + 1, 256)]
27
  freq = Counter(ngrams)
28
- return {ng: c for ng, c in freq.items() if c > 7}
29
 
30
  def __call__(
31
  self,
@@ -38,7 +42,7 @@ class ThinkLogitsProcessor:
38
  # ngram
39
  if len(past_token_ids) % self.interval == 0:
40
  # If repetation detected, force </think>
41
- if self.find_repeated_ngrams(past_token_ids, n=16384):
42
  # Set all other logits to -inf except for </think>
43
  logits = torch.full_like(logits, torch.finfo(torch.bfloat16).min)
44
  logits[self.think_end_token] = 1.0
 
9
  import os
10
  from collections import Counter
11
 
12
+ CHUNK_SIZE=16384
13
+ WINDOW_SIZE=256
14
+ MAX_REPETATION_COUNT=7
15
+
16
  class ThinkLogitsProcessor:
17
  def __init__(self, think_end_token = 219406, max_len: int = 131072, ratio: float = 0.95):
18
  self.think_end_token = think_end_token
 
27
  n: n-gram size
28
  returns dict of {ngram_tuple: count} for repeated n-grams
29
  """
30
+ ngrams = [tuple(input_ids[i:i+n]) for i in range(0, len(input_ids) - n + 1, WINDOW_SIZE)]
31
  freq = Counter(ngrams)
32
+ return {ng: c for ng, c in freq.items() if c > MAX_REPETATION_COUNT}
33
 
34
  def __call__(
35
  self,
 
42
  # ngram
43
  if len(past_token_ids) % self.interval == 0:
44
  # If repetation detected, force </think>
45
+ if self.find_repeated_ngrams(past_token_ids, n=CHUNK_SIZE):
46
  # Set all other logits to -inf except for </think>
47
  logits = torch.full_like(logits, torch.finfo(torch.bfloat16).min)
48
  logits[self.think_end_token] = 1.0