Ops-Colqwen3-4B / modeling_ops_colqwen3.py
frozenc's picture
update usage
4894b7d verified
from typing import Optional
import torch
from torch import nn
from transformers import PreTrainedModel
from transformers.models.qwen3_vl import Qwen3VLModel
from transformers.utils import logging
from .configuration_ops_colqwen3 import OpsColQwen3Config
logger = logging.get_logger(__name__)
class OpsColQwen3PreTrainedModel(PreTrainedModel):
config_class = OpsColQwen3Config
base_model_prefix = "ops_colqwen3"
supports_gradient_checkpointing = True
_no_split_modules = ["Qwen3VLVisionBlock", "Qwen3DecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
class OpsColQwen3Model(OpsColQwen3PreTrainedModel):
_checkpoint_conversion_mapping = {
r"^language_model": r"qwen3vl.language_model",
r"^visual": "qwen3vl.visual",
}
def __init__(self, config: OpsColQwen3Config):
super().__init__(config)
self.config = config
self.qwen3vl = Qwen3VLModel(config)
self.dims = config.text_config.hidden_size
self.custom_text_proj = nn.Linear(config.text_config.hidden_size, self.dims)
self.mask_non_image_embeddings = config.mask_non_image_embeddings
self.post_init()
@classmethod
def from_pretrained(cls, *args, config: Optional[OpsColQwen3Config] = None, **kwargs):
key_mapping = kwargs.pop("key_mapping", None)
if key_mapping is None:
key_mapping = getattr(cls, "_checkpoint_conversion_mapping", None)
dims = None
if 'dims' in kwargs:
dims = kwargs.pop('dims')
elif config is not None:
dims = config.dims
model = super().from_pretrained(*args, config=config, **kwargs, key_mapping=key_mapping)
if dims is not None:
model.dims = dims
return model
def forward(self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, pixel_values: Optional[torch.Tensor] = None, image_grid_thw: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
has_pixel_values = pixel_values is not None
if has_pixel_values:
if image_grid_thw is None:
raise ValueError("`image_grid_thw` must be provided when `pixel_values` is passed.")
if not torch.is_tensor(image_grid_thw):
image_grid_thw = torch.as_tensor(image_grid_thw, device=pixel_values.device)
offsets = image_grid_thw.prod(dim=1)
unpadded = [pixel_sequence[: int(offset.item())] for pixel_sequence, offset in zip(pixel_values, offsets)]
pixel_values = torch.cat(unpadded, dim=0) if unpadded else None
outputs = self.qwen3vl(
input_ids=input_ids,
attention_mask=attention_mask,
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
use_cache=False,
output_hidden_states=True,
return_dict=True,
)
last_hidden_states = outputs.last_hidden_state
proj = self.custom_text_proj(last_hidden_states)
if self.dims < self.config.text_config.hidden_size:
proj = proj[..., : self.dims]
proj = proj / proj.norm(dim=-1, keepdim=True)
if attention_mask is not None:
proj = proj * attention_mask.unsqueeze(-1)
if has_pixel_values and self.mask_non_image_embeddings and input_ids is not None:
image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
proj = proj * image_mask
return proj
@property
def patch_size(self) -> int:
return self.qwen3vl.visual.config.patch_size
@property
def spatial_merge_size(self) -> int:
return self.qwen3vl.visual.config.spatial_merge_size