File size: 3,829 Bytes
4894b7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from typing import Optional
import torch
from torch import nn
from transformers import PreTrainedModel
from transformers.models.qwen3_vl import Qwen3VLModel
from transformers.utils import logging

from .configuration_ops_colqwen3 import OpsColQwen3Config

logger = logging.get_logger(__name__)


class OpsColQwen3PreTrainedModel(PreTrainedModel):
    config_class = OpsColQwen3Config
    base_model_prefix = "ops_colqwen3"
    supports_gradient_checkpointing = True
    _no_split_modules = ["Qwen3VLVisionBlock", "Qwen3DecoderLayer"]
    _skip_keys_device_placement = "past_key_values"
    _supports_flash_attn_2 = True
    _supports_sdpa = True
    _supports_cache_class = True


class OpsColQwen3Model(OpsColQwen3PreTrainedModel):
    _checkpoint_conversion_mapping = {
        r"^language_model": r"qwen3vl.language_model",
        r"^visual": "qwen3vl.visual",
    }

    def __init__(self, config: OpsColQwen3Config):
        super().__init__(config)
        self.config = config

        self.qwen3vl = Qwen3VLModel(config)
        self.dims = config.text_config.hidden_size
        self.custom_text_proj = nn.Linear(config.text_config.hidden_size, self.dims)

        self.mask_non_image_embeddings = config.mask_non_image_embeddings
        self.post_init()

    @classmethod
    def from_pretrained(cls, *args, config: Optional[OpsColQwen3Config] = None, **kwargs):
        key_mapping = kwargs.pop("key_mapping", None)
        if key_mapping is None:
            key_mapping = getattr(cls, "_checkpoint_conversion_mapping", None)
        dims = None
        if 'dims' in kwargs:
            dims = kwargs.pop('dims')
        elif config is not None:
            dims = config.dims

        model = super().from_pretrained(*args, config=config, **kwargs, key_mapping=key_mapping)
        if dims is not None:
            model.dims = dims
        return model

    def forward(self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, pixel_values: Optional[torch.Tensor] = None, image_grid_thw: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
        has_pixel_values = pixel_values is not None

        if has_pixel_values:
            if image_grid_thw is None:
                raise ValueError("`image_grid_thw` must be provided when `pixel_values` is passed.")
            if not torch.is_tensor(image_grid_thw):
                image_grid_thw = torch.as_tensor(image_grid_thw, device=pixel_values.device)

            offsets = image_grid_thw.prod(dim=1)
            unpadded = [pixel_sequence[: int(offset.item())] for pixel_sequence, offset in zip(pixel_values, offsets)]
            pixel_values = torch.cat(unpadded, dim=0) if unpadded else None

        outputs = self.qwen3vl(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values=pixel_values,
            image_grid_thw=image_grid_thw,
            use_cache=False,
            output_hidden_states=True,
            return_dict=True,
        )

        last_hidden_states = outputs.last_hidden_state
        proj = self.custom_text_proj(last_hidden_states)

        if self.dims < self.config.text_config.hidden_size:
            proj = proj[..., : self.dims]

        proj = proj / proj.norm(dim=-1, keepdim=True)

        if attention_mask is not None:
            proj = proj * attention_mask.unsqueeze(-1)

        if has_pixel_values and self.mask_non_image_embeddings and input_ids is not None:
            image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
            proj = proj * image_mask

        return proj

    @property
    def patch_size(self) -> int:
        return self.qwen3vl.visual.config.patch_size

    @property
    def spatial_merge_size(self) -> int:
        return self.qwen3vl.visual.config.spatial_merge_size