Update logit_processors/logit_.py
Browse files
logit_processors/logit_.py
CHANGED
|
@@ -9,6 +9,10 @@ from vllm.sampling_params import SamplingParams
|
|
| 9 |
import os
|
| 10 |
from collections import Counter
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
class ThinkLogitsProcessor:
|
| 13 |
def __init__(self, think_end_token = 219406, max_len: int = 131072, ratio: float = 0.95):
|
| 14 |
self.think_end_token = think_end_token
|
|
@@ -23,9 +27,9 @@ class ThinkLogitsProcessor:
|
|
| 23 |
n: n-gram size
|
| 24 |
returns dict of {ngram_tuple: count} for repeated n-grams
|
| 25 |
"""
|
| 26 |
-
ngrams = [tuple(input_ids[i:i+n]) for i in range(0, len(input_ids) - n + 1,
|
| 27 |
freq = Counter(ngrams)
|
| 28 |
-
return {ng: c for ng, c in freq.items() if c >
|
| 29 |
|
| 30 |
def __call__(
|
| 31 |
self,
|
|
@@ -38,7 +42,7 @@ class ThinkLogitsProcessor:
|
|
| 38 |
# ngram
|
| 39 |
if len(past_token_ids) % self.interval == 0:
|
| 40 |
# If repetation detected, force </think>
|
| 41 |
-
if self.find_repeated_ngrams(past_token_ids, n=
|
| 42 |
# Set all other logits to -inf except for </think>
|
| 43 |
logits = torch.full_like(logits, torch.finfo(torch.bfloat16).min)
|
| 44 |
logits[self.think_end_token] = 1.0
|
|
|
|
| 9 |
import os
|
| 10 |
from collections import Counter
|
| 11 |
|
| 12 |
+
CHUNK_SIZE=16384
|
| 13 |
+
WINDOW_SIZE=256
|
| 14 |
+
MAX_REPETATION_COUNT=7
|
| 15 |
+
|
| 16 |
class ThinkLogitsProcessor:
|
| 17 |
def __init__(self, think_end_token = 219406, max_len: int = 131072, ratio: float = 0.95):
|
| 18 |
self.think_end_token = think_end_token
|
|
|
|
| 27 |
n: n-gram size
|
| 28 |
returns dict of {ngram_tuple: count} for repeated n-grams
|
| 29 |
"""
|
| 30 |
+
ngrams = [tuple(input_ids[i:i+n]) for i in range(0, len(input_ids) - n + 1, WINDOW_SIZE)]
|
| 31 |
freq = Counter(ngrams)
|
| 32 |
+
return {ng: c for ng, c in freq.items() if c > MAX_REPETATION_COUNT}
|
| 33 |
|
| 34 |
def __call__(
|
| 35 |
self,
|
|
|
|
| 42 |
# ngram
|
| 43 |
if len(past_token_ids) % self.interval == 0:
|
| 44 |
# If repetation detected, force </think>
|
| 45 |
+
if self.find_repeated_ngrams(past_token_ids, n=CHUNK_SIZE):
|
| 46 |
# Set all other logits to -inf except for </think>
|
| 47 |
logits = torch.full_like(logits, torch.finfo(torch.bfloat16).min)
|
| 48 |
logits[self.think_end_token] = 1.0
|