| # Copyright 2024 SGLang Team | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| # Adapted from | |
| # https://github.com/vllm-project/vllm/blob/6071e989df1531b59ef35568f83f7351afb0b51e/vllm/model_executor/models/phi4mm.py | |
| # https://huggingface.co/microsoft/Phi-4-multimodal-instruct/blob/main/processing_phi4mm.py | |
| import logging | |
| import math | |
| import re | |
| from collections.abc import Iterable | |
| from typing import List, Optional, Tuple | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| from transformers import PretrainedConfig | |
| from sglang.srt.layers.quantization import QuantizationConfig | |
| from sglang.srt.managers.mm_utils import ( | |
| MultiModalityDataPaddingPatternMultimodalTokens, | |
| general_mm_embed_routine, | |
| ) | |
| from sglang.srt.managers.schedule_batch import ( | |
| Modality, | |
| MultimodalDataItem, | |
| MultimodalInputs, | |
| ) | |
| from sglang.srt.model_executor.forward_batch_info import ForwardBatch | |
| from sglang.srt.model_loader.weight_utils import default_weight_loader | |
| from sglang.srt.models.idefics2 import Idefics2VisionTransformer | |
| from sglang.srt.models.llama import LlamaForCausalLM | |
| from sglang.srt.models.phi4mm_audio import AudioEmbedding | |
| logger = logging.getLogger(__name__) | |
| SIGLIP_NAME = "siglip-so400m-patch14-448" | |
| VISION_ENCODER_TO_PROCESSING_CONFIG = { | |
| "siglip-so400m-patch14-448": { | |
| "vit_image_size": 448, | |
| "vit_patch_size": 14, | |
| "token_compression_factor": 2, | |
| }, | |
| } | |
| class Phi4MMImageEncoder(nn.Module): | |
| """Image embedding.""" | |
| def __init__( | |
| self, | |
| config: PretrainedConfig, | |
| quant_config: Optional[QuantizationConfig], | |
| prefix: str = "", | |
| model_dir: str = "", | |
| ) -> None: | |
| super().__init__() | |
| # n_embed or hidden_size | |
| hidden_size = config.n_embd if hasattr(config, "n_embd") else config.hidden_size | |
| self.type_feature = "patch" | |
| self.img_processor = Idefics2VisionTransformer( | |
| config=config.vision_config, require_post_norm=False | |
| ) | |
| pe_weight = self.img_processor.embeddings.position_embedding.weight | |
| L, D = pe_weight.size() | |
| H = int(math.sqrt(L)) | |
| assert H**2 == L, f"position embedding size {L} is not square" | |
| if H % 2 != 0: | |
| self.img_processor_padding = nn.ReflectionPad2d((0, 1, 0, 1)) | |
| H += 1 | |
| image_dim_out = D | |
| # ((448/14)//2)**2 | |
| self.num_img_tokens = (H // 2) ** 2 | |
| self.base_feat_height_target = H | |
| self.image_dim_out = image_dim_out | |
| self.img_sizes = None | |
| self.image_attention_mask = None | |
| # global_gn and sub_gn for hd transform, serves as line separator | |
| self.use_hd_transform = True | |
| self.with_learnable_separator = True | |
| self.hd_transform_order = "sub_glb" | |
| self.freeze_img_processor = False | |
| self.crop_size = 448 | |
| # image token compression | |
| self.image_token_compression_cls = "avg_pool_2d" | |
| self.image_token_compression = nn.AvgPool2d(kernel_size=2, stride=2) | |
| self.base_feat_height_reduction = 1 | |
| self.base_feat_height_target = self.base_feat_height_target // 2 | |
| # with_hd_transform and with_learnable_separator should have same value | |
| assert ( | |
| self.use_hd_transform == self.with_learnable_separator | |
| ), "use_hd_transform and with_learnable_separator should have same value" | |
| assert self.use_hd_transform, "learnable separator is only for hd transform" | |
| # 1024 * 4, merge spatial to channel dimension | |
| self.glb_GN = nn.Parameter( | |
| torch.zeros([1, 1, self.image_dim_out * self.base_feat_height_reduction**2]) | |
| ) | |
| self.sub_GN = nn.Parameter( | |
| torch.zeros( | |
| [1, 1, 1, self.image_dim_out * self.base_feat_height_reduction**2] | |
| ) | |
| ) | |
| dim_projection = hidden_size | |
| depth = 2 | |
| layers = [ | |
| nn.Linear( | |
| image_dim_out * self.base_feat_height_reduction**2, dim_projection | |
| ) | |
| ] | |
| for _ in range(1, depth): | |
| layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)]) | |
| self.img_projection = nn.Sequential(*layers) | |
| self.vocab_size = config.vocab_size | |
| self.img_features = None | |
| self.use_out_place_operations = False | |
| def get_img_features( | |
| self, img_embeds: torch.FloatTensor, attention_mask=None | |
| ) -> torch.FloatTensor: | |
| img_feature = self.img_processor( | |
| img_embeds, patch_attention_mask=attention_mask | |
| ) | |
| patch_feature = img_feature | |
| use_token_compression = self.image_token_compression is not None | |
| use_padding = getattr(self, "img_processor_padding", None) is not None | |
| if use_token_compression or use_padding: | |
| # reshape to 2D tensor | |
| width = int(math.sqrt(patch_feature.size(1))) | |
| patch_feature = patch_feature.view(-1, width, width, patch_feature.size(-1)) | |
| # convert to NCHW | |
| patch_feature = patch_feature.permute(0, 3, 1, 2) | |
| if use_padding: | |
| patch_feature = self.img_processor_padding(patch_feature) | |
| if use_token_compression: | |
| patch_feature = self.image_token_compression(patch_feature) | |
| # convert to NHWC | |
| patch_feature = patch_feature.permute(0, 2, 3, 1) | |
| patch_feature = patch_feature.view( | |
| -1, | |
| patch_feature.size(1) * patch_feature.size(2), | |
| patch_feature.size(-1), | |
| ) | |
| return patch_feature | |
| def forward( | |
| self, | |
| pixel_values: torch.FloatTensor, | |
| image_sizes: torch.Tensor, | |
| image_attention_mask: torch.Tensor, | |
| ) -> list[torch.FloatTensor]: | |
| """ | |
| process image and return vision embeddings. | |
| pixel_values: (num_images, num_crops, c, h, w) | |
| image_sizes: [[h1, w1], [h2, w2]] | |
| image_attention_mask: num_images x num_crops x 32 x 32 | |
| output: (num_images, num_img_tokens, hidden_size) | |
| """ | |
| # eg | |
| # pixel_values: torch.Size([1, 7, 3, 448, 448]) | |
| # image_sizes: tensor([[ 896, 1344]], device='cuda:0') | |
| # output: torch.Size([1, 1841, 3072]) | |
| img_projection_params = next(self.img_projection.parameters()) | |
| target_device = img_projection_params.device | |
| target_dtype = img_projection_params.dtype | |
| img_sizes = image_sizes | |
| num_images, num_crops, c, h, w = pixel_values.shape | |
| bs = num_images | |
| pixel_values = pixel_values.flatten(0, 1) | |
| img_features = self.get_img_features( | |
| pixel_values, | |
| image_attention_mask.type(torch.BoolTensor).flatten(0, 1).to(target_device), | |
| ) | |
| base_feat_height_target = self.base_feat_height_target | |
| base_resolution = self.crop_size | |
| base_feat_height_reduction = self.base_feat_height_reduction | |
| base_feat_height = base_feat_width = int(np.sqrt(img_features.shape[1])) | |
| assert ( | |
| base_feat_height == base_feat_height_target | |
| and base_feat_width == base_feat_height_target | |
| ), f'base_feat_height: {base_feat_height},"\ | |
| f" base_feat_width: {base_feat_width}, "\ | |
| f"expect {base_feat_height_target} features for hd transform' | |
| # bs x max_num_crops x (24x24) x C | |
| img_features = img_features.view( | |
| bs, -1, base_feat_height * base_feat_width, self.image_dim_out | |
| ) | |
| C = self.image_dim_out | |
| H = base_feat_height | |
| output_imgs = [] | |
| output_len = [] | |
| # training is tensor, inference is list | |
| if isinstance(img_sizes, torch.Tensor): | |
| img_sizes = img_sizes.view(-1, 2) | |
| for _bs in range(bs): | |
| h, w = img_sizes[_bs] | |
| h = h // base_resolution | |
| w = w // base_resolution | |
| B_ = h * w | |
| # 1 x (24x24) x 1024 | |
| global_img_feature = img_features[_bs, :1] | |
| # 1 x 12 x 12 x 4096 | |
| glb_img = ( | |
| global_img_feature.reshape(1, H, H, C) | |
| .reshape( | |
| 1, | |
| H // base_feat_height_reduction, | |
| base_feat_height_reduction, | |
| H // base_feat_height_reduction, | |
| base_feat_height_reduction, | |
| C, | |
| ) | |
| .contiguous() | |
| .permute(0, 1, 3, 2, 4, 5) | |
| .reshape( | |
| 1, | |
| H // base_feat_height_reduction, | |
| H // base_feat_height_reduction, | |
| base_feat_height_reduction * base_feat_height_reduction * C, | |
| ) | |
| .contiguous() | |
| ) | |
| temp_glb_GN = self.sub_GN.repeat(1, H // base_feat_height_reduction, 1, 1) | |
| # 1 x 156 x 4096 | |
| glb_img = torch.cat([glb_img, temp_glb_GN], dim=2).reshape( | |
| 1, -1, base_feat_height_reduction * base_feat_height_reduction * C | |
| ) | |
| # (max_num_crops-1) x (12x12) x C | |
| sub_img = img_features[_bs, 1:] | |
| # 16x574x1024 | |
| # get rid of padding sub_img | |
| sub_img = sub_img[:B_] | |
| # (num_crops, 12, 2, 12, 2, 1024) -> | |
| # (num_crops, 12, 12, 2, 2, 1024) -> (num_crops, 12*12, 4*1024) | |
| sub_img = ( | |
| sub_img.reshape(B_, H, H, C) | |
| .reshape( | |
| B_, | |
| H // base_feat_height_reduction, | |
| base_feat_height_reduction, | |
| H // base_feat_height_reduction, | |
| base_feat_height_reduction, | |
| C, | |
| ) | |
| .contiguous() | |
| .permute(0, 1, 3, 2, 4, 5) | |
| .reshape( | |
| B_, -1, base_feat_height_reduction * base_feat_height_reduction * C | |
| ) | |
| .contiguous() | |
| ) | |
| sub_img = ( | |
| sub_img.reshape( | |
| 1, | |
| h, | |
| w, | |
| base_feat_height // base_feat_height_reduction, | |
| base_feat_width // base_feat_height_reduction, | |
| -1, | |
| ) | |
| .permute(0, 1, 3, 2, 4, 5) | |
| .reshape( | |
| 1, | |
| h * base_feat_height // base_feat_height_reduction, | |
| w * base_feat_width // base_feat_height_reduction, | |
| base_feat_height_reduction * base_feat_height_reduction * C, | |
| ) | |
| ) | |
| if image_attention_mask is not None and len(image_attention_mask) > 0: | |
| reshaped_image_attention_mask = ( | |
| image_attention_mask[_bs, 1 : B_ + 1, 0::2, 0::2] | |
| .reshape( | |
| 1, | |
| h, | |
| w, | |
| base_feat_height // base_feat_height_reduction, | |
| base_feat_width // base_feat_height_reduction, | |
| ) | |
| .permute(0, 1, 3, 2, 4) | |
| .reshape( | |
| 1, | |
| h * base_feat_height // base_feat_height_reduction, | |
| w * base_feat_width // base_feat_height_reduction, | |
| ) | |
| ) | |
| useful_height = int(reshaped_image_attention_mask[0, :, 0].sum().item()) | |
| useful_width = int(reshaped_image_attention_mask[0, 0, :].sum().item()) | |
| sub_img = sub_img[:, :useful_height, :useful_width] | |
| temp_sub_GN = self.sub_GN.repeat(1, useful_height, 1, 1) | |
| temp_len = ( | |
| int(image_attention_mask[_bs, : B_ + 1, 0::2, 0::2].sum().item()) | |
| + (useful_height + 1) | |
| + base_feat_height // base_feat_height_reduction | |
| ) | |
| else: | |
| temp_sub_GN = self.sub_GN.repeat( | |
| 1, h * base_feat_height // base_feat_height_reduction, 1, 1 | |
| ) | |
| temp_len = int( | |
| (h * w + 1) * self.num_img_tokens | |
| + 1 | |
| + (h + 1) * base_feat_height // base_feat_height_reduction | |
| ) | |
| sub_img = torch.cat([sub_img, temp_sub_GN], dim=2).reshape( | |
| 1, -1, base_feat_height_reduction * base_feat_height_reduction * C | |
| ) | |
| # (1, num_img_tokens, 1024*4) | |
| # glb + sub | |
| if self.hd_transform_order == "glb_sub": | |
| output_imgs.append(torch.cat([glb_img, self.glb_GN, sub_img], dim=1)) | |
| elif self.hd_transform_order == "sub_glb": | |
| output_imgs.append(torch.cat([sub_img, self.glb_GN, glb_img], dim=1)) | |
| else: | |
| raise NotImplementedError( | |
| f'hd_transform_order = {self.hd_transform_order}, "\ | |
| "not implemented' | |
| ) | |
| # temp_len = int((h*w+1)*144 + 1 + (h+1)*12) | |
| assert ( | |
| temp_len == output_imgs[-1].shape[1] | |
| ), f'temp_len: {temp_len}, output_imgs[-1].shape[1]: "\ | |
| "{output_imgs[-1].shape[1]}' | |
| output_len.append(temp_len) | |
| img_set_tensor = [] | |
| for _output_img in output_imgs: | |
| img_feature_proj = self.img_projection( | |
| _output_img.to(target_device).to(target_dtype) | |
| ) | |
| img_set_tensor.append(img_feature_proj.squeeze(0)) | |
| return img_set_tensor | |
| class Phi4MMForCausalLM(nn.Module): | |
| packed_modules_mapping = { | |
| "qkv_proj": ["q_proj", "k_proj", "v_proj"], | |
| "gate_up_proj": ["gate_proj", "up_proj"], | |
| } | |
| lora_pattern = re.compile( | |
| r"^language_model\.model\.layers\.(\d+)\.(?:self_attn|mlp)\.(?:qkv_proj|o_proj|down_proj|gate_up_proj)" | |
| ) | |
| def __init__( | |
| self, | |
| config: PretrainedConfig, | |
| quant_config: Optional[QuantizationConfig] = None, | |
| prefix: str = "", | |
| ): | |
| super().__init__() | |
| self.language_model = LlamaForCausalLM( | |
| config=config, quant_config=quant_config, prefix=prefix | |
| ) | |
| self.vision_encoder = Phi4MMImageEncoder( | |
| config, | |
| quant_config, | |
| prefix="model.vision_embed_tokens", | |
| model_dir=config._name_or_path, | |
| ) | |
| if isinstance(config.embd_layer["audio_embd_layer"], dict): | |
| embedding_config = { | |
| "embedding_cls": config.embd_layer["audio_embd_layer"]["embedding_cls"], | |
| **config.embd_layer["audio_embd_layer"], | |
| } | |
| else: | |
| embedding_config = {"embedding_cls": config.embd_layer["embedding_cls"]} | |
| self.embed_tokens_extend = AudioEmbedding(config, **embedding_config) | |
| def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: | |
| dtype = next(self.vision_encoder.parameters()).dtype | |
| pixel_values = torch.cat([item.feature for item in items], dim=0).type(dtype) | |
| image_attention_mask = torch.cat( | |
| [ | |
| item.image_attention_mask | |
| for item in items | |
| if hasattr(item, "image_attention_mask") | |
| ], | |
| dim=0, | |
| ) | |
| image_sizes = torch.cat([item.image_sizes for item in items], dim=0) | |
| image_embeds = self.vision_encoder( | |
| pixel_values, image_sizes, image_attention_mask | |
| ) | |
| return torch.cat(image_embeds).type(dtype) | |
| def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: | |
| # (e.g. multiple examples) and the second dim is the multi-audio dim | |
| # (e.g. multiple audios in the same example) | |
| embed_tokens_extend_param = next(self.embed_tokens_extend.parameters()) | |
| device = embed_tokens_extend_param.device | |
| dtype = embed_tokens_extend_param.dtype | |
| audio_embeds = [ | |
| self.embed_tokens_extend( | |
| # item.feature: (num_audios_in_a_sequence, T, D) | |
| # item.audio_attention_mask: (num_audios_in_a_sequence, T, D) BoolTensor or None | |
| audio_features=item.feature.to(device).type(dtype), | |
| audio_attention_mask=( | |
| item.audio_attention_mask.to(device) | |
| if hasattr(item, "audio_attention_mask") | |
| else None | |
| ), | |
| ) | |
| for item in items | |
| ] | |
| return torch.cat(audio_embeds).type(dtype) | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| positions: torch.Tensor, | |
| forward_batch: ForwardBatch, | |
| **kwargs: object, | |
| ) -> torch.Tensor: | |
| hidden_states = general_mm_embed_routine( | |
| input_ids=input_ids, | |
| forward_batch=forward_batch, | |
| language_model=self.language_model, | |
| data_embedding_funcs={ | |
| Modality.IMAGE: self.get_image_feature, | |
| Modality.AUDIO: self.get_audio_feature, | |
| }, | |
| positions=positions, | |
| ) | |
| return hidden_states | |
| def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): | |
| pattern = MultiModalityDataPaddingPatternMultimodalTokens() | |
| return pattern.pad_input_tokens(input_ids, mm_inputs) | |
| def should_apply_lora(self, module_name: str) -> bool: | |
| return bool(self.lora_pattern.match(module_name)) | |
| def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): | |
| stacked_params_mapping = [ | |
| # (param_name, shard_name, shard_id) | |
| (".self_attn.qkv_proj", ".self_attn.q_proj", "q"), | |
| (".self_attn.qkv_proj", ".self_attn.k_proj", "k"), | |
| (".self_attn.qkv_proj", ".self_attn.v_proj", "v"), | |
| ] | |
| prefix_mapping = { | |
| "model.embed_tokens_extend.audio_embed.audio_projection.vision.": "embed_tokens_extend.audio_projection_for_vision.", | |
| "model.embed_tokens_extend.audio_embed.audio_projection.speech.": "embed_tokens_extend.audio_projection.", | |
| "model.embed_tokens_extend.audio_embed.": "embed_tokens_extend.", | |
| "model.embed_tokens_extend.image_embed.": "vision_encoder.", | |
| "model.": "language_model.model.", | |
| } | |
| skip_list = [ | |
| "img_processor.encoder.layers.26", | |
| "img_processor.head", | |
| "img_processor.post_layernorm", | |
| ] | |
| def _should_skip(name: str) -> bool: | |
| return any(substr in name for substr in skip_list) | |
| params_dict = dict(self.named_parameters()) | |
| for name, loaded_weight in weights: | |
| # Skip the last layer | |
| if _should_skip(name): | |
| continue | |
| for old_name, new_name in prefix_mapping.items(): | |
| if name.startswith(old_name): | |
| name = name.replace(old_name, new_name) | |
| break | |
| # Adapt to VisionAttention | |
| name = name.replace(r"self_attn.out_proj", r"self_attn.proj") | |
| name = name.replace(r"base_layer.", r"") | |
| for param_name, weight_name, shard_id in stacked_params_mapping: | |
| if weight_name not in name: | |
| continue | |
| name = name.replace(weight_name, param_name) | |
| param = params_dict[name] | |
| weight_loader = param.weight_loader | |
| weight_loader(param, loaded_weight, shard_id) | |
| break | |
| else: | |
| param = params_dict.get(name) | |
| if param is None: | |
| if "lora" not in name: | |
| logger.warning("Warning: {name} not found in model parameters") | |
| continue | |
| weight_loader = getattr(param, "weight_loader", default_weight_loader) | |
| weight_loader(param, loaded_weight) | |
| EntryClass = [Phi4MMForCausalLM] | |
Xet Storage Details
- Size:
- 20.6 kB
- Xet hash:
- 73bbf4f78016342f39426d97a143a504f6143d7dbef6faeea3994228437db5b9
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.