| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| from typing import Any, Optional |
|
|
| import torch |
| from PIL import Image |
| from transformers import PreTrainedModel |
| from transformers.utils import ModelOutput |
|
|
| from losses.hungarian_matcher import build_criterion |
| from .configuration_internvl_ovd import InternVLOVDConfig |
| from .internvl_ovd import build_internvl_ovd |
| from .internvl_image_procesing import load_image |
|
|
|
|
| @dataclass |
| class InternVLOVDOutput(ModelOutput): |
| loss: torch.Tensor | None = None |
| pred_boxes: torch.Tensor | None = None |
| pred_scores: torch.Tensor | None = None |
| loss_total: torch.Tensor | None = None |
| loss_bbox: torch.Tensor | None = None |
| loss_giou: torch.Tensor | None = None |
| loss_cls: torch.Tensor | None = None |
|
|
|
|
| class InternVLOVDForDetection(PreTrainedModel): |
| config_class = InternVLOVDConfig |
|
|
| def __init__(self, config: InternVLOVDConfig) -> None: |
| super().__init__(config) |
|
|
| amp_dtype = torch.bfloat16 if config.dtype == "bfloat16" else torch.float16 |
|
|
| self.inner = build_internvl_ovd( |
| model_config=config, |
| device=config.device_map, |
| dtype=amp_dtype, |
| ) |
|
|
| if config.freeze_backbone: |
| for name, param in self.inner.named_parameters(): |
| if name.startswith("backbone.vlm.vision_model"): |
| param.requires_grad = False |
|
|
| |
| self._criterion = None |
|
|
| @property |
| def criterion(self) -> torch.nn.Module: |
| if self._criterion is None: |
| cfg = self.config |
| self._criterion = build_criterion( |
| cost_bbox=cfg.cost_bbox, |
| cost_giou=cfg.cost_giou, |
| cost_class=cfg.cost_class, |
| loss_bbox=cfg.loss_bbox, |
| loss_giou=cfg.loss_giou, |
| loss_cls=cfg.loss_cls, |
| eos_coef=cfg.eos_coef, |
| use_focal_loss=cfg.use_focal_loss, |
| focal_alpha=cfg.focal_alpha, |
| focal_gamma=cfg.focal_gamma, |
| loss_mode=cfg.loss_mode, |
| ) |
| return self._criterion |
|
|
| def forward_inference( |
| self, |
| *, |
| pixel_values: torch.Tensor, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| patch_mask: Optional[torch.Tensor] = None, |
| **kwargs: Any, |
| ) -> InternVLOVDOutput: |
| if hasattr(self.inner.backbone.vlm, "vision_model"): |
| self.inner.backbone.vlm.vision_model.eval() |
|
|
| pred_boxes, pred_scores = self.inner( |
| pixel_values=pixel_values, |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| patch_mask=patch_mask, |
| ) |
| return InternVLOVDOutput(loss=None, pred_boxes=pred_boxes, pred_scores=pred_scores) |
|
|
| def forward( |
| self, |
| pixel_values: torch.Tensor, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| patch_mask: Optional[torch.Tensor] = None, |
| boxes: Optional[torch.Tensor] = None, |
| box_mask: Optional[torch.Tensor] = None, |
| compute_loss: bool = False, |
| **kwargs: Any, |
| ) -> InternVLOVDOutput: |
| outputs = self.forward_inference( |
| pixel_values=pixel_values, |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| patch_mask=patch_mask, |
| **kwargs, |
| ) |
|
|
| if not compute_loss: |
| return outputs |
|
|
| if boxes is None or box_mask is None: |
| raise ValueError("compute_loss=True requires both `boxes` and `box_mask`.") |
|
|
| pred_boxes = outputs.pred_boxes |
| pred_scores = outputs.pred_scores |
| losses = self.criterion(pred_boxes, pred_scores, boxes, box_mask) |
| loss_total = losses.get("loss_total") |
| return InternVLOVDOutput( |
| loss=loss_total, |
| pred_boxes=pred_boxes, |
| pred_scores=pred_scores, |
| loss_total=loss_total, |
| loss_bbox=losses.get("loss_bbox"), |
| loss_giou=losses.get("loss_giou"), |
| loss_cls=losses.get("loss_cls"), |
| ) |
|
|
| @torch.no_grad() |
| def infer_image( |
| self, |
| *, |
| image: Image.Image | str, |
| query: str, |
| tokenizer, |
| max_length: int = 4096, |
| device: Optional[torch.device] = None, |
| ) -> InternVLOVDOutput: |
| """ |
| Convenience inference helper that accepts a PIL image (or path) + query text. |
| Handles image preprocessing and prompt construction. |
| """ |
| cfg = self.config |
| if device is None: |
| device = next(self.parameters()).device |
| amp_dtype = torch.bfloat16 if cfg.dtype == "bfloat16" else torch.float16 |
| if device.type == "cpu" and amp_dtype == torch.float16: |
| amp_dtype = torch.bfloat16 |
|
|
| pixel_values = load_image(image, input_size=cfg.input_size, max_num=cfg.max_num_patches) |
| num_patches = int(pixel_values.shape[0]) |
| pixel_values = pixel_values.unsqueeze(0) |
| patch_mask = torch.ones((1, num_patches), dtype=torch.bool) |
|
|
| img_context_token = "<IMG_CONTEXT>" |
| img_start_token = "<img>" |
| img_end_token = "</img>" |
| tokens_per_patch = 256 |
|
|
| image_tokens = img_start_token + img_context_token * (tokens_per_patch * num_patches) + img_end_token |
| prompt = ( |
| f"{image_tokens}\n" |
| "Please provide the bounding box coordinate of the region this sentence describes: " |
| f"<ref>{query}</ref>" |
| ) |
|
|
| tokens = tokenizer( |
| [prompt], |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| max_length=max_length, |
| ) |
|
|
| pixel_values = pixel_values.to(device=device, dtype=amp_dtype) |
| patch_mask = patch_mask.to(device=device) |
| input_ids = tokens["input_ids"].to(device=device) |
| attention_mask = tokens["attention_mask"].to(device=device) |
|
|
| self.eval() |
| amp_device_type = "cuda" if device.type == "cuda" else "cpu" |
| with torch.amp.autocast(device_type=amp_device_type, dtype=amp_dtype): |
| return self.forward_inference( |
| pixel_values=pixel_values, |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| patch_mask=patch_mask, |
| ) |
|
|
| @torch.no_grad() |
| def infer_batch( |
| self, |
| *, |
| image: Image.Image | str, |
| queries: list[str], |
| tokenizer, |
| max_length: int = 4096, |
| device: Optional[torch.device] = None, |
| ) -> InternVLOVDOutput: |
| """ |
| Batch inference helper that accepts a single PIL image (or path) + multiple query texts. |
| Each query produces one detection result, all sharing the same image. |
| |
| Args: |
| image: PIL Image or path/URL to image |
| queries: List of query strings |
| tokenizer: Tokenizer instance |
| max_length: Maximum token length |
| device: Device to run inference on |
| |
| Returns: |
| InternVLOVDOutput with pred_boxes shape (batch_size, num_queries, 4) |
| """ |
| cfg = self.config |
| if device is None: |
| device = next(self.parameters()).device |
| amp_dtype = torch.bfloat16 if cfg.dtype == "bfloat16" else torch.float16 |
| if device.type == "cpu" and amp_dtype == torch.float16: |
| amp_dtype = torch.bfloat16 |
|
|
| |
| pixel_values = load_image(image, input_size=cfg.input_size, max_num=cfg.max_num_patches) |
| num_patches = int(pixel_values.shape[0]) |
|
|
| img_context_token = "<IMG_CONTEXT>" |
| img_start_token = "<img>" |
| img_end_token = "</img>" |
| tokens_per_patch = 256 |
|
|
| image_tokens = img_start_token + img_context_token * (tokens_per_patch * num_patches) + img_end_token |
|
|
| |
| prompts = [] |
| for query in queries: |
| prompt = ( |
| f"{image_tokens}\n" |
| "Please provide the bounding box coordinate of the region this sentence describes: " |
| f"<ref>{query}</ref>" |
| ) |
| prompts.append(prompt) |
|
|
| |
| tokens = tokenizer( |
| prompts, |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| max_length=max_length, |
| ) |
|
|
| |
| batch_size = len(queries) |
| pixel_values = pixel_values.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) |
| pixel_values = pixel_values.view(-1, *pixel_values.shape[2:]) |
| patch_mask = torch.ones((batch_size, num_patches), dtype=torch.bool) |
|
|
| pixel_values = pixel_values.to(device=device, dtype=amp_dtype) |
| patch_mask = patch_mask.to(device=device) |
| input_ids = tokens["input_ids"].to(device=device) |
| attention_mask = tokens["attention_mask"].to(device=device) |
|
|
| self.eval() |
| amp_device_type = "cuda" if device.type == "cuda" else "cpu" |
| with torch.amp.autocast(device_type=amp_device_type, dtype=amp_dtype): |
| return self.forward_inference( |
| pixel_values=pixel_values, |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| patch_mask=patch_mask, |
| ) |
|
|