leideng/QCFuse / srt /sampling /sampling_params.py
leideng's picture
download
raw
9.19 kB
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Sampling parameters for text generation."""
import logging
import sre_parse
from typing import Any, Dict, List, Optional, Union
_SAMPLING_EPS = 1e-6
TOP_K_ALL = 1 << 30
logger = logging.getLogger(__name__)
class SamplingParams:
"""
The sampling parameters.
See docs/backend/sampling_params.md or
https://docs.sglang.ai/backend/sampling_params.html
for the documentation.
"""
def __init__(
self,
max_new_tokens: int = 128,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
stop_regex: Optional[Union[str, List[str]]] = None,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
min_p: float = 0.0,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
repetition_penalty: float = 1.0,
min_new_tokens: int = 0,
n: int = 1,
json_schema: Optional[str] = None,
regex: Optional[str] = None,
ebnf: Optional[str] = None,
structural_tag: Optional[str] = None,
ignore_eos: bool = False,
skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True,
no_stop_trim: bool = False,
custom_params: Optional[Dict[str, Any]] = None,
stream_interval: Optional[int] = None,
logit_bias: Optional[Dict[str, float]] = None,
sampling_seed: int = 42,
) -> None:
self.max_new_tokens = max_new_tokens
self.stop_strs = stop
if stop_token_ids:
self.stop_token_ids = set(stop_token_ids)
else:
self.stop_token_ids = None
self.stop_regex_strs = stop_regex
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.min_p = min_p
self.frequency_penalty = frequency_penalty
self.presence_penalty = presence_penalty
self.repetition_penalty = repetition_penalty
self.min_new_tokens = min_new_tokens
self.regex = regex
self.n = n
self.json_schema = json_schema
self.ebnf = ebnf
self.structural_tag = structural_tag
self.ignore_eos = ignore_eos
self.skip_special_tokens = skip_special_tokens
self.spaces_between_special_tokens = spaces_between_special_tokens
self.no_stop_trim = no_stop_trim
self.custom_params = custom_params
self.stream_interval = stream_interval
self.logit_bias = logit_bias
self.sampling_seed = sampling_seed
# Process some special cases
if 0 <= self.temperature < _SAMPLING_EPS:
# top_k = 1 means greedy sampling
self.temperature = 1.0
self.top_k = 1
if self.top_k == -1:
self.top_k = TOP_K_ALL # whole vocabulary
def verify(self, vocab_size):
if self.temperature < 0.0:
raise ValueError(
f"temperature must be non-negative, got {self.temperature}."
)
if not 0.0 < self.top_p <= 1.0:
raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
if not 0.0 <= self.min_p <= 1.0:
raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.")
if self.top_k < 1 or self.top_k == -1:
raise ValueError(
f"top_k must be -1 (disable) or at least 1, got {self.top_k}."
)
if not -2.0 <= self.frequency_penalty <= 2.0:
raise ValueError(
"frequency_penalty must be in [-2, 2], got "
f"{self.frequency_penalty}."
)
if not -2.0 <= self.presence_penalty <= 2.0:
raise ValueError(
"presence_penalty must be in [-2, 2], got " f"{self.presence_penalty}."
)
if not 0.0 <= self.repetition_penalty <= 2.0:
raise ValueError(
"repetition_penalty must be in [0, 2], got "
f"{self.repetition_penalty}."
)
if not 0 <= self.min_new_tokens:
raise ValueError(
f"min_new_tokens must be in [0, max_new_tokens], got "
f"{self.min_new_tokens}."
)
if self.max_new_tokens is not None:
if self.max_new_tokens < 0:
raise ValueError(
f"max_new_tokens must be at least 0, got {self.max_new_tokens}."
)
if not self.min_new_tokens <= self.max_new_tokens:
raise ValueError(
f"min_new_tokens must be in [0, max_new_tokens({self.max_new_tokens})], got "
f"{self.min_new_tokens}."
)
if self.logit_bias is not None:
for token_id in self.logit_bias:
if not 0 <= int(token_id) < vocab_size:
raise ValueError(
f"logit_bias must has keys in [0, {vocab_size - 1}], got "
f"{token_id}."
)
if self.sampling_seed is None:
raise ValueError("sampling_seed should not be None")
grammars = [
self.json_schema,
self.regex,
self.ebnf,
] # since mutually exclusive, only one can be set
if sum(x is not None for x in grammars) > 1:
raise ValueError("Only one of regex, json_schema, or ebnf can be set.")
def normalize(self, tokenizer):
# Process stop strings
if self.stop_strs is None:
self.stop_strs = []
self.stop_str_max_len = 0
else:
if isinstance(self.stop_strs, str):
self.stop_strs = [self.stop_strs]
stop_str_max_len = 0
for stop_str in self.stop_strs:
if tokenizer is not None:
stop_str_ids = tokenizer.encode(stop_str, add_special_tokens=False)
stop_str_max_len = max(stop_str_max_len, len(stop_str_ids))
else:
stop_str_max_len = max(stop_str_max_len, len(stop_str))
self.stop_str_max_len = stop_str_max_len
# Process stop regex strings
if self.stop_regex_strs is None:
self.stop_regex_strs = []
self.stop_regex_max_len = 0
else:
if isinstance(self.stop_regex_strs, str):
self.stop_regex_strs = [self.stop_regex_strs]
stop_regex_max_len = 0
for stop_regex in self.stop_regex_strs:
stop_regex_max_len = max(
stop_regex_max_len, get_max_seq_length(stop_regex)
)
self.stop_regex_max_len = stop_regex_max_len
# This function gets a strict upperbound on the maximum number of tokens that would need
# to be buffered to match the input regex string
# NOTE: in the worst case, one character that needs to be buffered corresponds to one
# token
def get_max_seq_length(regex_str: str):
return _max_length_from_subpattern(sre_parse.parse(regex_str))
MAX_LEN = 2**30
def _max_length_from_subpattern(subpattern: sre_parse.SubPattern):
total = 0
for token, value in subpattern:
if token in {
sre_parse.LITERAL, # `value` is any one character
sre_parse.IN, # Any character within `value`
sre_parse.ANY, # "."
}:
total += 1
elif token == sre_parse.SUBPATTERN:
# EG: (a\d+) ->
# [(SUBPATTERN,
# (1, 0, 0, [(LITERAL, 97),
# (MAX_REPEAT, (1, MAXREPEAT, [(IN, [(CATEGORY, CATEGORY_DIGIT)])]))]))]
_, _, _, inner_subpattern = value
total += _max_length_from_subpattern(inner_subpattern)
elif token == sre_parse.BRANCH:
_, branches = value
total += max(_max_length_from_subpattern(branch) for branch in branches)
elif token in {sre_parse.MAX_REPEAT, sre_parse.MIN_REPEAT}:
_, max_num_repeat, inner_subpattern = value
if max_num_repeat == sre_parse.MAXREPEAT:
total += MAX_LEN
else:
total += max_num_repeat * _max_length_from_subpattern(inner_subpattern)
elif token == sre_parse.AT:
# These are zero-width assertions like ^, $, and \b that don't add to the max
# length
total += 0
else:
logger.warning(f"Got unhandled regex token: {token}")
total += MAX_LEN
return total

Xet Storage Details

Size:
9.19 kB
·
Xet hash:
7e30894d5475850e177bf4ac59fa058b0b6fc7497847087c15c1f32915bf6512

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.