leejunhyeok's picture
change n-gram and ratio check order
d63f78d verified
raw
history blame
3.05 kB
from typing import List, Optional
import torch
from vllm.config import VllmConfig
from vllm.v1.sample.logits_processor import (
AdapterLogitsProcessor,
RequestLogitsProcessor,
)
from vllm.sampling_params import SamplingParams
import os
from collections import Counter
CHUNK_SIZE=16384
WINDOW_SIZE=256
MAX_REPETATION_COUNT=7
class ThinkLogitsProcessor:
def __init__(self, think_end_token = 219406, max_len: int = 131072, ratio: float = 0.95):
self.think_end_token = think_end_token
self.min_answer_budget = 4096
self.max_len = max_len
self.ratio = ratio
self.interval = 4096
def find_repeated_ngrams(self, input_ids, n=512):
"""
input_ids: list of integer tokens
n: n-gram size
returns dict of {ngram_tuple: count} for repeated n-grams
"""
ngrams = [tuple(input_ids[i:i+n]) for i in range(0, len(input_ids) - n + 1, WINDOW_SIZE)]
freq = Counter(ngrams)
return {ng: c for ng, c in freq.items() if c > MAX_REPETATION_COUNT}
def __call__(
self,
prompt_token_ids: List[int],
past_token_ids: List[int],
logits: torch.Tensor
) -> torch.Tensor:
if self.think_end_token not in past_token_ids:
# ratio
tokens_since_think = len(past_token_ids)
response_budget = max(self.min_answer_budget, int((self.max_len - len(prompt_token_ids)) * (1-self.ratio)))
remaining_budget = self.max_len - len(prompt_token_ids) - response_budget - tokens_since_think
if 0 >= remaining_budget:
logits = torch.full_like(logits, torch.finfo(logits.dtype).min)
logits[self.think_end_token] = 1.0
# ngram
elif len(past_token_ids) % self.interval == 0:
# If repetation detected, force </think>
if self.find_repeated_ngrams(past_token_ids, n=CHUNK_SIZE):
# Set all other logits to -inf except for </think>
logits = torch.full_like(logits, torch.finfo(logits.dtype).min)
logits[self.think_end_token] = 1.0
return logits
class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
def __init__(self, vllm_config: VllmConfig, device: torch.device,is_pin_memory: bool):
super().__init__(vllm_config, device, is_pin_memory)
self.model_max_len = vllm_config.model_config.max_model_len
assert self.model_max_len, "specify --model-max-len if using ratiologitprocessor"
self.ratio = float(os.environ.get("VLLM_THINK_BUDGET_RATIO", "0.0"))
assert 1 >= self.ratio > 0, "specify env var VLLM_THINK_BUDGET_RATIO in 0.0 < R =< 1.0 if using ratiologitprocessor"
def is_argmax_invariant(self) -> bool:
return False
def new_req_logits_processor(
self,
params: SamplingParams,
) -> Optional[RequestLogitsProcessor]:
return ThinkLogitsProcessor(max_len = self.model_max_len, ratio = self.ratio)