from typing import Optional import torch from torch import nn from transformers import PreTrainedModel from transformers.models.qwen3_vl import Qwen3VLModel from transformers.utils import logging from .configuration_ops_colqwen3 import OpsColQwen3Config logger = logging.get_logger(__name__) class OpsColQwen3PreTrainedModel(PreTrainedModel): config_class = OpsColQwen3Config base_model_prefix = "ops_colqwen3" supports_gradient_checkpointing = True _no_split_modules = ["Qwen3VLVisionBlock", "Qwen3DecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True class OpsColQwen3Model(OpsColQwen3PreTrainedModel): _checkpoint_conversion_mapping = { r"^language_model": r"qwen3vl.language_model", r"^visual": "qwen3vl.visual", } def __init__(self, config: OpsColQwen3Config): super().__init__(config) self.config = config self.qwen3vl = Qwen3VLModel(config) self.dims = config.text_config.hidden_size self.custom_text_proj = nn.Linear(config.text_config.hidden_size, self.dims) self.mask_non_image_embeddings = config.mask_non_image_embeddings self.post_init() @classmethod def from_pretrained(cls, *args, config: Optional[OpsColQwen3Config] = None, **kwargs): key_mapping = kwargs.pop("key_mapping", None) if key_mapping is None: key_mapping = getattr(cls, "_checkpoint_conversion_mapping", None) dims = None if 'dims' in kwargs: dims = kwargs.pop('dims') elif config is not None: dims = config.dims model = super().from_pretrained(*args, config=config, **kwargs, key_mapping=key_mapping) if dims is not None: model.dims = dims return model def forward(self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, pixel_values: Optional[torch.Tensor] = None, image_grid_thw: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: has_pixel_values = pixel_values is not None if has_pixel_values: if image_grid_thw is None: raise ValueError("`image_grid_thw` must be provided when `pixel_values` is passed.") if not torch.is_tensor(image_grid_thw): image_grid_thw = torch.as_tensor(image_grid_thw, device=pixel_values.device) offsets = image_grid_thw.prod(dim=1) unpadded = [pixel_sequence[: int(offset.item())] for pixel_sequence, offset in zip(pixel_values, offsets)] pixel_values = torch.cat(unpadded, dim=0) if unpadded else None outputs = self.qwen3vl( input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values, image_grid_thw=image_grid_thw, use_cache=False, output_hidden_states=True, return_dict=True, ) last_hidden_states = outputs.last_hidden_state proj = self.custom_text_proj(last_hidden_states) if self.dims < self.config.text_config.hidden_size: proj = proj[..., : self.dims] proj = proj / proj.norm(dim=-1, keepdim=True) if attention_mask is not None: proj = proj * attention_mask.unsqueeze(-1) if has_pixel_values and self.mask_non_image_embeddings and input_ids is not None: image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) proj = proj * image_mask return proj @property def patch_size(self) -> int: return self.qwen3vl.visual.config.patch_size @property def spatial_merge_size(self) -> int: return self.qwen3vl.visual.config.spatial_merge_size