OVD_SOSP-B_Internvl_model2 / modeling_internvl_ovd.py
xpuenabler's picture
Update modeling_internvl_ovd.py
66e82a5 verified
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Optional
import torch
from PIL import Image
from transformers import PreTrainedModel
from transformers.utils import ModelOutput
from losses.hungarian_matcher import build_criterion
from .configuration_internvl_ovd import InternVLOVDConfig
from .internvl_ovd import build_internvl_ovd
from .internvl_image_procesing import load_image
@dataclass
class InternVLOVDOutput(ModelOutput):
loss: torch.Tensor | None = None
pred_boxes: torch.Tensor | None = None
pred_scores: torch.Tensor | None = None
loss_total: torch.Tensor | None = None
loss_bbox: torch.Tensor | None = None
loss_giou: torch.Tensor | None = None
loss_cls: torch.Tensor | None = None
class InternVLOVDForDetection(PreTrainedModel):
config_class = InternVLOVDConfig
def __init__(self, config: InternVLOVDConfig) -> None:
super().__init__(config)
amp_dtype = torch.bfloat16 if config.dtype == "bfloat16" else torch.float16
self.inner = build_internvl_ovd(
model_config=config,
device=config.device_map,
dtype=amp_dtype,
)
if config.freeze_backbone:
for name, param in self.inner.named_parameters():
if name.startswith("backbone.vlm.vision_model"):
param.requires_grad = False
# Training criterion is created lazily to keep Hub inference-only loads minimal.
self._criterion = None
@property
def criterion(self) -> torch.nn.Module:
if self._criterion is None:
cfg = self.config
self._criterion = build_criterion(
cost_bbox=cfg.cost_bbox,
cost_giou=cfg.cost_giou,
cost_class=cfg.cost_class,
loss_bbox=cfg.loss_bbox,
loss_giou=cfg.loss_giou,
loss_cls=cfg.loss_cls,
eos_coef=cfg.eos_coef,
use_focal_loss=cfg.use_focal_loss,
focal_alpha=cfg.focal_alpha,
focal_gamma=cfg.focal_gamma,
loss_mode=cfg.loss_mode,
)
return self._criterion
def forward_inference(
self,
*,
pixel_values: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
patch_mask: Optional[torch.Tensor] = None,
**kwargs: Any,
) -> InternVLOVDOutput:
if hasattr(self.inner.backbone.vlm, "vision_model"):
self.inner.backbone.vlm.vision_model.eval()
pred_boxes, pred_scores = self.inner(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
patch_mask=patch_mask,
)
return InternVLOVDOutput(loss=None, pred_boxes=pred_boxes, pred_scores=pred_scores)
def forward(
self,
pixel_values: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
patch_mask: Optional[torch.Tensor] = None,
boxes: Optional[torch.Tensor] = None,
box_mask: Optional[torch.Tensor] = None,
compute_loss: bool = False,
**kwargs: Any,
) -> InternVLOVDOutput:
outputs = self.forward_inference(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
patch_mask=patch_mask,
**kwargs,
)
if not compute_loss:
return outputs
if boxes is None or box_mask is None:
raise ValueError("compute_loss=True requires both `boxes` and `box_mask`.")
pred_boxes = outputs.pred_boxes
pred_scores = outputs.pred_scores
losses = self.criterion(pred_boxes, pred_scores, boxes, box_mask)
loss_total = losses.get("loss_total")
return InternVLOVDOutput(
loss=loss_total,
pred_boxes=pred_boxes,
pred_scores=pred_scores,
loss_total=loss_total,
loss_bbox=losses.get("loss_bbox"),
loss_giou=losses.get("loss_giou"),
loss_cls=losses.get("loss_cls"),
)
@torch.no_grad()
def infer_image(
self,
*,
image: Image.Image | str,
query: str,
tokenizer,
max_length: int = 4096,
device: Optional[torch.device] = None,
) -> InternVLOVDOutput:
"""
Convenience inference helper that accepts a PIL image (or path) + query text.
Handles image preprocessing and prompt construction.
"""
cfg = self.config
if device is None:
device = next(self.parameters()).device
amp_dtype = torch.bfloat16 if cfg.dtype == "bfloat16" else torch.float16
if device.type == "cpu" and amp_dtype == torch.float16:
amp_dtype = torch.bfloat16
pixel_values = load_image(image, input_size=cfg.input_size, max_num=cfg.max_num_patches)
num_patches = int(pixel_values.shape[0])
pixel_values = pixel_values.unsqueeze(0)
patch_mask = torch.ones((1, num_patches), dtype=torch.bool)
img_context_token = "<IMG_CONTEXT>"
img_start_token = "<img>"
img_end_token = "</img>"
tokens_per_patch = 256
image_tokens = img_start_token + img_context_token * (tokens_per_patch * num_patches) + img_end_token
prompt = (
f"{image_tokens}\n"
"Please provide the bounding box coordinate of the region this sentence describes: "
f"<ref>{query}</ref>"
)
tokens = tokenizer(
[prompt],
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length,
)
pixel_values = pixel_values.to(device=device, dtype=amp_dtype)
patch_mask = patch_mask.to(device=device)
input_ids = tokens["input_ids"].to(device=device)
attention_mask = tokens["attention_mask"].to(device=device)
self.eval()
amp_device_type = "cuda" if device.type == "cuda" else "cpu"
with torch.amp.autocast(device_type=amp_device_type, dtype=amp_dtype):
return self.forward_inference(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
patch_mask=patch_mask,
)
@torch.no_grad()
def infer_batch(
self,
*,
image: Image.Image | str,
queries: list[str],
tokenizer,
max_length: int = 4096,
device: Optional[torch.device] = None,
) -> InternVLOVDOutput:
"""
Batch inference helper that accepts a single PIL image (or path) + multiple query texts.
Each query produces one detection result, all sharing the same image.
Args:
image: PIL Image or path/URL to image
queries: List of query strings
tokenizer: Tokenizer instance
max_length: Maximum token length
device: Device to run inference on
Returns:
InternVLOVDOutput with pred_boxes shape (batch_size, num_queries, 4)
"""
cfg = self.config
if device is None:
device = next(self.parameters()).device
amp_dtype = torch.bfloat16 if cfg.dtype == "bfloat16" else torch.float16
if device.type == "cpu" and amp_dtype == torch.float16:
amp_dtype = torch.bfloat16
# Load image once
pixel_values = load_image(image, input_size=cfg.input_size, max_num=cfg.max_num_patches)
num_patches = int(pixel_values.shape[0])
img_context_token = "<IMG_CONTEXT>"
img_start_token = "<img>"
img_end_token = "</img>"
tokens_per_patch = 256
image_tokens = img_start_token + img_context_token * (tokens_per_patch * num_patches) + img_end_token
# Build prompts for all queries
prompts = []
for query in queries:
prompt = (
f"{image_tokens}\n"
"Please provide the bounding box coordinate of the region this sentence describes: "
f"<ref>{query}</ref>"
)
prompts.append(prompt)
# Tokenize all prompts with padding
tokens = tokenizer(
prompts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length,
)
# Repeat pixel_values for each query
batch_size = len(queries)
pixel_values = pixel_values.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) # (B, num_patches, 3, H, W)
pixel_values = pixel_values.view(-1, *pixel_values.shape[2:]) # (B*num_patches, 3, H, W)
patch_mask = torch.ones((batch_size, num_patches), dtype=torch.bool)
pixel_values = pixel_values.to(device=device, dtype=amp_dtype)
patch_mask = patch_mask.to(device=device)
input_ids = tokens["input_ids"].to(device=device)
attention_mask = tokens["attention_mask"].to(device=device)
self.eval()
amp_device_type = "cuda" if device.type == "cuda" else "cpu"
with torch.amp.autocast(device_type=amp_device_type, dtype=amp_dtype):
return self.forward_inference(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
patch_mask=patch_mask,
)