|
|
from typing import List, Optional |
|
|
import torch |
|
|
from vllm.config import VllmConfig |
|
|
from vllm.v1.sample.logits_processor import ( |
|
|
AdapterLogitsProcessor, |
|
|
RequestLogitsProcessor, |
|
|
) |
|
|
from vllm.sampling_params import SamplingParams |
|
|
import os |
|
|
from collections import Counter |
|
|
|
|
|
CHUNK_SIZE=16384 |
|
|
WINDOW_SIZE=256 |
|
|
MAX_REPETATION_COUNT=7 |
|
|
|
|
|
class ThinkLogitsProcessor: |
|
|
def __init__(self, think_end_token = 219406, max_len: int = 131072, ratio: float = 0.95): |
|
|
self.think_end_token = think_end_token |
|
|
self.min_answer_budget = 4096 |
|
|
self.max_len = max_len |
|
|
self.ratio = ratio |
|
|
self.interval = 4096 |
|
|
|
|
|
def find_repeated_ngrams(self, input_ids, n=512): |
|
|
""" |
|
|
input_ids: list of integer tokens |
|
|
n: n-gram size |
|
|
returns dict of {ngram_tuple: count} for repeated n-grams |
|
|
""" |
|
|
ngrams = [tuple(input_ids[i:i+n]) for i in range(0, len(input_ids) - n + 1, WINDOW_SIZE)] |
|
|
freq = Counter(ngrams) |
|
|
return {ng: c for ng, c in freq.items() if c > MAX_REPETATION_COUNT} |
|
|
|
|
|
def __call__( |
|
|
self, |
|
|
prompt_token_ids: List[int], |
|
|
past_token_ids: List[int], |
|
|
logits: torch.Tensor |
|
|
) -> torch.Tensor: |
|
|
if self.think_end_token not in past_token_ids: |
|
|
|
|
|
tokens_since_think = len(past_token_ids) |
|
|
|
|
|
response_budget = max(self.min_answer_budget, int((self.max_len - len(prompt_token_ids)) * (1-self.ratio))) |
|
|
remaining_budget = self.max_len - len(prompt_token_ids) - response_budget - tokens_since_think |
|
|
|
|
|
if 0 >= remaining_budget: |
|
|
logits = torch.full_like(logits, torch.finfo(logits.dtype).min) |
|
|
logits[self.think_end_token] = 1.0 |
|
|
|
|
|
|
|
|
elif len(past_token_ids) % self.interval == 0: |
|
|
|
|
|
if self.find_repeated_ngrams(past_token_ids, n=CHUNK_SIZE): |
|
|
|
|
|
logits = torch.full_like(logits, torch.finfo(logits.dtype).min) |
|
|
logits[self.think_end_token] = 1.0 |
|
|
|
|
|
return logits |
|
|
|
|
|
|
|
|
class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor): |
|
|
def __init__(self, vllm_config: VllmConfig, device: torch.device,is_pin_memory: bool): |
|
|
super().__init__(vllm_config, device, is_pin_memory) |
|
|
self.model_max_len = vllm_config.model_config.max_model_len |
|
|
assert self.model_max_len, "specify --model-max-len if using ratiologitprocessor" |
|
|
self.ratio = float(os.environ.get("VLLM_THINK_BUDGET_RATIO", "0.0")) |
|
|
assert 1 >= self.ratio > 0, "specify env var VLLM_THINK_BUDGET_RATIO in 0.0 < R =< 1.0 if using ratiologitprocessor" |
|
|
|
|
|
def is_argmax_invariant(self) -> bool: |
|
|
return False |
|
|
|
|
|
def new_req_logits_processor( |
|
|
self, |
|
|
params: SamplingParams, |
|
|
) -> Optional[RequestLogitsProcessor]: |
|
|
|
|
|
return ThinkLogitsProcessor(max_len = self.model_max_len, ratio = self.ratio) |
|
|
|