leideng/QCFuse / srt /sampling /custom_logit_processor.py
leideng's picture
download
raw
4.27 kB
import json
from abc import ABC, abstractmethod
from functools import lru_cache
from typing import TYPE_CHECKING, Any, Dict, List, Optional
import dill
import orjson
import torch
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
@lru_cache(maxsize=None)
def _cache_from_str(json_str: str):
"""Deserialize a json string to a Callable object.
This function is cached to avoid redundant deserialization.
"""
data = orjson.loads(json_str)
return dill.loads(bytes.fromhex(data["callable"]))
class CustomLogitProcessor(ABC):
"""Abstract base class for callable functions."""
@abstractmethod
def __call__(
self,
logits: torch.Tensor,
custom_param_list: Optional[List[Dict[str, Any]]] = None,
) -> torch.Tensor:
"""Define the callable behavior."""
raise NotImplementedError
@classmethod
def to_str(cls) -> str:
"""Serialize the callable function to a JSON-compatible string."""
return json.dumps({"callable": dill.dumps(cls).hex()})
@classmethod
def from_str(cls, json_str: str):
"""Deserialize a callable function from a JSON string."""
return _cache_from_str(json_str)()
class DisallowedTokensLogitsProcessor(CustomLogitProcessor):
def __call__(
self,
logits: torch.Tensor,
custom_param_list: Optional[List[Dict[str, Any]]] = None,
) -> torch.Tensor:
disallowed_token_ids = custom_param_list[0]["token_ids"]
assert all(
disallowed_token_ids == c["token_ids"] for c in custom_param_list
), f"{custom_param_list=}"
logits[..., disallowed_token_ids] = -float("inf")
return logits
class ThinkingBudgetLogitProcessor(CustomLogitProcessor):
"""A logit processor that controls the length of thinking."""
THINKING_START_TOKEN_ID: int
THINKING_END_TOKEN_ID: int
NEW_LINE_TOKEN_ID: int
def __call__(self, logits, custom_param_list: list[dict[str, Any]]):
if custom_param_list is None or not custom_param_list:
return logits
for i, param_dict in enumerate(custom_param_list):
if param_dict is None:
continue
thinking_budget: int | None = param_dict.get("thinking_budget")
# Skip if thinking_budget is unset, or not an integer, or negative
if (
thinking_budget is None
or not isinstance(thinking_budget, int)
or thinking_budget < 0
):
continue
req: Req = param_dict.get("__req__")
cur_ids: list[int] = [*req.origin_input_ids, *req.output_ids]
# Check if out of thinking stage
if (
self.THINKING_START_TOKEN_ID not in cur_ids
or self.THINKING_END_TOKEN_ID in cur_ids
):
continue
# Find the index of the thinking start token
start_index = cur_ids.index(self.THINKING_START_TOKEN_ID)
# Count the number of tokens after the thinking start token
num_tokens_after_start = len(cur_ids) - start_index - 1
if num_tokens_after_start < thinking_budget:
continue
# Ensure new line token before thinking end token
if not req.output_ids or req.output_ids[-1] != self.NEW_LINE_TOKEN_ID:
logits[i, :] = -float("inf")
logits[i, self.NEW_LINE_TOKEN_ID] = 0.0
continue
# Assign highest probability to the thinking end token
logits[i, :] = -float("inf")
logits[i, self.THINKING_END_TOKEN_ID] = 0.0
return logits
class Qwen3ThinkingBudgetLogitProcessor(ThinkingBudgetLogitProcessor):
"""A logit processor that controls the length of thinking for Qwen3 models."""
THINKING_START_TOKEN_ID: int = 151667
THINKING_END_TOKEN_ID: int = 151668
NEW_LINE_TOKEN_ID: int = 198
class DeepSeekR1ThinkingBudgetLogitProcessor(ThinkingBudgetLogitProcessor):
"""A logit processor that controls the length of thinking for DeepSeek-R1 models."""
THINKING_START_TOKEN_ID: int = 128798
THINKING_END_TOKEN_ID: int = 128799
NEW_LINE_TOKEN_ID: int = 201

Xet Storage Details

Size:
4.27 kB
·
Xet hash:
97fd3ea087d14336b62273f502e9abc56ebdf3e51888ba7c754f86af0e6579c8

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.