| import json | |
| from abc import ABC, abstractmethod | |
| from functools import lru_cache | |
| from typing import TYPE_CHECKING, Any, Dict, List, Optional | |
| import dill | |
| import orjson | |
| import torch | |
| if TYPE_CHECKING: | |
| from sglang.srt.managers.schedule_batch import Req | |
| def _cache_from_str(json_str: str): | |
| """Deserialize a json string to a Callable object. | |
| This function is cached to avoid redundant deserialization. | |
| """ | |
| data = orjson.loads(json_str) | |
| return dill.loads(bytes.fromhex(data["callable"])) | |
| class CustomLogitProcessor(ABC): | |
| """Abstract base class for callable functions.""" | |
| def __call__( | |
| self, | |
| logits: torch.Tensor, | |
| custom_param_list: Optional[List[Dict[str, Any]]] = None, | |
| ) -> torch.Tensor: | |
| """Define the callable behavior.""" | |
| raise NotImplementedError | |
| def to_str(cls) -> str: | |
| """Serialize the callable function to a JSON-compatible string.""" | |
| return json.dumps({"callable": dill.dumps(cls).hex()}) | |
| def from_str(cls, json_str: str): | |
| """Deserialize a callable function from a JSON string.""" | |
| return _cache_from_str(json_str)() | |
| class DisallowedTokensLogitsProcessor(CustomLogitProcessor): | |
| def __call__( | |
| self, | |
| logits: torch.Tensor, | |
| custom_param_list: Optional[List[Dict[str, Any]]] = None, | |
| ) -> torch.Tensor: | |
| disallowed_token_ids = custom_param_list[0]["token_ids"] | |
| assert all( | |
| disallowed_token_ids == c["token_ids"] for c in custom_param_list | |
| ), f"{custom_param_list=}" | |
| logits[..., disallowed_token_ids] = -float("inf") | |
| return logits | |
| class ThinkingBudgetLogitProcessor(CustomLogitProcessor): | |
| """A logit processor that controls the length of thinking.""" | |
| THINKING_START_TOKEN_ID: int | |
| THINKING_END_TOKEN_ID: int | |
| NEW_LINE_TOKEN_ID: int | |
| def __call__(self, logits, custom_param_list: list[dict[str, Any]]): | |
| if custom_param_list is None or not custom_param_list: | |
| return logits | |
| for i, param_dict in enumerate(custom_param_list): | |
| if param_dict is None: | |
| continue | |
| thinking_budget: int | None = param_dict.get("thinking_budget") | |
| # Skip if thinking_budget is unset, or not an integer, or negative | |
| if ( | |
| thinking_budget is None | |
| or not isinstance(thinking_budget, int) | |
| or thinking_budget < 0 | |
| ): | |
| continue | |
| req: Req = param_dict.get("__req__") | |
| cur_ids: list[int] = [*req.origin_input_ids, *req.output_ids] | |
| # Check if out of thinking stage | |
| if ( | |
| self.THINKING_START_TOKEN_ID not in cur_ids | |
| or self.THINKING_END_TOKEN_ID in cur_ids | |
| ): | |
| continue | |
| # Find the index of the thinking start token | |
| start_index = cur_ids.index(self.THINKING_START_TOKEN_ID) | |
| # Count the number of tokens after the thinking start token | |
| num_tokens_after_start = len(cur_ids) - start_index - 1 | |
| if num_tokens_after_start < thinking_budget: | |
| continue | |
| # Ensure new line token before thinking end token | |
| if not req.output_ids or req.output_ids[-1] != self.NEW_LINE_TOKEN_ID: | |
| logits[i, :] = -float("inf") | |
| logits[i, self.NEW_LINE_TOKEN_ID] = 0.0 | |
| continue | |
| # Assign highest probability to the thinking end token | |
| logits[i, :] = -float("inf") | |
| logits[i, self.THINKING_END_TOKEN_ID] = 0.0 | |
| return logits | |
| class Qwen3ThinkingBudgetLogitProcessor(ThinkingBudgetLogitProcessor): | |
| """A logit processor that controls the length of thinking for Qwen3 models.""" | |
| THINKING_START_TOKEN_ID: int = 151667 | |
| THINKING_END_TOKEN_ID: int = 151668 | |
| NEW_LINE_TOKEN_ID: int = 198 | |
| class DeepSeekR1ThinkingBudgetLogitProcessor(ThinkingBudgetLogitProcessor): | |
| """A logit processor that controls the length of thinking for DeepSeek-R1 models.""" | |
| THINKING_START_TOKEN_ID: int = 128798 | |
| THINKING_END_TOKEN_ID: int = 128799 | |
| NEW_LINE_TOKEN_ID: int = 201 | |
Xet Storage Details
- Size:
- 4.27 kB
- Xet hash:
- 97fd3ea087d14336b62273f502e9abc56ebdf3e51888ba7c754f86af0e6579c8
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.