| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """PyTorch PenguinVL model.""" |
|
|
| import importlib.util |
| import os.path as osp |
| import re |
| from abc import ABC, abstractmethod |
| from typing import List, Optional, Tuple, Union |
|
|
| import torch |
| import torch.nn as nn |
| import torch.utils.checkpoint |
| import math |
|
|
| from transformers import Qwen3ForCausalLM, Qwen3Model |
| from transformers.generation.utils import GenerateOutput |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
| try: |
| from .configuration_penguinvl import PenguinVLQwen3Config |
| except ModuleNotFoundError: |
| spec = importlib.util.spec_from_file_location( |
| "configuration_penguinvl", |
| osp.join(osp.dirname(__file__), "configuration_penguinvl.py"), |
| ) |
| configuration_penguinvl = importlib.util.module_from_spec(spec) |
| spec.loader.exec_module(configuration_penguinvl) |
| PenguinVLQwen3Config = getattr( |
| configuration_penguinvl, |
| "PenguinVLQwen3Config", |
| ) |
|
|
| try: |
| from .configuration_penguinvl_encoder import PenguinVLVisionEncoderConfig |
| from .modeling_penguinvl_encoder import PenguinVLVisionEncoderModel |
| except ModuleNotFoundError: |
| enc_spec = importlib.util.spec_from_file_location( |
| "configuration_penguinvl_encoder", |
| osp.join(osp.dirname(__file__), "configuration_penguinvl_encoder.py"), |
| ) |
| configuration_penguinvl_encoder = importlib.util.module_from_spec(enc_spec) |
| enc_spec.loader.exec_module(configuration_penguinvl_encoder) |
| PenguinVLVisionEncoderConfig = getattr( |
| configuration_penguinvl_encoder, |
| "PenguinVLVisionEncoderConfig", |
| ) |
| enc_model_spec = importlib.util.spec_from_file_location( |
| "modeling_penguinvl_encoder", |
| osp.join(osp.dirname(__file__), "modeling_penguinvl_encoder.py"), |
| ) |
| modeling_penguinvl_encoder = importlib.util.module_from_spec(enc_model_spec) |
| enc_model_spec.loader.exec_module(modeling_penguinvl_encoder) |
| PenguinVLVisionEncoderModel = getattr( |
| modeling_penguinvl_encoder, |
| "PenguinVLVisionEncoderModel", |
| ) |
|
|
|
|
| def build_mlp(depth, hidden_size, output_hidden_size): |
| modules = [nn.Linear(hidden_size, output_hidden_size)] |
| for _ in range(1, depth): |
| modules.append(nn.GELU()) |
| modules.append(nn.Linear(output_hidden_size, output_hidden_size)) |
| return nn.Sequential(*modules) |
|
|
|
|
| def build_vision_projector(config, **kwargs): |
| projector_type = getattr(config, 'vision_projector_type', 'linear') |
| if projector_type == "linear": |
| return nn.Linear(config.mm_hidden_size, config.hidden_size) |
| elif projector_type.startswith("mlp"): |
| return MlpGeluProjector(config.vision_encoder_config.hidden_size, config.hidden_size, projector_type) |
| else: |
| raise ValueError(f'Unknown projector type: {projector_type}') |
|
|
|
|
| class MlpGeluProjector(nn.Module): |
|
|
| def __init__(self, mm_hidden_size, hidden_size, projector_type): |
| super().__init__() |
|
|
| mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type) |
| mlp_depth = int(mlp_gelu_match.group(1)) |
|
|
| self.readout = build_mlp(mlp_depth, mm_hidden_size, hidden_size) |
|
|
| def forward(self, x): |
| x = self.readout(x) |
| return x |
| |
|
|
| class MlpGeluDownsampleProjector(nn.Module): |
| def __init__(self, mm_hidden_size, hidden_size, projector_type): |
| super().__init__() |
| self.downsample = nn.Linear(mm_hidden_size*8, mm_hidden_size) |
|
|
| mlp_gelu_match = re.match(r"^dmlp(\d+)x_gelu$", projector_type) |
| mlp_depth = int(mlp_gelu_match.group(1)) |
|
|
| self.readout = build_mlp(mlp_depth, mm_hidden_size, hidden_size) |
|
|
| def forward(self, x): |
| B, S, D = x.shape |
|
|
| group = 8 |
| S8 = (S // group) * group |
| x = x[:, :S8, :] |
| x = x.reshape(B, S8 // group, group * D) |
| x = self.downsample(x) |
| x = self.readout(x) |
| return x |
|
|
|
|
| class VLMMetaModel: |
|
|
| def __init__(self, config): |
| super(VLMMetaModel, self).__init__(config) |
| if config.vision_encoder is not None: |
| |
| encoder_config = PenguinVLVisionEncoderConfig.from_pretrained(config.vision_encoder) |
| self.vision_encoder = PenguinVLVisionEncoderModel.from_pretrained( |
| config.vision_encoder, |
| config=encoder_config, |
| attn_implementation=self.config._attn_implementation, |
| torch_dtype=self.dtype, |
| ) |
| self.config.vision_encoder_config = self.vision_encoder.config |
| self.config.vision_encoder = None |
| elif config.vision_encoder_config is not None: |
| self.vision_encoder = PenguinVLVisionEncoderModel.from_config( |
| self.config.vision_encoder_config, |
| attn_implementation=self.config._attn_implementation, |
| torch_dtype=self.dtype, |
| ) |
| else: |
| raise ValueError("Vision encoder is not provided in config") |
|
|
| self.vision_projector = build_vision_projector(config) |
|
|
| def get_vision_encoder(self): |
| return self.vision_encoder |
|
|
| def get_vision_projector(self): |
| return self.vision_projector |
|
|
|
|
| class PenguinVLQwen3Model(VLMMetaModel, Qwen3Model): |
|
|
| config_class = PenguinVLQwen3Config |
|
|
| def __init__(self, config: PenguinVLQwen3Config): |
| super(PenguinVLQwen3Model, self).__init__(config) |
|
|
|
|
| class VLMMetaForCausalLM(ABC): |
|
|
| @abstractmethod |
| def get_model(self): |
| pass |
|
|
| def get_vision_encoder(self): |
| return self.get_model().get_vision_encoder() |
|
|
| def get_vision_projector(self): |
| return self.get_model().get_vision_projector() |
|
|
| def encode_images( |
| self, |
| pixel_values: torch.FloatTensor, |
| grid_sizes: torch.LongTensor, |
| merge_sizes: torch.LongTensor, |
| ) -> torch.FloatTensor: |
| mm_features = self.get_model().get_vision_encoder()( |
| pixel_values=pixel_values, |
| grid_sizes=grid_sizes, |
| merge_sizes=merge_sizes, |
| ) |
| mm_features = self.get_model().vision_projector(mm_features) |
| return mm_features |
|
|
| def _get_valid_visual_tokens( |
| self, |
| mm_features: torch.FloatTensor, |
| batched_num_patches: torch.LongTensor, |
| modals: List[str], |
| ): |
| valid_masks = [] |
| for num_patches, modal in zip(batched_num_patches, modals): |
| valid_mask = torch.full((num_patches, ), modal != "text", dtype=torch.bool, device=mm_features.device) |
| valid_masks.append(valid_mask) |
| mm_features = mm_features[torch.cat(valid_masks)] |
| return mm_features |
|
|
| def prepare_inputs_labels_for_multimodal( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| labels: Optional[torch.LongTensor] = None, |
| pixel_values: Optional[torch.FloatTensor] = None, |
| grid_sizes: Optional[torch.LongTensor] = None, |
| merge_sizes: Optional[torch.LongTensor] = None, |
| modals: Optional[List[str]] = None, |
| ): |
| vision_encoder = self.get_vision_encoder() |
| |
| if vision_encoder is None or pixel_values is None or input_ids.shape[1] == 1: |
| return input_ids, attention_mask, position_ids, past_key_values, None, labels |
|
|
| |
| B, N = input_ids.shape |
| input_ids = input_ids.view(B * N) |
| if attention_mask is not None: |
| attention_mask = attention_mask.view(B * N) |
| if position_ids is not None: |
| position_ids = position_ids.view(B * N) |
| if labels is not None: |
| labels = labels.view(B * N) |
|
|
| |
| image_selected, mm_features_teacher = None, None |
| if pixel_values is not None: |
| |
| batched_num_patches = grid_sizes.prod(dim=1).div(merge_sizes ** 2).long() |
| mm_features = self.encode_images(pixel_values, grid_sizes, merge_sizes) |
| mm_features = mm_features.to(input_ids.device) |
| mm_features = self._get_valid_visual_tokens(mm_features, batched_num_patches, modals) |
|
|
| |
| image_selected = (input_ids == self.config.image_token_index) |
| input_ids[image_selected] = 0 |
|
|
| num_vision_tokens = image_selected.sum() |
| if mm_features.size(0) != num_vision_tokens: |
| print(f"Number of vision_features ({mm_features.size(0)}) does not match the number of image tokens ({num_vision_tokens}). Please check the inputs.") |
| mm_features = mm_features[:num_vision_tokens] |
|
|
| |
| inputs_embeds = self.get_model().embed_tokens(input_ids).clone() |
| if image_selected is not None: |
| inputs_embeds[image_selected] = inputs_embeds[image_selected] * 0.0 + mm_features |
|
|
| |
| C = inputs_embeds.shape[-1] |
| inputs_embeds = inputs_embeds.reshape(B, -1, C) |
| if attention_mask is not None: |
| attention_mask = attention_mask.view(B, -1) |
| if labels is not None: |
| labels = labels.view(B, -1) |
| if position_ids is not None: |
| position_ids = position_ids.view(B, -1) |
|
|
| return None, attention_mask, position_ids, past_key_values, inputs_embeds, labels |
|
|
|
|
| class PenguinVLQwen3ForCausalLM(Qwen3ForCausalLM, VLMMetaForCausalLM): |
|
|
| config_class = PenguinVLQwen3Config |
|
|
| def __init__(self, config, **kwargs): |
| super(Qwen3ForCausalLM, self).__init__(config) |
| self.model = PenguinVLQwen3Model(config) |
| self.vocab_size = config.vocab_size |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
| |
| self.post_init() |
|
|
| def get_model(self): |
| return self.model |
| |
| @classmethod |
| def _load_pretrained_model( |
| cls, |
| model, |
| state_dict, |
| checkpoint_files, |
| pretrained_model_name_or_path, |
| ignore_mismatched_sizes=False, |
| sharded_metadata=None, |
| device_map=None, |
| disk_offload_folder=None, |
| offload_state_dict=None, |
| dtype=None, |
| hf_quantizer=None, |
| keep_in_fp32_regex=None, |
| device_mesh=None, |
| key_mapping=None, |
| weights_only=True, |
| ): |
| """ |
| Override to handle nested vision_encoder keys before calling parent's load method. |
| Remaps keys from 'model.vision_encoder.vision_encoder.*' to 'model.vision_encoder.*' |
| """ |
| |
| if state_dict is not None: |
| needs_remapping = any(k.startswith('model.vision_encoder.vision_encoder.') for k in state_dict.keys()) |
| if needs_remapping: |
| print("Detected nested encoder keys, remapping 'model.vision_encoder.vision_encoder.*' -> 'model.vision_encoder.*'") |
| new_state_dict = {} |
| for k, v in state_dict.items(): |
| if k.startswith('model.vision_encoder.vision_encoder.'): |
| |
| new_key = k.replace('model.vision_encoder.vision_encoder.', 'model.vision_encoder.') |
| new_state_dict[new_key] = v |
| else: |
| new_state_dict[k] = v |
| state_dict = new_state_dict |
| |
| |
| if checkpoint_files is not None and key_mapping is None: |
| |
| from transformers.modeling_utils import load_state_dict |
| checkpoint = {} |
| checkpoint_files_list = checkpoint_files if isinstance(checkpoint_files, list) else [checkpoint_files] |
| for ckpt_file in checkpoint_files_list: |
| ckpt = load_state_dict(ckpt_file, map_location="cpu", weights_only=weights_only) |
| checkpoint.update(ckpt) |
| needs_remapping = any(k.startswith('model.vision_encoder.vision_encoder.') for k in checkpoint.keys()) |
|
|
| if needs_remapping: |
| print("Detected nested encoder keys in checkpoint, adding key mapping for vision_encoder") |
| key_mapping = {} |
| for k in checkpoint.keys(): |
| if k.startswith('model.vision_encoder.vision_encoder.'): |
| new_key = k.replace('model.vision_encoder.vision_encoder.', 'model.vision_encoder.') |
| key_mapping[k] = new_key |
| del checkpoint |
| |
| return super()._load_pretrained_model( |
| model=model, |
| state_dict=state_dict, |
| checkpoint_files=checkpoint_files, |
| pretrained_model_name_or_path=pretrained_model_name_or_path, |
| ignore_mismatched_sizes=ignore_mismatched_sizes, |
| sharded_metadata=sharded_metadata, |
| device_map=device_map, |
| disk_offload_folder=disk_offload_folder, |
| offload_state_dict=offload_state_dict, |
| dtype=dtype, |
| hf_quantizer=hf_quantizer, |
| keep_in_fp32_regex=keep_in_fp32_regex, |
| device_mesh=device_mesh, |
| key_mapping=key_mapping, |
| weights_only=weights_only, |
| ) |
|
|
| |
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| num_logits_to_keep: int = 0, |
| |
| pixel_values: Optional[torch.FloatTensor] = None, |
| grid_sizes: Optional[torch.LongTensor] = None, |
| merge_sizes: Optional[torch.LongTensor] = None, |
| modals: Optional[List[str]] = None, |
| **loss_kwargs, |
| ) -> Union[Tuple, CausalLMOutputWithPast]: |
| if inputs_embeds is None: |
| ( |
| input_ids, |
| attention_mask, |
| position_ids, |
| past_key_values, |
| inputs_embeds, |
| labels, |
| ) = self.prepare_inputs_labels_for_multimodal( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| labels=labels, |
| pixel_values=pixel_values, |
| grid_sizes=grid_sizes, |
| merge_sizes=merge_sizes, |
| modals=modals, |
| ) |
|
|
| return super().forward( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| labels=labels, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| cache_position=cache_position, |
| num_logits_to_keep=num_logits_to_keep, |
| **loss_kwargs, |
| ) |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| |
| pixel_values: Optional[torch.FloatTensor] = None, |
| grid_sizes: Optional[torch.LongTensor] = None, |
| merge_sizes: Optional[torch.LongTensor] = None, |
| modals: Optional[List[str]] = None, |
| **kwargs, |
| ) -> Union[GenerateOutput, torch.LongTensor]: |
| input_ids = kwargs.pop("input_ids", None) |
| attention_mask = kwargs.pop("attention_mask", None) |
| position_ids = kwargs.pop("position_ids", None) |
| past_key_values = kwargs.pop("past_key_values", None) |
|
|
| if "inputs_embeds" in kwargs: |
| raise NotImplementedError("`inputs_embeds` is not supported") |
|
|
| if pixel_values is not None: |
| ( |
| input_ids, |
| attention_mask, |
| position_ids, |
| past_key_values, |
| inputs_embeds, |
| labels, |
| ) = self.prepare_inputs_labels_for_multimodal( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| labels=None, |
| pixel_values=pixel_values, |
| grid_sizes=grid_sizes, |
| merge_sizes=merge_sizes, |
| modals=modals, |
| ) |
| else: |
| inputs_embeds = self.get_model().embed_tokens(input_ids) |
|
|
| return super().generate( |
| position_ids=position_ids, |
| attention_mask=attention_mask, |
| inputs_embeds=inputs_embeds, |
| **kwargs |
| ) |
|
|
| def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): |
| images = kwargs.pop("images", None) |
| _inputs = super().prepare_inputs_for_generation( |
| input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs |
| ) |
| if images is not None: |
| _inputs['images'] = images |
| return _inputs |
|
|