leideng/QCFuse / srt /utils /hf_transformers_utils.py
leideng's picture
download
raw
17.7 kB
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for Huggingface Transformers."""
import contextlib
import json
import os
import tempfile
import warnings
from pathlib import Path
from typing import Any, Dict, List, Optional, Type, Union
import torch
from huggingface_hub import snapshot_download
from transformers import (
AutoConfig,
AutoProcessor,
AutoTokenizer,
GenerationConfig,
PretrainedConfig,
PreTrainedTokenizer,
PreTrainedTokenizerBase,
PreTrainedTokenizerFast,
)
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from sglang.srt.configs import (
ChatGLMConfig,
DbrxConfig,
DeepseekVL2Config,
DotsOCRConfig,
DotsVLMConfig,
ExaoneConfig,
FalconH1Config,
KimiVLConfig,
LongcatFlashConfig,
MultiModalityConfig,
NemotronHConfig,
Olmo3Config,
Qwen3NextConfig,
Step3VLConfig,
)
from sglang.srt.configs.deepseek_ocr import DeepseekVLV2Config
from sglang.srt.configs.internvl import InternVLChatConfig
from sglang.srt.connector import create_remote_connector
from sglang.srt.utils import is_remote_url, logger, lru_cache_frozenset
_CONFIG_REGISTRY: List[Type[PretrainedConfig]] = [
ChatGLMConfig,
DbrxConfig,
ExaoneConfig,
DeepseekVL2Config,
MultiModalityConfig,
KimiVLConfig,
InternVLChatConfig,
Step3VLConfig,
LongcatFlashConfig,
Olmo3Config,
Qwen3NextConfig,
FalconH1Config,
DotsVLMConfig,
DotsOCRConfig,
NemotronHConfig,
DeepseekVLV2Config,
]
_CONFIG_REGISTRY = {
config_cls.model_type: config_cls for config_cls in _CONFIG_REGISTRY
}
for name, cls in _CONFIG_REGISTRY.items():
with contextlib.suppress(ValueError):
AutoConfig.register(name, cls)
def download_from_hf(
model_path: str,
allow_patterns: Optional[Union[str, list]] = None,
):
if os.path.exists(model_path):
return model_path
if not allow_patterns:
allow_patterns = ["*.json", "*.bin", "*.model"]
return snapshot_download(model_path, allow_patterns=allow_patterns)
def get_hf_text_config(config: PretrainedConfig):
"""Get the "sub" config relevant to llm for multi modal models.
No op for pure text models.
"""
if config.architectures is not None:
class_name = config.architectures[0]
if class_name.startswith("Llava") and class_name.endswith("ForCausalLM"):
# We support non-hf version of llava models, so we do not want to
# read the wrong values from the unused default text_config.
# NOTE(HandH1998): We set `torch_dtype` of config to `torch.float16` for the weights, as
# `torch.float16` is default used for image features in `python/sglang/srt/models/llava.py`.
setattr(config, "torch_dtype", torch.float16)
return config
if hasattr(config, "text_config"):
# The code operates under the assumption that text_config should have
# `num_attention_heads` (among others). Assert here to fail early
# if transformers config doesn't align with this assumption.
assert hasattr(config.text_config, "num_attention_heads")
return config.text_config
if hasattr(config, "llm_config"):
# PointsV1.5 Chat Model
assert hasattr(config.llm_config, "num_attention_heads")
return config.llm_config
if hasattr(config, "language_config"):
return config.language_config
if hasattr(config, "thinker_config"):
# qwen2.5 omni
thinker_config = config.thinker_config
if hasattr(thinker_config, "text_config"):
setattr(
thinker_config.text_config,
"torch_dtype",
getattr(thinker_config, "torch_dtype", None),
)
return thinker_config.text_config
return thinker_config
else:
return config
# Temporary hack for DeepSeek-V3.2 model
def _load_deepseek_v32_model(
model_path: str,
trust_remote_code: bool = False,
revision: Optional[str] = None,
**kwargs,
):
# first get the local path
local_path = download_from_hf(model_path)
# then load the config file in json
config_file = os.path.join(local_path, "config.json")
if not os.path.exists(config_file):
raise RuntimeError(f"Can't find config file in {local_path}.")
with open(config_file, "r") as f:
config_json = json.load(f)
config_json["architectures"] = ["DeepseekV3ForCausalLM"]
config_json["model_type"] = "deepseek_v3"
tmp_path = os.path.join(tempfile.gettempdir(), "_tmp_config_folder")
os.makedirs(tmp_path, exist_ok=True)
unique_path = os.path.join(tmp_path, f"deepseek_v32_{os.getpid()}")
with open(unique_path, "w") as f:
json.dump(config_json, f)
return AutoConfig.from_pretrained(
unique_path, trust_remote_code=trust_remote_code, revision=revision, **kwargs
)
@lru_cache_frozenset(maxsize=32)
def get_config(
model: str,
trust_remote_code: bool,
revision: Optional[str] = None,
model_override_args: Optional[dict] = None,
**kwargs,
):
is_gguf = check_gguf_file(model)
if is_gguf:
kwargs["gguf_file"] = model
model = Path(model).parent
if is_remote_url(model):
# BaseConnector implements __del__() to clean up the local dir.
# Since config files need to exist all the time, so we DO NOT use
# with statement to avoid closing the client.
client = create_remote_connector(model)
client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
model = client.get_local_dir()
try:
config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
)
if "deepseek-ai/DeepSeek-OCR" in model:
config.model_type = "deepseek-ocr"
# Due to an unknown reason, Hugging Face’s AutoConfig mistakenly recognizes the configuration of deepseek-ocr as deepseekvl2.
# This is a temporary workaround and will require further optimization.
except ValueError as e:
if not "deepseek_v32" in str(e):
raise e
config = _load_deepseek_v32_model(
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
)
if (
config.architectures is not None
and config.architectures[0] == "Phi4MMForCausalLM"
):
# Phi4MMForCausalLM uses a hard-coded vision_config. See:
# https://github.com/vllm-project/vllm/blob/6071e989df1531b59ef35568f83f7351afb0b51e/vllm/model_executor/models/phi4mm.py#L71
# We set it here to support cases where num_attention_heads is not divisible by the TP size.
from transformers import SiglipVisionConfig
vision_config = {
"hidden_size": 1152,
"image_size": 448,
"intermediate_size": 4304,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_hidden_layers": 26,
# Model is originally 27-layer, we only need the first 26 layers for feature extraction.
"patch_size": 14,
}
config.vision_config = SiglipVisionConfig(**vision_config)
text_config = get_hf_text_config(config=config)
if isinstance(model, str) and text_config is not None:
for key, val in text_config.__dict__.items():
if not hasattr(config, key) and getattr(text_config, key, None) is not None:
setattr(config, key, val)
if config.model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[config.model_type]
config = config_class.from_pretrained(model, revision=revision)
# NOTE(HandH1998): Qwen2VL requires `_name_or_path` attribute in `config`.
setattr(config, "_name_or_path", model)
if isinstance(model, str) and config.model_type == "internvl_chat":
for key, val in config.llm_config.__dict__.items():
if not hasattr(config, key):
setattr(config, key, val)
if config.model_type == "multi_modality":
config.update({"architectures": ["MultiModalityCausalLM"]})
if model_override_args:
config.update(model_override_args)
# Special architecture mapping check for GGUF models
if is_gguf:
if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
raise RuntimeError(f"Can't get gguf config for {config.model_type}.")
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
config.update({"architectures": [model_type]})
return config
@lru_cache_frozenset(maxsize=32)
def get_generation_config(
model: str,
trust_remote_code: bool,
revision: Optional[str] = None,
**kwargs,
):
try:
return GenerationConfig.from_pretrained(
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
)
except OSError as e:
return None
# Qwen-1M related
def get_sparse_attention_config(
model: str,
sparse_attention_config_filename: str = "sparse_attention_config.json",
) -> Dict[str, Any]:
is_local = os.path.isdir(model)
if not is_local:
# Download the config files.
model = download_from_hf(model, allow_patterns=["*.json"])
config_file = os.path.join(model, sparse_attention_config_filename)
if not os.path.exists(config_file):
return {}
# Load the sparse attention config.
with open(config_file) as f:
config = json.load(f)
return config
# Models don't use the same configuration key for determining the maximum
# context length. Store them here so we can sanely check them.
# NOTE: The ordering here is important. Some models have two of these and we
# have a preference for which value gets used.
CONTEXT_LENGTH_KEYS = [
"max_sequence_length",
"seq_length",
"max_seq_len",
"model_max_length",
"max_position_embeddings",
]
def get_context_length(config):
"""Get the context length of a model from a huggingface model configs."""
text_config = config
rope_scaling = getattr(text_config, "rope_scaling", None)
if rope_scaling:
rope_scaling_factor = rope_scaling.get("factor", 1)
if "original_max_position_embeddings" in rope_scaling:
rope_scaling_factor = 1
if rope_scaling.get("rope_type", None) == "llama3":
rope_scaling_factor = 1
else:
rope_scaling_factor = 1
for key in CONTEXT_LENGTH_KEYS:
val = getattr(text_config, key, None)
if val is not None:
return int(rope_scaling_factor * val)
return 2048
# A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file.
_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"
def get_tokenizer(
tokenizer_name: str,
*args,
tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
tokenizer_revision: Optional[str] = None,
**kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
"""Gets a tokenizer for the given model name via Huggingface."""
if tokenizer_name.endswith(".json"):
from sglang.srt.tokenizer.tiktoken_tokenizer import TiktokenTokenizer
return TiktokenTokenizer(tokenizer_name)
if tokenizer_mode == "slow":
if kwargs.get("use_fast", False):
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
kwargs["use_fast"] = False
# TODO(Xinyuan): Remove this once we have a proper tokenizer for Devstral
if tokenizer_name == "mistralai/Devstral-Small-2505":
tokenizer_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
is_gguf = check_gguf_file(tokenizer_name)
if is_gguf:
kwargs["gguf_file"] = tokenizer_name
tokenizer_name = Path(tokenizer_name).parent
if is_remote_url(tokenizer_name):
# BaseConnector implements __del__() to clean up the local dir.
# Since config files need to exist all the time, so we DO NOT use
# with statement to avoid closing the client.
client = create_remote_connector(tokenizer_name)
client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
tokenizer_name = client.get_local_dir()
try:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
*args,
trust_remote_code=trust_remote_code,
tokenizer_revision=tokenizer_revision,
clean_up_tokenization_spaces=False,
**kwargs,
)
except TypeError as e:
# The LLaMA tokenizer causes a protobuf error in some environments.
err_msg = (
"Failed to load the tokenizer. If you are using a LLaMA V1 model "
f"consider using '{_FAST_LLAMA_TOKENIZER}' instead of the "
"original tokenizer."
)
raise RuntimeError(err_msg) from e
except ValueError as e:
# If the error pertains to the tokenizer class not existing or not
# currently being imported, suggest using the --trust-remote-code flag.
if not trust_remote_code and (
"does not exist or is not currently imported." in str(e)
or "requires you to execute the tokenizer file" in str(e)
):
err_msg = (
"Failed to load the tokenizer. If the tokenizer is a custom "
"tokenizer not yet available in the HuggingFace transformers "
"library, consider setting `trust_remote_code=True` in LLM "
"or using the `--trust-remote-code` flag in the CLI."
)
raise RuntimeError(err_msg) from e
else:
raise e
if not isinstance(tokenizer, PreTrainedTokenizerFast):
warnings.warn(
"Using a slow tokenizer. This might cause a significant "
"slowdown. Consider using a fast tokenizer instead."
)
attach_additional_stop_token_ids(tokenizer)
return tokenizer
# Some models doesn't have an available processor, e.g.: InternVL
def get_tokenizer_from_processor(processor):
if isinstance(processor, PreTrainedTokenizerBase):
return processor
return processor.tokenizer
def get_processor(
tokenizer_name: str,
*args,
tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
tokenizer_revision: Optional[str] = None,
use_fast: Optional[bool] = True,
**kwargs,
):
# pop 'revision' from kwargs if present.
revision = kwargs.pop("revision", tokenizer_revision)
config = AutoConfig.from_pretrained(
tokenizer_name,
trust_remote_code=trust_remote_code,
revision=revision,
**kwargs,
)
# fix: for Qwen2-VL and Sarashina2Vision models, inject default 'size' if not provided.
if config.model_type in {"qwen2_vl", "sarashina2_vision"}:
if "size" not in kwargs:
kwargs["size"] = {"shortest_edge": 3136, "longest_edge": 1003520}
if config.model_type not in {"llava", "clip"}:
kwargs["use_fast"] = use_fast
try:
if "InternVL3_5" in tokenizer_name:
processor = AutoTokenizer.from_pretrained(
tokenizer_name,
*args,
trust_remote_code=trust_remote_code,
revision=revision,
**kwargs,
)
else:
processor = AutoProcessor.from_pretrained(
tokenizer_name,
*args,
trust_remote_code=trust_remote_code,
revision=revision,
**kwargs,
)
except ValueError as e:
error_message = str(e)
if "does not have a slow version" in error_message:
logger.info(
f"Processor {tokenizer_name} does not have a slow version. Automatically use fast version"
)
kwargs["use_fast"] = True
processor = AutoProcessor.from_pretrained(
tokenizer_name,
*args,
trust_remote_code=trust_remote_code,
revision=revision,
**kwargs,
)
else:
raise e
tokenizer = get_tokenizer_from_processor(processor)
attach_additional_stop_token_ids(tokenizer)
return processor
def attach_additional_stop_token_ids(tokenizer):
# Special handling for stop token <|eom_id|> generated by llama 3 tool use.
if "<|eom_id|>" in tokenizer.get_added_vocab():
tokenizer.additional_stop_token_ids = set(
[tokenizer.get_added_vocab()["<|eom_id|>"]]
)
else:
tokenizer.additional_stop_token_ids = None
def check_gguf_file(model: Union[str, os.PathLike]) -> bool:
"""Check if the file is a GGUF model."""
model = Path(model)
if not model.is_file():
return False
elif model.suffix == ".gguf":
return True
with open(model, "rb") as f:
header = f.read(4)
return header == b"GGUF"

Xet Storage Details

Size:
17.7 kB
·
Xet hash:
f5f6c4d43e70b98545902a410d4c88003455701669edc881e7af051bddb533aa

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.