|
|
from typing import Optional |
|
|
import torch |
|
|
from torch import nn |
|
|
from transformers import PreTrainedModel |
|
|
from transformers.models.qwen3_vl import Qwen3VLModel |
|
|
from transformers.utils import logging |
|
|
|
|
|
from .configuration_ops_colqwen3 import OpsColQwen3Config |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
class OpsColQwen3PreTrainedModel(PreTrainedModel): |
|
|
config_class = OpsColQwen3Config |
|
|
base_model_prefix = "ops_colqwen3" |
|
|
supports_gradient_checkpointing = True |
|
|
_no_split_modules = ["Qwen3VLVisionBlock", "Qwen3DecoderLayer"] |
|
|
_skip_keys_device_placement = "past_key_values" |
|
|
_supports_flash_attn_2 = True |
|
|
_supports_sdpa = True |
|
|
_supports_cache_class = True |
|
|
|
|
|
|
|
|
class OpsColQwen3Model(OpsColQwen3PreTrainedModel): |
|
|
_checkpoint_conversion_mapping = { |
|
|
r"^language_model": r"qwen3vl.language_model", |
|
|
r"^visual": "qwen3vl.visual", |
|
|
} |
|
|
|
|
|
def __init__(self, config: OpsColQwen3Config): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
|
|
|
self.qwen3vl = Qwen3VLModel(config) |
|
|
self.dims = config.text_config.hidden_size |
|
|
self.custom_text_proj = nn.Linear(config.text_config.hidden_size, self.dims) |
|
|
|
|
|
self.mask_non_image_embeddings = config.mask_non_image_embeddings |
|
|
self.post_init() |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, *args, config: Optional[OpsColQwen3Config] = None, **kwargs): |
|
|
key_mapping = kwargs.pop("key_mapping", None) |
|
|
if key_mapping is None: |
|
|
key_mapping = getattr(cls, "_checkpoint_conversion_mapping", None) |
|
|
dims = None |
|
|
if 'dims' in kwargs: |
|
|
dims = kwargs.pop('dims') |
|
|
elif config is not None: |
|
|
dims = config.dims |
|
|
|
|
|
model = super().from_pretrained(*args, config=config, **kwargs, key_mapping=key_mapping) |
|
|
if dims is not None: |
|
|
model.dims = dims |
|
|
return model |
|
|
|
|
|
def forward(self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, pixel_values: Optional[torch.Tensor] = None, image_grid_thw: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: |
|
|
has_pixel_values = pixel_values is not None |
|
|
|
|
|
if has_pixel_values: |
|
|
if image_grid_thw is None: |
|
|
raise ValueError("`image_grid_thw` must be provided when `pixel_values` is passed.") |
|
|
if not torch.is_tensor(image_grid_thw): |
|
|
image_grid_thw = torch.as_tensor(image_grid_thw, device=pixel_values.device) |
|
|
|
|
|
offsets = image_grid_thw.prod(dim=1) |
|
|
unpadded = [pixel_sequence[: int(offset.item())] for pixel_sequence, offset in zip(pixel_values, offsets)] |
|
|
pixel_values = torch.cat(unpadded, dim=0) if unpadded else None |
|
|
|
|
|
outputs = self.qwen3vl( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
pixel_values=pixel_values, |
|
|
image_grid_thw=image_grid_thw, |
|
|
use_cache=False, |
|
|
output_hidden_states=True, |
|
|
return_dict=True, |
|
|
) |
|
|
|
|
|
last_hidden_states = outputs.last_hidden_state |
|
|
proj = self.custom_text_proj(last_hidden_states) |
|
|
|
|
|
if self.dims < self.config.text_config.hidden_size: |
|
|
proj = proj[..., : self.dims] |
|
|
|
|
|
proj = proj / proj.norm(dim=-1, keepdim=True) |
|
|
|
|
|
if attention_mask is not None: |
|
|
proj = proj * attention_mask.unsqueeze(-1) |
|
|
|
|
|
if has_pixel_values and self.mask_non_image_embeddings and input_ids is not None: |
|
|
image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) |
|
|
proj = proj * image_mask |
|
|
|
|
|
return proj |
|
|
|
|
|
@property |
|
|
def patch_size(self) -> int: |
|
|
return self.qwen3vl.visual.config.patch_size |
|
|
|
|
|
@property |
|
|
def spatial_merge_size(self) -> int: |
|
|
return self.qwen3vl.visual.config.spatial_merge_size |
|
|
|