| |
| |
| """ |
| Whenever you add an architecture to this page, please also update |
| `tests/models/registry.py` with example HuggingFace models for it. |
| """ |
| import importlib |
| import os |
| import pickle |
| import subprocess |
| import sys |
| import tempfile |
| from abc import ABC, abstractmethod |
| from collections.abc import Set |
| from dataclasses import dataclass, field |
| from functools import lru_cache |
| from typing import Callable, Optional, TypeVar, Union |
|
|
| import cloudpickle |
| import torch.nn as nn |
|
|
| from vllm.logger import init_logger |
|
|
| from .interfaces import (has_inner_state, has_noops, is_attention_free, |
| is_hybrid, supports_cross_encoding, |
| supports_multimodal, supports_pp, |
| supports_transcription, supports_v0_only) |
| from .interfaces_base import is_text_generation_model |
|
|
| logger = init_logger(__name__) |
|
|
| |
| _TEXT_GENERATION_MODELS = { |
| |
| "AquilaModel": ("llama", "LlamaForCausalLM"), |
| "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), |
| "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"), |
| "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"), |
| "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"), |
| |
| "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), |
| |
| "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), |
| "BambaForCausalLM": ("bamba", "BambaForCausalLM"), |
| "BloomForCausalLM": ("bloom", "BloomForCausalLM"), |
| "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), |
| "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"), |
| "CohereForCausalLM": ("commandr", "CohereForCausalLM"), |
| "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"), |
| "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"), |
| "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"), |
| "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"), |
| "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"), |
| "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"), |
| "UTUV1ForCausalLM": ("utu_v1", "UTUV1ForCausalLM"), |
| "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"), |
| "Ernie4_5_ForCausalLM": ("ernie45", "Ernie4_5_ForCausalLM"), |
| "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"), |
| "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"), |
| "FalconForCausalLM": ("falcon", "FalconForCausalLM"), |
| "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"), |
| "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"), |
| "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"), |
| "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"), |
| |
| "Gemma3nForConditionalGeneration": ("gemma3n", "Gemma3nForConditionalGeneration"), |
| "GlmForCausalLM": ("glm", "GlmForCausalLM"), |
| "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"), |
| "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"), |
| "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"), |
| "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"), |
| "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"), |
| "GraniteForCausalLM": ("granite", "GraniteForCausalLM"), |
| "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"), |
| "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"), |
| "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"), |
| "GritLM": ("gritlm", "GritLM"), |
| "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"), |
| "HunYuanMoEV1ForCausalLM": ("hunyuan_v1_moe", "HunYuanMoEV1ForCausalLM"), |
| "InternLMForCausalLM": ("llama", "LlamaForCausalLM"), |
| "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"), |
| "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"), |
| "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"), |
| "JAISLMHeadModel": ("jais", "JAISLMHeadModel"), |
| "JambaForCausalLM": ("jamba", "JambaForCausalLM"), |
| "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), |
| |
| "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), |
| "MambaForCausalLM": ("mamba", "MambaForCausalLM"), |
| "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"), |
| "FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"), |
| "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"), |
| "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"), |
| "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"), |
| "MistralForCausalLM": ("llama", "LlamaForCausalLM"), |
| "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), |
| "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"), |
| |
| "MptForCausalLM": ("mpt", "MPTForCausalLM"), |
| "MPTForCausalLM": ("mpt", "MPTForCausalLM"), |
| "MiMoForCausalLM": ("mimo", "MiMoForCausalLM"), |
| "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"), |
| "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"), |
| "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"), |
| "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"), |
| "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"), |
| "OPTForCausalLM": ("opt", "OPTForCausalLM"), |
| "OrionForCausalLM": ("orion", "OrionForCausalLM"), |
| "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"), |
| "PhiForCausalLM": ("phi", "PhiForCausalLM"), |
| "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), |
| "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), |
| "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), |
| "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"), |
| "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), |
| "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), |
| "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"), |
| "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"), |
| "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"), |
| "RWForCausalLM": ("falcon", "FalconForCausalLM"), |
| "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), |
| "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), |
| "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), |
| "SolarForCausalLM": ("solar", "SolarForCausalLM"), |
| "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"), |
| "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"), |
| "XverseForCausalLM": ("llama", "LlamaForCausalLM"), |
| "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"), |
| |
| "BartModel": ("bart", "BartForConditionalGeneration"), |
| "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"), |
| } |
|
|
| _EMBEDDING_MODELS = { |
| |
| "BertModel": ("bert", "BertEmbeddingModel"), |
| "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"), |
| "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"), |
| "GlmForCausalLM": ("glm", "GlmForCausalLM"), |
| "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"), |
| "GritLM": ("gritlm", "GritLM"), |
| "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"), |
| "GteNewModel": ("bert_with_rope", "GteNewModel"), |
| "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"), |
| "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"), |
| "LlamaModel": ("llama", "LlamaForCausalLM"), |
| **{ |
| |
| k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items() |
| if arch == "LlamaForCausalLM" |
| }, |
| "MistralModel": ("llama", "LlamaForCausalLM"), |
| "ModernBertModel": ("modernbert", "ModernBertModel"), |
| "NomicBertModel": ("bert_with_rope", "NomicBertModel"), |
| "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), |
| "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"), |
| "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), |
| "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"), |
| "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"), |
| "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"), |
| "RobertaModel": ("roberta", "RobertaEmbeddingModel"), |
| "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"), |
| "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"), |
| |
| "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), |
| "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), |
| "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), |
| |
| |
| |
| "PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"), |
| } |
|
|
| _CROSS_ENCODER_MODELS = { |
| "BertForSequenceClassification": ("bert", "BertForSequenceClassification"), |
| "RobertaForSequenceClassification": ("roberta", |
| "RobertaForSequenceClassification"), |
| "XLMRobertaForSequenceClassification": ("roberta", |
| "RobertaForSequenceClassification"), |
| "ModernBertForSequenceClassification": ("modernbert", |
| "ModernBertForSequenceClassification"), |
| |
| "Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForSequenceClassification"), |
| "Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"), |
| } |
|
|
| _MULTIMODAL_MODELS = { |
| |
| "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"), |
| "AyaVisionForConditionalGeneration": ("aya_vision", "AyaVisionForConditionalGeneration"), |
| "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"), |
| "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), |
| "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"), |
| "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), |
| "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), |
| "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"), |
| "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"), |
| "GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"), |
| "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"), |
| "InternVLChatModel": ("internvl", "InternVLChatModel"), |
| "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"), |
| "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"), |
| "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"), |
| "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), |
| "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"), |
| "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), |
| "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), |
| "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), |
| "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"), |
| "MiniMaxVL01ForConditionalGeneration": ("minimax_vl_01", "MiniMaxVL01ForConditionalGeneration"), |
| "MiniCPMO": ("minicpmo", "MiniCPMO"), |
| "MiniCPMV": ("minicpmv", "MiniCPMV"), |
| "Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"), |
| "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"), |
| "NVLM_D": ("nvlm_d", "NVLM_D_Model"), |
| "Ovis": ("ovis", "Ovis"), |
| "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), |
| "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), |
| "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), |
| "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"), |
| "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), |
| "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), |
| "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), |
| "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), |
| "Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), |
| "UltravoxModel": ("ultravox", "UltravoxModel"), |
| "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"), |
| "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"), |
| "Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"), |
| |
| "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), |
| "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), |
| "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"), |
| "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"), |
| "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), |
| } |
|
|
| _SPECULATIVE_DECODING_MODELS = { |
| "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"), |
| "EAGLEModel": ("eagle", "EAGLE"), |
| "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"), |
| "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"), |
| "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), |
| "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), |
| "MedusaModel": ("medusa", "Medusa"), |
| "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), |
| } |
|
|
| _TRANSFORMERS_MODELS = { |
| "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"), |
| } |
| |
|
|
| _VLLM_MODELS = { |
| **_TEXT_GENERATION_MODELS, |
| **_EMBEDDING_MODELS, |
| **_CROSS_ENCODER_MODELS, |
| **_MULTIMODAL_MODELS, |
| **_SPECULATIVE_DECODING_MODELS, |
| **_TRANSFORMERS_MODELS, |
| } |
|
|
| |
| |
| |
| |
| _SUBPROCESS_COMMAND = [ |
| sys.executable, "-m", "vllm.model_executor.models.registry" |
| ] |
|
|
|
|
| @dataclass(frozen=True) |
| class _ModelInfo: |
| architecture: str |
| is_text_generation_model: bool |
| is_pooling_model: bool |
| supports_cross_encoding: bool |
| supports_multimodal: bool |
| supports_pp: bool |
| has_inner_state: bool |
| is_attention_free: bool |
| is_hybrid: bool |
| has_noops: bool |
| supports_transcription: bool |
| supports_v0_only: bool |
|
|
| @staticmethod |
| def from_model_cls(model: type[nn.Module]) -> "_ModelInfo": |
| return _ModelInfo( |
| architecture=model.__name__, |
| is_text_generation_model=is_text_generation_model(model), |
| is_pooling_model=True, |
| supports_cross_encoding=supports_cross_encoding(model), |
| supports_multimodal=supports_multimodal(model), |
| supports_pp=supports_pp(model), |
| has_inner_state=has_inner_state(model), |
| is_attention_free=is_attention_free(model), |
| is_hybrid=is_hybrid(model), |
| supports_transcription=supports_transcription(model), |
| supports_v0_only=supports_v0_only(model), |
| has_noops=has_noops(model), |
| ) |
|
|
|
|
| class _BaseRegisteredModel(ABC): |
|
|
| @abstractmethod |
| def inspect_model_cls(self) -> _ModelInfo: |
| raise NotImplementedError |
|
|
| @abstractmethod |
| def load_model_cls(self) -> type[nn.Module]: |
| raise NotImplementedError |
|
|
|
|
| @dataclass(frozen=True) |
| class _RegisteredModel(_BaseRegisteredModel): |
| """ |
| Represents a model that has already been imported in the main process. |
| """ |
|
|
| interfaces: _ModelInfo |
| model_cls: type[nn.Module] |
|
|
| @staticmethod |
| def from_model_cls(model_cls: type[nn.Module]): |
| return _RegisteredModel( |
| interfaces=_ModelInfo.from_model_cls(model_cls), |
| model_cls=model_cls, |
| ) |
|
|
| def inspect_model_cls(self) -> _ModelInfo: |
| return self.interfaces |
|
|
| def load_model_cls(self) -> type[nn.Module]: |
| return self.model_cls |
|
|
|
|
| @dataclass(frozen=True) |
| class _LazyRegisteredModel(_BaseRegisteredModel): |
| """ |
| Represents a model that has not been imported in the main process. |
| """ |
| module_name: str |
| class_name: str |
|
|
| |
| def inspect_model_cls(self) -> _ModelInfo: |
| return _run_in_subprocess( |
| lambda: _ModelInfo.from_model_cls(self.load_model_cls())) |
|
|
| def load_model_cls(self) -> type[nn.Module]: |
| mod = importlib.import_module(self.module_name) |
| return getattr(mod, self.class_name) |
|
|
|
|
| @lru_cache(maxsize=128) |
| def _try_load_model_cls( |
| model_arch: str, |
| model: _BaseRegisteredModel, |
| ) -> Optional[type[nn.Module]]: |
| from vllm.platforms import current_platform |
| current_platform.verify_model_arch(model_arch) |
| try: |
| return model.load_model_cls() |
| except Exception: |
| logger.exception("Error in loading model architecture '%s'", |
| model_arch) |
| return None |
|
|
|
|
| @lru_cache(maxsize=128) |
| def _try_inspect_model_cls( |
| model_arch: str, |
| model: _BaseRegisteredModel, |
| ) -> Optional[_ModelInfo]: |
| try: |
| return model.inspect_model_cls() |
| except Exception: |
| logger.exception("Error in inspecting model architecture '%s'", |
| model_arch) |
| return None |
|
|
|
|
| @dataclass |
| class _ModelRegistry: |
| |
| models: dict[str, _BaseRegisteredModel] = field(default_factory=dict) |
|
|
| def get_supported_archs(self) -> Set[str]: |
| return self.models.keys() |
|
|
| def register_model( |
| self, |
| model_arch: str, |
| model_cls: Union[type[nn.Module], str], |
| ) -> None: |
| """ |
| Register an external model to be used in vLLM. |
| |
| `model_cls` can be either: |
| |
| - A [`torch.nn.Module`][] class directly referencing the model. |
| - A string in the format `<module>:<class>` which can be used to |
| lazily import the model. This is useful to avoid initializing CUDA |
| when importing the model and thus the related error |
| `RuntimeError: Cannot re-initialize CUDA in forked subprocess`. |
| """ |
| if not isinstance(model_arch, str): |
| msg = f"`model_arch` should be a string, not a {type(model_arch)}" |
| raise TypeError(msg) |
|
|
| if model_arch in self.models: |
| logger.warning( |
| "Model architecture %s is already registered, and will be " |
| "overwritten by the new model class %s.", model_arch, |
| model_cls) |
|
|
| if isinstance(model_cls, str): |
| split_str = model_cls.split(":") |
| if len(split_str) != 2: |
| msg = "Expected a string in the format `<module>:<class>`" |
| raise ValueError(msg) |
|
|
| model = _LazyRegisteredModel(*split_str) |
| elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module): |
| model = _RegisteredModel.from_model_cls(model_cls) |
| else: |
| msg = ("`model_cls` should be a string or PyTorch model class, " |
| f"not a {type(model_arch)}") |
| raise TypeError(msg) |
|
|
| self.models[model_arch] = model |
|
|
| def _raise_for_unsupported(self, architectures: list[str]): |
| all_supported_archs = self.get_supported_archs() |
|
|
| if any(arch in all_supported_archs for arch in architectures): |
| raise ValueError( |
| f"Model architectures {architectures} failed " |
| "to be inspected. Please check the logs for more details.") |
|
|
| raise ValueError( |
| f"Model architectures {architectures} are not supported for now. " |
| f"Supported architectures: {all_supported_archs}") |
|
|
| def _try_load_model_cls(self, |
| model_arch: str) -> Optional[type[nn.Module]]: |
| if model_arch not in self.models: |
| return None |
|
|
| return _try_load_model_cls(model_arch, self.models[model_arch]) |
|
|
| def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]: |
| if model_arch not in self.models: |
| return None |
|
|
| return _try_inspect_model_cls(model_arch, self.models[model_arch]) |
|
|
| def _normalize_archs( |
| self, |
| architectures: Union[str, list[str]], |
| ) -> list[str]: |
| if isinstance(architectures, str): |
| architectures = [architectures] |
| if not architectures: |
| logger.warning("No model architectures are specified") |
|
|
| |
| normalized_arch = list( |
| filter(lambda model: model in self.models, architectures)) |
|
|
| |
| if len(normalized_arch) != len(architectures): |
| normalized_arch.append("TransformersForCausalLM") |
| return normalized_arch |
|
|
| def inspect_model_cls( |
| self, |
| architectures: Union[str, list[str]], |
| ) -> tuple[_ModelInfo, str]: |
| architectures = self._normalize_archs(architectures) |
|
|
| for arch in architectures: |
| model_info = self._try_inspect_model_cls(arch) |
| if model_info is not None: |
| return (model_info, arch) |
|
|
| return self._raise_for_unsupported(architectures) |
|
|
| def resolve_model_cls( |
| self, |
| architectures: Union[str, list[str]], |
| ) -> tuple[type[nn.Module], str]: |
| architectures = self._normalize_archs(architectures) |
|
|
| for arch in architectures: |
| model_cls = self._try_load_model_cls(arch) |
| if model_cls is not None: |
| return (model_cls, arch) |
|
|
| return self._raise_for_unsupported(architectures) |
|
|
| def is_text_generation_model( |
| self, |
| architectures: Union[str, list[str]], |
| ) -> bool: |
| model_cls, _ = self.inspect_model_cls(architectures) |
| return model_cls.is_text_generation_model |
|
|
| def is_pooling_model( |
| self, |
| architectures: Union[str, list[str]], |
| ) -> bool: |
| model_cls, _ = self.inspect_model_cls(architectures) |
| return model_cls.is_pooling_model |
|
|
| def is_cross_encoder_model( |
| self, |
| architectures: Union[str, list[str]], |
| ) -> bool: |
| model_cls, _ = self.inspect_model_cls(architectures) |
| return model_cls.supports_cross_encoding |
|
|
| def is_multimodal_model( |
| self, |
| architectures: Union[str, list[str]], |
| ) -> bool: |
| model_cls, _ = self.inspect_model_cls(architectures) |
| return model_cls.supports_multimodal |
|
|
| def is_pp_supported_model( |
| self, |
| architectures: Union[str, list[str]], |
| ) -> bool: |
| model_cls, _ = self.inspect_model_cls(architectures) |
| return model_cls.supports_pp |
|
|
| def model_has_inner_state( |
| self, |
| architectures: Union[str, list[str]], |
| ) -> bool: |
| model_cls, _ = self.inspect_model_cls(architectures) |
| return model_cls.has_inner_state |
|
|
| def is_attention_free_model( |
| self, |
| architectures: Union[str, list[str]], |
| ) -> bool: |
| model_cls, _ = self.inspect_model_cls(architectures) |
| return model_cls.is_attention_free |
|
|
| def is_hybrid_model( |
| self, |
| architectures: Union[str, list[str]], |
| ) -> bool: |
| model_cls, _ = self.inspect_model_cls(architectures) |
| return model_cls.is_hybrid |
|
|
| def is_noops_model( |
| self, |
| architectures: Union[str, list[str]], |
| ) -> bool: |
| model_cls, _ = self.inspect_model_cls(architectures) |
| return model_cls.has_noops |
|
|
| def is_transcription_model( |
| self, |
| architectures: Union[str, list[str]], |
| ) -> bool: |
| model_cls, _ = self.inspect_model_cls(architectures) |
| return model_cls.supports_transcription |
|
|
| def is_v1_compatible( |
| self, |
| architectures: Union[str, list[str]], |
| ) -> bool: |
| model_cls, _ = self.inspect_model_cls(architectures) |
| return not model_cls.supports_v0_only |
|
|
|
|
| ModelRegistry = _ModelRegistry({ |
| model_arch: |
| _LazyRegisteredModel( |
| module_name=f"vllm.model_executor.models.{mod_relname}", |
| class_name=cls_name, |
| ) |
| for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items() |
| }) |
|
|
| _T = TypeVar("_T") |
|
|
|
|
| def _run_in_subprocess(fn: Callable[[], _T]) -> _T: |
| |
| |
| with tempfile.TemporaryDirectory() as tempdir: |
| output_filepath = os.path.join(tempdir, "registry_output.tmp") |
|
|
| |
| input_bytes = cloudpickle.dumps((fn, output_filepath)) |
|
|
| |
| |
| returned = subprocess.run(_SUBPROCESS_COMMAND, |
| input=input_bytes, |
| capture_output=True) |
|
|
| |
| try: |
| returned.check_returncode() |
| except Exception as e: |
| |
| raise RuntimeError(f"Error raised in subprocess:\n" |
| f"{returned.stderr.decode()}") from e |
|
|
| with open(output_filepath, "rb") as f: |
| return pickle.load(f) |
|
|
|
|
| def _run() -> None: |
| |
| from vllm.plugins import load_general_plugins |
| load_general_plugins() |
|
|
| fn, output_file = pickle.loads(sys.stdin.buffer.read()) |
|
|
| result = fn() |
|
|
| with open(output_file, "wb") as f: |
| f.write(pickle.dumps(result)) |
|
|
|
|
| if __name__ == "__main__": |
| _run() |
|
|