Transformers documentation
Building a compatible model backend for inference
Building a compatible model backend for inference
Transformers models are compatible with inference engines like vLLM and SGLang. Use the same Transformers model anywhere and avoid reimplementing a model from scratch for each inference engine. Models in Transformers that aren’t natively supported by either inference engine work too.
This guide shows you how to implement a model in Transformers that works as a backend for any inference engine.
Model implementation
Follow the model contribution guidelines or the custom model contribution guidelines. The model must have a valid
config.jsonin its directory and a validauto_mapfield pointing to the model class in the config.Use the AttentionInterface class for custom and optimized attention functions. This interface unlocks each inference engine’s performance features.
Use
ALL_ATTENTION_FUNCTIONSwhen defining the attention layer and propagate**kwargs**from the baseMyModelclass to the attention layers. Set_supports_attention_backendtoTruein PreTrainedModel.Expand the code below for an example.
modeling_my_model.py
from transformers import PreTrainedModel from torch import nn class MyAttention(nn.Module): def forward(self, hidden_states, **kwargs): ... attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, **kwargs, ) ... class MyModel(PreTrainedModel): _supports_attention_backend = TrueEnable optional tensor or pipeline parallelism by adding the following keys to PreTrainedConfig.
base_model_tp_planenables tensor parallelism by mapping fully qualified layer name patterns to tensor parallel styles. Supports only the"colwise"and"rowwise"partitioning strategies.base_model_pp_planenables pipeline parallelism by mapping direct child layer names to tuples of lists of strings. The first element of the tuple contains the names of the input arguments. The last element contains the variable names of the layer outputs in the modeling code.
Expand the code below for an example.
configuration_my_model.py
from transformers import PreTrainedConfig class MyConfig(PreTrainedConfig): base_model_tp_plan = { "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", "layers.*.mlp.gate_proj": "colwise", "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } base_model_pp_plan = { "embed_tokens": (["input_ids"], ["inputs_embeds"]), "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), }
Multimodal models
Multimodal models require additional changes beyond the vision language model contribution checklist. These changes ensure multimodal inputs are properly processed.
The ProcessorMixin class must include the
self.image_tokenandself.image_token_idsattributes. These placeholder tokens indicate image positions in the input. The same token appears in the input prompt for images and in the model code to scatter image features.The ProcessorMixin class must include a
self._get_num_multimodal_tokensmethod. This method computes the number of placeholder tokens required for multimodal inputs with given sizes. It returns aMultiModalDataobject. Placeholders between<image>tokens, such as row or column tokens, don’t count as image placeholders. Count only tokens replaced by image features later in the modeling code.The ProcessorMixin class must check the value of
return_mm_token_type_idsand returnmm_token_type_ids. This indicates whether each position is a text token (0), image placeholder token (1), or a video placeholder token (2). Multimodal token type id sequences must be contiguous with no breaks between consecutive tokens. Treat special tokens for beginning, ending, row, and column tokens as placeholders.
Expand the code below for an example.
modeling_my_multimodal_model.py
class MyMultimodalProcessor(ProcessorMixin):
def __call__(self, images=None, text=None, **kwargs):
if return_mm_token_type_ids:
mm_token_type_ids = np.zeros_like(input_ids)
mm_token_type_ids[input_ids == self.image_token_id] = 1
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
"""
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
Args:
image_sizes (`list[list[int]]`, *optional*):
The input sizes formatted as (height, width) per each image.
Returns:
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
input modalities, along with other useful data.
"""
vision_data = {}
if image_sizes is not None:
num_image_tokens = [256] * len(image_sizes) # 256 placeholder tokens for each image always
num_image_patches = [1] * len(image_sizes) # no patching, thus each image is processed as a single base image
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
return MultiModalData(**vision_data)Resources
- Read the Transformers backend integration in vLLM blog post for more details.
- Read the Transformers backend integration in SGLang blog post for more details.