OVD_SOSP-B_Internvl_model2 / vlmbackbone.py
xpuenabler's picture
Upload folder using huggingface_hub
e3454bb verified
import torch
from torch import nn
from transformers import AutoModel, AutoTokenizer
import torch.nn.functional as F
class VlmBackboneBase(nn.Module):
"""
공통 VLM 비전 백본 인터페이스.
- forward_vision(pixel_values) -> (image_tokens, padding_mask)
image_tokens: (B, L, D)
padding_mask: (B, L) (True == pad)
"""
def __init__(self) -> None:
super().__init__()
def forward_fused(
self,
pixel_values: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor | None]:
raise NotImplementedError
class InternVL3_5_Backbone(VlmBackboneBase):
def __init__(
self,
model_name: str,
device: str,
dtype: torch.dtype,
*,
use_token_fpn: bool = False,
token_fpn_levels: tuple[int, ...] = (16, 8, 4, 2),
token_fpn_include_text: bool = True,
) -> None:
super().__init__()
self.device = device
self.dtype = dtype
self.use_token_fpn = use_token_fpn
self.token_fpn_levels = token_fpn_levels
self.token_fpn_include_text = token_fpn_include_text
self.vlm = AutoModel.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype=dtype,
low_cpu_mem_usage=False,
device_map=None,
_attn_implementation="flash_attention_2"
)
self.hidden_size_llm = 1024 # InternVL3_5 text hidden dim
self.hidden_size_detr = 1024 # DETR d_model
self.fused_proj = nn.Linear(
self.hidden_size_llm,
self.hidden_size_detr,
bias=True,
device=None,
dtype=dtype,
)
nn.init.eye_(self.fused_proj.weight)
nn.init.zeros_(self.fused_proj.bias)
# Set img_context_token_id for the model
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False)
IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>"
self.img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
self.vlm.img_context_token_id = self.img_context_token_id
def _build_token_fpn_memory(
self,
memory_last: torch.Tensor, # (B, T, D)
input_ids: torch.Tensor, # (B, T)
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Build an "FPN-like" multi-level token memory from IMG_CONTEXT token embeddings.
- Extract IMG_CONTEXT tokens per sample
- Reshape per patch into (num_patches, 16, 16, D)
- Pool to multiple spatial levels (e.g., 16->8->4->2)
- Flatten and concatenate levels into one sequence
- Optionally append non-image tokens (text + special tokens)
Returns:
memory: (B, L, D)
padding_mask: (B, L) with True == pad
"""
B, T, D = memory_last.shape
device = memory_last.device
# Validate levels (must be descending powers of 2 from 16)
levels = tuple(int(x) for x in self.token_fpn_levels)
if len(levels) == 0 or levels[0] != 16:
raise ValueError(f"token_fpn_levels must start with 16, got {levels}")
for a, b in zip(levels, levels[1:]):
if a % 2 != 0 or b != a // 2:
raise ValueError(f"token_fpn_levels must be like (16,8,4,2,...) got {levels}")
is_img = input_ids.eq(self.img_context_token_id) # (B, T)
per_sample_memory: list[torch.Tensor] = []
max_len = 0
for i in range(B):
img_tokens = memory_last[i][is_img[i]] # (N_img, D)
n_img = img_tokens.shape[0]
if n_img == 0:
# Fallback: no img tokens found -> keep original memory.
mem_i = memory_last[i]
else:
if n_img % 256 != 0:
raise ValueError(f"IMG_CONTEXT token count must be multiple of 256, got {n_img}")
num_patches = n_img // 256
# (num_patches, D, 16, 16)
patch_feat = img_tokens.view(num_patches, 16, 16, D).permute(0, 3, 1, 2).contiguous()
# Build levels by pooling
level_tokens: list[torch.Tensor] = []
feat = patch_feat
cur = 16
for lvl in levels:
# Ensure feat is at correct resolution
while cur > lvl:
feat = F.avg_pool2d(feat, kernel_size=2, stride=2)
cur //= 2
# Flatten: (num_patches, D, H, W) -> (num_patches*H*W, D)
h, w = feat.shape[-2:]
lvl_tok = feat.permute(0, 2, 3, 1).reshape(num_patches * h * w, D).contiguous()
level_tokens.append(lvl_tok)
mem_i = torch.cat(level_tokens, dim=0) # (L_img_fpn, D)
if self.token_fpn_include_text:
txt_tokens = memory_last[i][~is_img[i]] # (N_txt, D)
mem_i = torch.cat([txt_tokens, mem_i], dim=0)
per_sample_memory.append(mem_i)
max_len = max(max_len, mem_i.shape[0])
# Pad to (B, max_len, D)
padded = memory_last.new_zeros((B, max_len, D))
padding_mask = torch.ones((B, max_len), device=device, dtype=torch.bool)
for i, mem_i in enumerate(per_sample_memory):
seq_len = mem_i.shape[0]
padded[i, :seq_len] = mem_i
padding_mask[i, :seq_len] = False
return padded, padding_mask
def forward_fused(
self,
pixel_values: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
patch_mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
# pixel_values: (B, P, 3, H, W)
if pixel_values.dim() == 5:
bsz, num_patches, channels, height, width = pixel_values.shape
pixel_values = pixel_values.view(-1, channels, height, width)
if patch_mask is not None:
patch_mask = patch_mask.view(bsz * num_patches)
else:
bsz = pixel_values.shape[0]
num_patches = 1
patch_mask = None
# image_flags must match number of images provided to the VLM
if patch_mask is not None:
image_flags = patch_mask.to(pixel_values.device, dtype=torch.long)
else:
image_flags = torch.ones(pixel_values.shape[0], dtype=torch.long, device=pixel_values.device)
outputs = self.vlm(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
image_flags=image_flags,
output_hidden_states=True,
return_dict=True
)
# CausalLMOutputWithPast has hidden_states tuple, last element is the final layer output
memory = outputs.hidden_states[-1] # (B, T, hidden_size)
memory = self.fused_proj(memory) # (B, T, hidden_size_detr)
if self.use_token_fpn:
memory, padding_mask = self._build_token_fpn_memory(
memory_last=memory,
input_ids=input_ids,
)
return memory, padding_mask
return memory, None