| import torch |
| from torch import nn |
| from transformers import AutoModel, AutoTokenizer |
| import torch.nn.functional as F |
|
|
|
|
| class VlmBackboneBase(nn.Module): |
| """ |
| 공통 VLM 비전 백본 인터페이스. |
| - forward_vision(pixel_values) -> (image_tokens, padding_mask) |
| image_tokens: (B, L, D) |
| padding_mask: (B, L) (True == pad) |
| """ |
|
|
| def __init__(self) -> None: |
| super().__init__() |
|
|
| def forward_fused( |
| self, |
| pixel_values: torch.Tensor, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| ) -> tuple[torch.Tensor, torch.Tensor | None]: |
| raise NotImplementedError |
|
|
|
|
| class InternVL3_5_Backbone(VlmBackboneBase): |
| def __init__( |
| self, |
| model_name: str, |
| device: str, |
| dtype: torch.dtype, |
| *, |
| use_token_fpn: bool = False, |
| token_fpn_levels: tuple[int, ...] = (16, 8, 4, 2), |
| token_fpn_include_text: bool = True, |
| ) -> None: |
| super().__init__() |
| self.device = device |
| self.dtype = dtype |
| self.use_token_fpn = use_token_fpn |
| self.token_fpn_levels = token_fpn_levels |
| self.token_fpn_include_text = token_fpn_include_text |
| self.vlm = AutoModel.from_pretrained( |
| model_name, |
| trust_remote_code=True, |
| torch_dtype=dtype, |
| low_cpu_mem_usage=False, |
| device_map=None, |
| _attn_implementation="flash_attention_2" |
| ) |
| |
| self.hidden_size_llm = 1024 |
| self.hidden_size_detr = 1024 |
|
|
| self.fused_proj = nn.Linear( |
| self.hidden_size_llm, |
| self.hidden_size_detr, |
| bias=True, |
| device=None, |
| dtype=dtype, |
| ) |
| nn.init.eye_(self.fused_proj.weight) |
| nn.init.zeros_(self.fused_proj.bias) |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False) |
| IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>" |
| self.img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) |
| self.vlm.img_context_token_id = self.img_context_token_id |
|
|
| def _build_token_fpn_memory( |
| self, |
| memory_last: torch.Tensor, |
| input_ids: torch.Tensor, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Build an "FPN-like" multi-level token memory from IMG_CONTEXT token embeddings. |
| |
| - Extract IMG_CONTEXT tokens per sample |
| - Reshape per patch into (num_patches, 16, 16, D) |
| - Pool to multiple spatial levels (e.g., 16->8->4->2) |
| - Flatten and concatenate levels into one sequence |
| - Optionally append non-image tokens (text + special tokens) |
| |
| Returns: |
| memory: (B, L, D) |
| padding_mask: (B, L) with True == pad |
| """ |
| B, T, D = memory_last.shape |
| device = memory_last.device |
|
|
| |
| levels = tuple(int(x) for x in self.token_fpn_levels) |
| if len(levels) == 0 or levels[0] != 16: |
| raise ValueError(f"token_fpn_levels must start with 16, got {levels}") |
| for a, b in zip(levels, levels[1:]): |
| if a % 2 != 0 or b != a // 2: |
| raise ValueError(f"token_fpn_levels must be like (16,8,4,2,...) got {levels}") |
|
|
| is_img = input_ids.eq(self.img_context_token_id) |
|
|
| per_sample_memory: list[torch.Tensor] = [] |
| max_len = 0 |
| for i in range(B): |
| img_tokens = memory_last[i][is_img[i]] |
| n_img = img_tokens.shape[0] |
| if n_img == 0: |
| |
| mem_i = memory_last[i] |
| else: |
| if n_img % 256 != 0: |
| raise ValueError(f"IMG_CONTEXT token count must be multiple of 256, got {n_img}") |
| num_patches = n_img // 256 |
|
|
| |
| patch_feat = img_tokens.view(num_patches, 16, 16, D).permute(0, 3, 1, 2).contiguous() |
|
|
| |
| level_tokens: list[torch.Tensor] = [] |
| feat = patch_feat |
| cur = 16 |
| for lvl in levels: |
| |
| while cur > lvl: |
| feat = F.avg_pool2d(feat, kernel_size=2, stride=2) |
| cur //= 2 |
| |
| h, w = feat.shape[-2:] |
| lvl_tok = feat.permute(0, 2, 3, 1).reshape(num_patches * h * w, D).contiguous() |
| level_tokens.append(lvl_tok) |
|
|
| mem_i = torch.cat(level_tokens, dim=0) |
|
|
| if self.token_fpn_include_text: |
| txt_tokens = memory_last[i][~is_img[i]] |
| mem_i = torch.cat([txt_tokens, mem_i], dim=0) |
|
|
| per_sample_memory.append(mem_i) |
| max_len = max(max_len, mem_i.shape[0]) |
|
|
| |
| padded = memory_last.new_zeros((B, max_len, D)) |
| padding_mask = torch.ones((B, max_len), device=device, dtype=torch.bool) |
| for i, mem_i in enumerate(per_sample_memory): |
| seq_len = mem_i.shape[0] |
| padded[i, :seq_len] = mem_i |
| padding_mask[i, :seq_len] = False |
|
|
| return padded, padding_mask |
|
|
| def forward_fused( |
| self, |
| pixel_values: torch.Tensor, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| patch_mask: torch.Tensor | None = None, |
| ) -> tuple[torch.Tensor, torch.Tensor | None]: |
| |
| if pixel_values.dim() == 5: |
| bsz, num_patches, channels, height, width = pixel_values.shape |
| pixel_values = pixel_values.view(-1, channels, height, width) |
| if patch_mask is not None: |
| patch_mask = patch_mask.view(bsz * num_patches) |
| else: |
| bsz = pixel_values.shape[0] |
| num_patches = 1 |
| patch_mask = None |
|
|
| |
| if patch_mask is not None: |
| image_flags = patch_mask.to(pixel_values.device, dtype=torch.long) |
| else: |
| image_flags = torch.ones(pixel_values.shape[0], dtype=torch.long, device=pixel_values.device) |
| |
| outputs = self.vlm( |
| pixel_values=pixel_values, |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| image_flags=image_flags, |
| output_hidden_states=True, |
| return_dict=True |
| ) |
|
|
| |
| memory = outputs.hidden_states[-1] |
| memory = self.fused_proj(memory) |
| if self.use_token_fpn: |
| memory, padding_mask = self._build_token_fpn_memory( |
| memory_last=memory, |
| input_ids=input_ids, |
| ) |
| return memory, padding_mask |
|
|
| return memory, None |
|
|