xpuenabler commited on
Commit
e3454bb
·
verified ·
1 Parent(s): 054fad9

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## InternVL-OVD (inference-only)
2
+
3
+ This repository contains inference-only artifacts exported from a training checkpoint.
4
+
5
+ ### Quick start (single image)
6
+
7
+ ```python
8
+ import torch
9
+ from PIL import Image
10
+ from transformers import AutoConfig, AutoModel, AutoTokenizer
11
+
12
+ repo_id = "YOUR_ORG/YOUR_MODEL"
13
+ image_path = "flower.jpg"
14
+ query = "flower"
15
+
16
+ cfg = AutoConfig.from_pretrained(repo_id, trust_remote_code=True)
17
+ tokenizer = AutoTokenizer.from_pretrained(cfg.vlm_model_name, trust_remote_code=True, use_fast=False)
18
+ model = AutoModel.from_pretrained(repo_id, trust_remote_code=True)
19
+ model.eval()
20
+
21
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+ model = model.to(device)
23
+
24
+ pil = Image.open(image_path).convert("RGB")
25
+ outputs = model.infer_image(image=pil, query=query, tokenizer=tokenizer)
26
+
27
+ pred_boxes = outputs.pred_boxes[0].float().cpu()
28
+ pred_scores = outputs.pred_scores[0].squeeze(-1).float().sigmoid().cpu()
29
+ print(pred_boxes[:5])
30
+ print(pred_scores[:5])
31
+ ```
config.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "InternVLOVDForDetection"
4
+ ],
5
+ "head_type": "small",
6
+ "cost_bbox": 5.0,
7
+ "cost_class": 0.0,
8
+ "cost_giou": 2.0,
9
+ "device_map": "cuda",
10
+ "dim_feedforward": 1024,
11
+ "dropout": 0.0,
12
+ "dtype": "bfloat16",
13
+ "eos_coef": 0.1,
14
+ "focal_alpha": 0.75,
15
+ "focal_gamma": 2.0,
16
+ "freeze_backbone": true,
17
+ "hidden_size": 1024,
18
+ "loss_bbox": 5.0,
19
+ "loss_cls": 0.0,
20
+ "loss_giou": 2.0,
21
+ "loss_mode": "bbox_only",
22
+ "model_type": "internvl_ovd",
23
+ "nhead": 8,
24
+ "num_decoder_layers": 2,
25
+ "num_queries": 1,
26
+ "token_fpn_include_text": true,
27
+ "token_fpn_levels": [
28
+ 16,
29
+ 8,
30
+ 4,
31
+ 2
32
+ ],
33
+ "transformers_version": "4.57.3",
34
+ "use_focal_loss": false,
35
+ "use_token_fpn": false,
36
+ "vlm_model_name": "OpenGVLab/InternVL3_5-1B",
37
+ "auto_map": {
38
+ "AutoConfig": "configuration_internvl_ovd.InternVLOVDConfig",
39
+ "AutoModel": "modeling_internvl_ovd.InternVLOVDForDetection"
40
+ }
41
+ }
configuration_internvl_ovd.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Optional
4
+
5
+ from transformers import PretrainedConfig
6
+
7
+
8
+ def _as_tuple_ints(value: object, *, default: tuple[int, ...]) -> tuple[int, ...]:
9
+ """Normalize token_fpn_levels which may come from JSON as a list."""
10
+ if value is None:
11
+ return default
12
+ if isinstance(value, tuple):
13
+ return tuple(int(v) for v in value)
14
+ if isinstance(value, list):
15
+ return tuple(int(v) for v in value)
16
+ return default
17
+
18
+
19
+ class InternVLOVDConfig(PretrainedConfig):
20
+ """
21
+ HuggingFace configuration for InternVL-OVD.
22
+
23
+ NOTE:
24
+ - For convenience this config also stores some runtime/data fields used by this repo.
25
+ - For Hub inference, only model-relevant fields are strictly required.
26
+ """
27
+
28
+ model_type = "internvl_ovd"
29
+
30
+ def __init__(
31
+ self,
32
+ # Model/backbone
33
+ vlm_model_name: str = "OpenGVLab/InternVL3_5-1B",
34
+ hidden_size: int = 1024,
35
+ num_queries: int = 1,
36
+ use_token_fpn: bool = False,
37
+ token_fpn_levels: tuple[int, ...] = (16, 8, 4, 2),
38
+ token_fpn_include_text: bool = True,
39
+ head_type: str = "detr",
40
+ nhead: int = 8,
41
+ num_decoder_layers: int = 2,
42
+ dim_feedforward: int = 1024,
43
+ dropout: float = 0.0,
44
+ dtype: str = "bfloat16",
45
+ device_map: str = "cuda",
46
+ freeze_backbone: bool = True,
47
+ # Loss
48
+ cost_bbox: float = 5.0,
49
+ cost_giou: float = 2.0,
50
+ cost_class: float = 1.0,
51
+ loss_bbox: float = 5.0,
52
+ loss_giou: float = 2.0,
53
+ loss_cls: float = 1.0,
54
+ eos_coef: float = 0.1,
55
+ use_focal_loss: bool = False,
56
+ focal_alpha: float = 0.75,
57
+ focal_gamma: float = 2.0,
58
+ loss_mode: str = "hungarian",
59
+ # Data/runtime (repo convenience)
60
+ dataset_type: str = "unknown",
61
+ coco_root: str = "/coco/root/path",
62
+ train_ann_file: Optional[str] = None,
63
+ val_ann_file: Optional[str] = None,
64
+ train_img_dir: Optional[str] = None,
65
+ val_img_dir: Optional[str] = None,
66
+ refcoco_train_split: str = "val",
67
+ refcoco_val_split: str = "testB",
68
+ refcoco_train_max_samples: Optional[int] = None,
69
+ refcoco_val_max_samples: Optional[int] = None,
70
+ max_num_patches: int = 12,
71
+ input_size: int = 448,
72
+ max_instances: int = 100,
73
+ num_workers: int = 4,
74
+ # Train/runtime (repo convenience)
75
+ batch_size: int = 8,
76
+ gradient_accumulation_steps: int = 1,
77
+ num_epochs: int = 50,
78
+ lr: float = 1e-4,
79
+ weight_decay: float = 1e-4,
80
+ lr_scheduler: str = "cosine",
81
+ warmup_epochs: float = 1.0,
82
+ max_grad_norm: float = 0.1,
83
+ log_every: int = 50,
84
+ use_wandb: bool = True,
85
+ wandb_project: str = "internvl-ovd",
86
+ wandb_run_name: Optional[str] = None,
87
+ save_dir: str = "./checkpoints",
88
+ save_every_steps: int = 1000,
89
+ save_total_limit: int = 5,
90
+ val_every_steps: int = 1000,
91
+ eval_on_train: bool = True,
92
+ train_eval_ratio: float = 0.001,
93
+ train_eval_max_samples: Optional[int] = 128,
94
+ device: str = "cuda",
95
+ resume_from: Optional[str] = None,
96
+ seed: int = 42,
97
+ **kwargs: Any,
98
+ ) -> None:
99
+ super().__init__(**kwargs)
100
+
101
+ # Model/backbone
102
+ self.vlm_model_name = vlm_model_name
103
+ self.hidden_size = hidden_size
104
+ self.num_queries = num_queries
105
+ self.use_token_fpn = use_token_fpn
106
+ self.token_fpn_levels = _as_tuple_ints(token_fpn_levels, default=(16, 8, 4, 2))
107
+ self.token_fpn_include_text = token_fpn_include_text
108
+ self.head_type = head_type
109
+ self.nhead = nhead
110
+ self.num_decoder_layers = num_decoder_layers
111
+ self.dim_feedforward = dim_feedforward
112
+ self.dropout = dropout
113
+ self.dtype = dtype
114
+ self.device_map = device_map
115
+ self.freeze_backbone = freeze_backbone
116
+
117
+ # Loss
118
+ self.cost_bbox = cost_bbox
119
+ self.cost_giou = cost_giou
120
+ self.cost_class = cost_class
121
+ self.loss_bbox = loss_bbox
122
+ self.loss_giou = loss_giou
123
+ self.loss_cls = loss_cls
124
+ self.eos_coef = eos_coef
125
+ self.use_focal_loss = use_focal_loss
126
+ self.focal_alpha = focal_alpha
127
+ self.focal_gamma = focal_gamma
128
+ self.loss_mode = loss_mode
129
+
130
+ # Data
131
+ self.dataset_type = dataset_type
132
+ self.coco_root = coco_root
133
+ self.train_ann_file = train_ann_file or f"{self.coco_root}/annotations/instances_train2017.json"
134
+ self.val_ann_file = val_ann_file or f"{self.coco_root}/annotations/instances_val2017.json"
135
+ self.train_img_dir = train_img_dir or f"{self.coco_root}/train2017"
136
+ self.val_img_dir = val_img_dir or f"{self.coco_root}/val2017"
137
+ self.refcoco_train_split = refcoco_train_split
138
+ self.refcoco_val_split = refcoco_val_split
139
+ self.refcoco_train_max_samples = refcoco_train_max_samples
140
+ self.refcoco_val_max_samples = refcoco_val_max_samples
141
+ self.max_num_patches = max_num_patches
142
+ self.input_size = input_size
143
+ self.max_instances = max_instances
144
+ self.num_workers = num_workers
145
+
146
+ # Train/runtime
147
+ self.batch_size = batch_size
148
+ self.gradient_accumulation_steps = gradient_accumulation_steps
149
+ self.num_epochs = num_epochs
150
+ self.lr = lr
151
+ self.weight_decay = weight_decay
152
+ self.lr_scheduler = lr_scheduler
153
+ self.warmup_epochs = warmup_epochs
154
+ self.max_grad_norm = max_grad_norm
155
+ self.log_every = log_every
156
+ self.use_wandb = use_wandb
157
+ self.wandb_project = wandb_project
158
+ self.wandb_run_name = wandb_run_name
159
+ self.save_dir = save_dir
160
+ self.save_every_steps = save_every_steps
161
+ self.save_total_limit = save_total_limit
162
+ self.val_every_steps = val_every_steps
163
+ self.eval_on_train = eval_on_train
164
+ self.train_eval_ratio = train_eval_ratio
165
+ self.train_eval_max_samples = train_eval_max_samples
166
+ self.device = device
167
+ self.resume_from = resume_from
168
+ self.seed = seed
169
+
170
+
heads.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Optional, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+
10
+ class _CrossAttentionLayer(nn.Module):
11
+ def __init__(self, config: object) -> None:
12
+ super().__init__()
13
+ self.cross_attn = nn.MultiheadAttention(
14
+ embed_dim=config.hidden_size,
15
+ num_heads=config.nhead,
16
+ dropout=config.dropout,
17
+ batch_first=True,
18
+ )
19
+ self.linear1 = nn.Linear(config.hidden_size, config.dim_feedforward)
20
+ self.linear2 = nn.Linear(config.dim_feedforward, config.hidden_size)
21
+ self.norm1 = nn.LayerNorm(config.hidden_size)
22
+ self.norm2 = nn.LayerNorm(config.hidden_size)
23
+ self.dropout = nn.Dropout(config.dropout)
24
+
25
+ def forward(
26
+ self,
27
+ tgt: torch.Tensor, # (B,K,D)
28
+ memory: torch.Tensor, # (B,L,D)
29
+ memory_mask: Optional[torch.Tensor] = None,
30
+ ) -> torch.Tensor:
31
+ # cross-attn
32
+ residual = tgt
33
+ tgt2, _ = self.cross_attn(
34
+ query=tgt,
35
+ key=memory,
36
+ value=memory,
37
+ key_padding_mask=memory_mask,
38
+ )
39
+ tgt = residual + self.dropout(tgt2)
40
+ tgt = self.norm1(tgt)
41
+
42
+ # FFN
43
+ residual = tgt
44
+ tgt2 = self.linear2(F.gelu(self.linear1(tgt)))
45
+ tgt = residual + self.dropout(tgt2)
46
+ tgt = self.norm2(tgt)
47
+ return tgt
48
+
49
+
50
+ class _DetrDecoderLayer(nn.Module):
51
+ def __init__(self, config: object) -> None:
52
+ super().__init__()
53
+ self.self_attn = nn.MultiheadAttention(
54
+ embed_dim=config.hidden_size,
55
+ num_heads=config.nhead,
56
+ dropout=config.dropout,
57
+ batch_first=True,
58
+ )
59
+ self.cross_attn = nn.MultiheadAttention(
60
+ embed_dim=config.hidden_size,
61
+ num_heads=config.nhead,
62
+ dropout=config.dropout,
63
+ batch_first=True,
64
+ )
65
+ self.linear1 = nn.Linear(config.hidden_size, config.dim_feedforward)
66
+ self.linear2 = nn.Linear(config.dim_feedforward, config.hidden_size)
67
+ self.norm1 = nn.LayerNorm(config.hidden_size)
68
+ self.norm2 = nn.LayerNorm(config.hidden_size)
69
+ self.norm3 = nn.LayerNorm(config.hidden_size)
70
+ self.dropout = nn.Dropout(config.dropout)
71
+
72
+ def forward(
73
+ self,
74
+ tgt: torch.Tensor, # (B,K,D)
75
+ memory: torch.Tensor, # (B,L,D)
76
+ memory_mask: Optional[torch.Tensor] = None,
77
+ ) -> torch.Tensor:
78
+ # self-attn
79
+ residual = tgt
80
+ tgt2, _ = self.self_attn(tgt, tgt, tgt)
81
+ tgt = residual + self.dropout(tgt2)
82
+ tgt = self.norm1(tgt)
83
+
84
+ # cross-attn
85
+ residual = tgt
86
+ tgt2, _ = self.cross_attn(
87
+ query=tgt,
88
+ key=memory,
89
+ value=memory,
90
+ key_padding_mask=memory_mask,
91
+ )
92
+ tgt = residual + self.dropout(tgt2)
93
+ tgt = self.norm2(tgt)
94
+
95
+ # FFN
96
+ residual = tgt
97
+ tgt2 = self.linear2(F.gelu(self.linear1(tgt)))
98
+ tgt = residual + self.dropout(tgt2)
99
+ tgt = self.norm3(tgt)
100
+ return tgt
101
+
102
+
103
+ class DetrOvdHead(nn.Module):
104
+ """
105
+ Unified OVD head:
106
+ - head_type="small": single cross-attention pooling (fast)
107
+ - head_type="decoder": DETR-style decoder stack (heavier, experimental)
108
+ """
109
+
110
+ def __init__(self, config: object) -> None:
111
+ super().__init__()
112
+ self.config = config
113
+ self.num_queries = int(getattr(config, "num_queries"))
114
+ self.d_model = int(getattr(config, "hidden_size"))
115
+
116
+ head_type = getattr(config, "head_type", "small")
117
+ self.head_type = str(head_type)
118
+
119
+ self.query_embed = nn.Embedding(self.num_queries, self.d_model)
120
+
121
+ if self.head_type == "detr":
122
+ n_layers = int(getattr(config, "num_decoder_layers"))
123
+ self.layers = nn.ModuleList([_DetrDecoderLayer(config) for _ in range(n_layers)])
124
+ self.pooling = None
125
+ else:
126
+ # default: "small"
127
+ self.pooling = _CrossAttentionLayer(config)
128
+ self.layers = None
129
+
130
+ self.bbox_head = nn.Sequential(
131
+ nn.Linear(self.d_model, self.d_model),
132
+ nn.ReLU(),
133
+ nn.Linear(self.d_model, 4),
134
+ )
135
+ self.score_head = nn.Sequential(
136
+ nn.Linear(self.d_model, self.d_model),
137
+ nn.ReLU(),
138
+ nn.Linear(self.d_model, 1),
139
+ )
140
+
141
+ def forward(
142
+ self,
143
+ memory: torch.Tensor, # (B,L,D)
144
+ memory_mask: torch.Tensor | None = None, # (B,L) or None
145
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
146
+ B, _, _ = memory.shape
147
+ device = memory.device
148
+
149
+ # object queries
150
+ query_idx = torch.arange(self.num_queries, device=device)
151
+ tgt = self.query_embed(query_idx).unsqueeze(0).expand(B, -1, -1) # (B,K,D)
152
+
153
+ if self.head_type == "detr":
154
+ assert self.layers is not None
155
+ for layer in self.layers:
156
+ tgt = layer(tgt, memory, memory_mask)
157
+ else:
158
+ assert self.pooling is not None
159
+ tgt = self.pooling(tgt, memory, memory_mask)
160
+
161
+ # Predict (cx, cy, w, h) in [0, 1] range
162
+ pred_cxcywh = self.bbox_head(tgt).sigmoid() # (B,K,4), 0~1
163
+
164
+ # Convert to (x1, y1, x2, y2) format
165
+ cx, cy, w, h = pred_cxcywh.unbind(-1)
166
+ pred_boxes = torch.stack(
167
+ [
168
+ cx - w / 2, # x1
169
+ cy - h / 2, # y1
170
+ cx + w / 2, # x2
171
+ cy + h / 2, # y2
172
+ ],
173
+ dim=-1,
174
+ ).clamp(0, 1)
175
+
176
+ pred_logits = self.score_head(tgt) # (B,K,1), raw logits for BCE with logits
177
+ return pred_boxes, pred_logits
178
+
179
+
hungarian_matcher.py ADDED
@@ -0,0 +1,742 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hungarian Matcher and Loss Functions for OVD Training
3
+ Based on DETR's bipartite matching approach.
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from scipy.optimize import linear_sum_assignment
10
+
11
+
12
+ def generalized_box_iou(boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor:
13
+ """
14
+ Compute generalized IoU between two sets of boxes.
15
+
16
+ Args:
17
+ boxes1: (N, 4) in (x1, y1, x2, y2) format
18
+ boxes2: (M, 4) in (x1, y1, x2, y2) format
19
+
20
+ Returns:
21
+ giou: (N, M) GIoU matrix
22
+ """
23
+ # Compute intersection
24
+ lt = torch.max(boxes1[:, None, :2], boxes2[None, :, :2]) # (N, M, 2)
25
+ rb = torch.min(boxes1[:, None, 2:], boxes2[None, :, 2:]) # (N, M, 2)
26
+ wh = (rb - lt).clamp(min=0) # (N, M, 2)
27
+ inter = wh[:, :, 0] * wh[:, :, 1] # (N, M)
28
+
29
+ # Compute areas
30
+ area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1]) # (N,)
31
+ area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1]) # (M,)
32
+ union = area1[:, None] + area2[None, :] - inter # (N, M)
33
+
34
+ eps = 1e-6
35
+ iou = inter / (union + eps)
36
+
37
+ # Compute enclosing box
38
+ lt_enc = torch.min(boxes1[:, None, :2], boxes2[None, :, :2]) # (N, M, 2)
39
+ rb_enc = torch.max(boxes1[:, None, 2:], boxes2[None, :, 2:]) # (N, M, 2)
40
+ wh_enc = (rb_enc - lt_enc).clamp(min=0) # (N, M, 2)
41
+ area_enc = wh_enc[..., 0] * wh_enc[..., 1] # (N, M)
42
+
43
+ giou = iou - (area_enc - union) / (area_enc + eps)
44
+ return giou
45
+
46
+
47
+ def generalized_box_iou_pairwise(boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor:
48
+ """
49
+ Element-wise generalized IoU between two sets of boxes.
50
+
51
+ Args:
52
+ boxes1: (N, 4) in (x1, y1, x2, y2)
53
+ boxes2: (N, 4) in (x1, y1, x2, y2)
54
+
55
+ Returns:
56
+ giou: (N,) GIoU for each pair
57
+ """
58
+ assert boxes1.shape == boxes2.shape
59
+ # Intersection
60
+ lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # (N, 2)
61
+ rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # (N, 2)
62
+ wh = (rb - lt).clamp(min=0) # (N, 2)
63
+ inter = wh[:, 0] * wh[:, 1] # (N,)
64
+
65
+ # Areas
66
+ area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1]) # (N,)
67
+ area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1]) # (N,)
68
+ union = area1 + area2 - inter # (N,)
69
+
70
+ eps = 1e-6
71
+ iou = inter / (union + eps)
72
+
73
+ # Enclosing box
74
+ lt_enc = torch.min(boxes1[:, :2], boxes2[:, :2]) # (N, 2)
75
+ rb_enc = torch.max(boxes1[:, 2:], boxes2[:, 2:]) # (N, 2)
76
+ wh_enc = (rb_enc - lt_enc).clamp(min=0) # (N, 2)
77
+ area_enc = wh_enc[:, 0] * wh_enc[:, 1] # (N,)
78
+
79
+ giou = iou - (area_enc - union) / (area_enc + eps)
80
+ return giou
81
+
82
+
83
+ class HungarianMatcher(nn.Module):
84
+ """
85
+ Hungarian Matcher for bipartite matching between predictions and targets.
86
+ """
87
+
88
+ def __init__(
89
+ self,
90
+ cost_bbox: float = 5.0,
91
+ cost_giou: float = 2.0,
92
+ cost_class: float = 1.0,
93
+ ) -> None:
94
+ super().__init__()
95
+ self.cost_bbox = cost_bbox
96
+ self.cost_giou = cost_giou
97
+ self.cost_class = cost_class
98
+
99
+ @torch.no_grad()
100
+ def forward(
101
+ self,
102
+ pred_boxes: torch.Tensor,
103
+ pred_scores: torch.Tensor,
104
+ target_boxes: torch.Tensor,
105
+ target_mask: torch.Tensor,
106
+ ) -> list[tuple[torch.Tensor, torch.Tensor]]:
107
+ """
108
+ Compute bipartite matching between predictions and targets.
109
+
110
+ Args:
111
+ pred_boxes: (B, num_queries, 4) predicted boxes in [x1, y1, x2, y2] format
112
+ pred_scores: (B, num_queries, 1) predicted objectness scores (logits)
113
+ target_boxes: (B, max_targets, 4) target boxes
114
+ target_mask: (B, max_targets) True for valid targets
115
+
116
+ Returns:
117
+ List of (pred_indices, target_indices) tuples for each batch
118
+ """
119
+ B, num_queries, _ = pred_boxes.shape
120
+
121
+ indices: list[tuple[torch.Tensor, torch.Tensor]] = []
122
+ for b in range(B):
123
+ valid_mask = target_mask[b] # (max_targets,)
124
+ num_targets = int(valid_mask.sum().item())
125
+
126
+ if num_targets == 0:
127
+ # No targets, return empty matching
128
+ device = pred_boxes.device
129
+ indices.append(
130
+ (
131
+ torch.empty(0, dtype=torch.long, device=device),
132
+ torch.empty(0, dtype=torch.long, device=device),
133
+ )
134
+ )
135
+ continue
136
+
137
+ # Use mask indices instead of assuming prefix padding
138
+ valid_idx = valid_mask.nonzero(as_tuple=False).squeeze(-1) # (num_targets,)
139
+ tgt_boxes = target_boxes[b, valid_idx] # (num_targets, 4)
140
+
141
+ src_boxes = pred_boxes[b] # (num_queries, 4)
142
+ src_scores = pred_scores[b].squeeze(-1).sigmoid() # (num_queries,)
143
+
144
+ # Cast to float32 for cdist (bfloat16 not supported)
145
+ src_boxes_f32 = src_boxes.float()
146
+ tgt_boxes_f32 = tgt_boxes.float()
147
+
148
+ # L1 cost
149
+ cost_bbox = torch.cdist(src_boxes_f32, tgt_boxes_f32, p=1) # (num_queries, num_targets)
150
+
151
+ # GIoU cost
152
+ cost_giou = -generalized_box_iou(src_boxes_f32, tgt_boxes_f32) # (num_queries, num_targets)
153
+
154
+ # Classification cost (negative score for foreground)
155
+ cost_class = -src_scores.unsqueeze(1).expand(-1, num_targets) # (num_queries, num_targets)
156
+
157
+ # Total cost
158
+ C = (
159
+ self.cost_bbox * cost_bbox
160
+ + self.cost_giou * cost_giou
161
+ + self.cost_class * cost_class
162
+ )
163
+
164
+ # Hungarian matching (on CPU)
165
+ C_np = C.cpu().numpy()
166
+ pred_idx, tgt_local_idx = linear_sum_assignment(C_np)
167
+
168
+ # Map local target indices back to original indices
169
+ valid_idx = valid_idx.to(pred_boxes.device)
170
+ tgt_idx = valid_idx[tgt_local_idx]
171
+
172
+ device = pred_boxes.device
173
+ indices.append(
174
+ (
175
+ torch.as_tensor(pred_idx, dtype=torch.long, device=device),
176
+ tgt_idx.to(device=device, dtype=torch.long),
177
+ )
178
+ )
179
+
180
+ return indices
181
+
182
+ def sigmoid_focal_loss(
183
+ inputs: torch.Tensor,
184
+ targets: torch.Tensor,
185
+ alpha: float = 0.25,
186
+ gamma: float = 2.0,
187
+ reduction: str = "none",
188
+ ) -> torch.Tensor:
189
+ """
190
+ Focal Loss for dense classification.
191
+
192
+ Args:
193
+ inputs: logits (before sigmoid)
194
+ targets: binary targets (0 or 1)
195
+ alpha: weighting factor for positive class
196
+ gamma: focusing parameter (higher = more focus on hard examples)
197
+ reduction: 'none', 'mean', or 'sum'
198
+ """
199
+ p = torch.sigmoid(inputs)
200
+ ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
201
+ p_t = p * targets + (1 - p) * (1 - targets)
202
+ loss = ce_loss * ((1 - p_t) ** gamma)
203
+
204
+ # Apply alpha weighting
205
+ alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
206
+ loss = alpha_t * loss
207
+
208
+ if reduction == "mean":
209
+ return loss.mean()
210
+ elif reduction == "sum":
211
+ return loss.sum()
212
+ return loss
213
+
214
+ class OvdCriterion(nn.Module):
215
+ """
216
+ Loss criterion for OVD training with Hungarian matching.
217
+ """
218
+
219
+ def __init__(
220
+ self,
221
+ matcher: HungarianMatcher,
222
+ loss_bbox: float = 5.0,
223
+ loss_giou: float = 2.0,
224
+ loss_cls: float = 1.0,
225
+ eos_coef: float = 0.1, # Weight for no-object class (BCE mode only)
226
+ use_focal_loss: bool = True, # Use focal loss instead of BCE
227
+ focal_alpha: float = 0.75, # Higher weight for positive (sparse in OVD)
228
+ focal_gamma: float = 2.0,
229
+ ) -> None:
230
+ super().__init__()
231
+ self.matcher = matcher
232
+ self.loss_bbox_weight = loss_bbox
233
+ self.loss_giou_weight = loss_giou
234
+ self.loss_cls_weight = loss_cls
235
+ self.eos_coef = eos_coef
236
+ self.use_focal_loss = use_focal_loss
237
+ self.focal_alpha = focal_alpha
238
+ self.focal_gamma = focal_gamma
239
+
240
+ def loss_boxes(
241
+ self,
242
+ pred_boxes: torch.Tensor,
243
+ target_boxes: torch.Tensor,
244
+ indices: list[tuple[torch.Tensor, torch.Tensor]],
245
+ num_boxes: int,
246
+ ) -> dict[str, torch.Tensor]:
247
+ """
248
+ Compute bounding box losses (L1 + GIoU).
249
+ """
250
+ src_boxes_list: list[torch.Tensor] = []
251
+ tgt_boxes_list: list[torch.Tensor] = []
252
+
253
+ for b, (src_idx, tgt_idx) in enumerate(indices):
254
+ if src_idx.numel() > 0:
255
+ src_boxes_list.append(pred_boxes[b, src_idx])
256
+ tgt_boxes_list.append(target_boxes[b, tgt_idx])
257
+
258
+ if not src_boxes_list:
259
+ device = pred_boxes.device
260
+ return {
261
+ "loss_bbox": torch.tensor(0.0, device=device),
262
+ "loss_giou": torch.tensor(0.0, device=device),
263
+ }
264
+
265
+ src_boxes_all = torch.cat(src_boxes_list, dim=0).float() # (N, 4)
266
+ tgt_boxes_all = torch.cat(tgt_boxes_list, dim=0).float() # (N, 4)
267
+
268
+ # L1 loss
269
+ loss_bbox = F.l1_loss(src_boxes_all, tgt_boxes_all, reduction="sum") / max(num_boxes, 1)
270
+
271
+ # Element-wise GIoU loss (avoid building full N x N matrix)
272
+ giou = generalized_box_iou_pairwise(src_boxes_all, tgt_boxes_all) # (N,)
273
+ loss_giou = (1.0 - giou).sum() / max(num_boxes, 1)
274
+
275
+ return {
276
+ "loss_bbox": loss_bbox,
277
+ "loss_giou": loss_giou,
278
+ }
279
+
280
+ def loss_labels(
281
+ self,
282
+ pred_scores: torch.Tensor,
283
+ target_mask: torch.Tensor,
284
+ indices: list[tuple[torch.Tensor, torch.Tensor]],
285
+ num_boxes: int,
286
+ ) -> dict[str, torch.Tensor]:
287
+ """
288
+ Compute classification loss (objectness).
289
+ """
290
+ B, num_queries, _ = pred_scores.shape
291
+ device = pred_scores.device
292
+
293
+ # Create target labels: 1 for matched queries, 0 for unmatched
294
+ target_labels = torch.zeros(B, num_queries, dtype=torch.float32, device=device)
295
+
296
+ for b, (src_idx, _) in enumerate(indices):
297
+ if src_idx.numel() > 0:
298
+ target_labels[b, src_idx] = 1.0
299
+
300
+ pred_logits = pred_scores.squeeze(-1)
301
+
302
+ if self.use_focal_loss:
303
+ # Focal Loss: focuses on hard examples, handles class imbalance better
304
+ loss_per_query = sigmoid_focal_loss(
305
+ pred_logits,
306
+ target_labels,
307
+ alpha=self.focal_alpha,
308
+ gamma=self.focal_gamma,
309
+ reduction="none",
310
+ )
311
+ else:
312
+ # BCE with manual weighting for no-object class
313
+ loss_per_query = F.binary_cross_entropy_with_logits(
314
+ pred_logits,
315
+ target_labels,
316
+ reduction="none",
317
+ )
318
+
319
+ pos_mask = target_labels == 1
320
+ neg_mask = ~pos_mask
321
+
322
+ pos_count = max(int(pos_mask.sum().item()), 1)
323
+ neg_count = max(int(neg_mask.sum().item()), 1)
324
+
325
+ pos_loss = loss_per_query[pos_mask].sum() / pos_count
326
+ neg_loss = loss_per_query[neg_mask].sum() / neg_count
327
+
328
+ loss_cls = pos_loss + self.eos_coef * neg_loss
329
+
330
+ return {"loss_cls": loss_cls}
331
+
332
+ def forward(
333
+ self,
334
+ pred_boxes: torch.Tensor,
335
+ pred_scores: torch.Tensor,
336
+ target_boxes: torch.Tensor,
337
+ target_mask: torch.Tensor,
338
+ ) -> dict[str, torch.Tensor]:
339
+ """
340
+ Compute total loss.
341
+
342
+ Args:
343
+ pred_boxes: (B, num_queries, 4) predicted boxes
344
+ pred_scores: (B, num_queries, 1) predicted objectness scores (logits)
345
+ target_boxes: (B, max_targets, 4) target boxes
346
+ target_mask: (B, max_targets) True for valid targets
347
+
348
+ Returns:
349
+ Dictionary of losses
350
+ """
351
+ # Compute matching
352
+ indices = self.matcher(pred_boxes, pred_scores, target_boxes, target_mask)
353
+
354
+ # Count total number of target boxes for normalization
355
+ num_boxes = int(target_mask.sum().item())
356
+ num_boxes = max(num_boxes, 1)
357
+
358
+ # Box losses
359
+ box_losses = self.loss_boxes(pred_boxes, target_boxes, indices, num_boxes)
360
+ # Classification losses
361
+ cls_losses = self.loss_labels(pred_scores, target_mask, indices, num_boxes)
362
+
363
+ losses: dict[str, torch.Tensor] = {}
364
+ losses.update(box_losses)
365
+ losses.update(cls_losses)
366
+
367
+ # Total loss
368
+ losses["loss_total"] = (
369
+ self.loss_bbox_weight * losses["loss_bbox"]
370
+ + self.loss_giou_weight * losses["loss_giou"]
371
+ + self.loss_cls_weight * losses["loss_cls"]
372
+ )
373
+
374
+ return losses
375
+
376
+
377
+ class BboxOnlyCriterion(nn.Module):
378
+ """
379
+ Loss criterion that only uses bounding box losses (L1 + GIoU) with Hungarian matching.
380
+ No classification loss is computed.
381
+ """
382
+
383
+ def __init__(
384
+ self,
385
+ matcher: HungarianMatcher,
386
+ loss_bbox: float = 5.0,
387
+ loss_giou: float = 2.0,
388
+ ) -> None:
389
+ super().__init__()
390
+ self.matcher = matcher
391
+ self.loss_bbox_weight = loss_bbox
392
+ self.loss_giou_weight = loss_giou
393
+
394
+ def loss_boxes(
395
+ self,
396
+ pred_boxes: torch.Tensor,
397
+ target_boxes: torch.Tensor,
398
+ indices: list[tuple[torch.Tensor, torch.Tensor]],
399
+ num_boxes: int,
400
+ ) -> dict[str, torch.Tensor]:
401
+ """
402
+ Compute bounding box losses (L1 + GIoU).
403
+ """
404
+ src_boxes_list: list[torch.Tensor] = []
405
+ tgt_boxes_list: list[torch.Tensor] = []
406
+
407
+ for b, (src_idx, tgt_idx) in enumerate(indices):
408
+ if src_idx.numel() > 0:
409
+ src_boxes_list.append(pred_boxes[b, src_idx])
410
+ tgt_boxes_list.append(target_boxes[b, tgt_idx])
411
+
412
+ if not src_boxes_list:
413
+ device = pred_boxes.device
414
+ return {
415
+ "loss_bbox": torch.tensor(0.0, device=device),
416
+ "loss_giou": torch.tensor(0.0, device=device),
417
+ }
418
+
419
+ src_boxes_all = torch.cat(src_boxes_list, dim=0).float() # (N, 4)
420
+ tgt_boxes_all = torch.cat(tgt_boxes_list, dim=0).float() # (N, 4)
421
+
422
+ # L1 loss
423
+ loss_bbox = F.l1_loss(src_boxes_all, tgt_boxes_all, reduction="sum") / max(num_boxes, 1)
424
+
425
+ # Element-wise GIoU loss (avoid building full N x N matrix)
426
+ giou = generalized_box_iou_pairwise(src_boxes_all, tgt_boxes_all) # (N,)
427
+ loss_giou = (1.0 - giou).sum() / max(num_boxes, 1)
428
+
429
+ return {
430
+ "loss_bbox": loss_bbox,
431
+ "loss_giou": loss_giou,
432
+ }
433
+
434
+ def forward(
435
+ self,
436
+ pred_boxes: torch.Tensor,
437
+ pred_scores: torch.Tensor,
438
+ target_boxes: torch.Tensor,
439
+ target_mask: torch.Tensor,
440
+ ) -> dict[str, torch.Tensor]:
441
+ """
442
+ Compute total loss (bbox only).
443
+
444
+ Args:
445
+ pred_boxes: (B, num_queries, 4) predicted boxes
446
+ pred_scores: (B, num_queries, 1) predicted objectness scores (logits) - used for matching only
447
+ target_boxes: (B, max_targets, 4) target boxes
448
+ target_mask: (B, max_targets) True for valid targets
449
+
450
+ Returns:
451
+ Dictionary of losses
452
+ """
453
+ # Compute matching (still uses Hungarian matcher for optimal assignment)
454
+ indices = self.matcher(pred_boxes, pred_scores, target_boxes, target_mask)
455
+
456
+ # Count total number of target boxes for normalization
457
+ num_boxes = int(target_mask.sum().item())
458
+ num_boxes = max(num_boxes, 1)
459
+
460
+ # Box losses only
461
+ box_losses = self.loss_boxes(pred_boxes, target_boxes, indices, num_boxes)
462
+
463
+ losses: dict[str, torch.Tensor] = {}
464
+ losses.update(box_losses)
465
+ # Set classification loss to zero for compatibility
466
+ device = pred_boxes.device
467
+ losses["loss_cls"] = torch.tensor(0.0, device=device)
468
+
469
+ # Total loss (bbox only)
470
+ losses["loss_total"] = (
471
+ self.loss_bbox_weight * losses["loss_bbox"]
472
+ + self.loss_giou_weight * losses["loss_giou"]
473
+ )
474
+
475
+ return losses
476
+
477
+
478
+ def build_criterion(
479
+ cost_bbox: float = 5.0,
480
+ cost_giou: float = 2.0,
481
+ cost_class: float = 1.0,
482
+ loss_bbox: float = 5.0,
483
+ loss_giou: float = 2.0,
484
+ loss_cls: float = 1.0,
485
+ eos_coef: float = 0.1,
486
+ use_focal_loss: bool = True,
487
+ focal_alpha: float = 0.25,
488
+ focal_gamma: float = 2.0,
489
+ loss_mode: str = "hungarian",
490
+ ) -> OvdCriterion | BboxOnlyCriterion:
491
+ """
492
+ Build the loss criterion.
493
+
494
+ Args:
495
+ loss_mode: "hungarian" (full Hungarian matcher + all losses) or "bbox_only" (bbox loss only)
496
+ """
497
+ matcher = HungarianMatcher(
498
+ cost_bbox=cost_bbox,
499
+ cost_giou=cost_giou,
500
+ cost_class=cost_class,
501
+ )
502
+
503
+ if loss_mode == "bbox_only":
504
+ criterion = BboxOnlyCriterion(
505
+ matcher=matcher,
506
+ loss_bbox=loss_bbox,
507
+ loss_giou=loss_giou,
508
+ )
509
+ else: # "hungarian" (default)
510
+ criterion = OvdCriterion(
511
+ matcher=matcher,
512
+ loss_bbox=loss_bbox,
513
+ loss_giou=loss_giou,
514
+ loss_cls=loss_cls,
515
+ eos_coef=eos_coef,
516
+ use_focal_loss=use_focal_loss,
517
+ focal_alpha=focal_alpha,
518
+ focal_gamma=focal_gamma,
519
+ )
520
+ return criterion
521
+
522
+
523
+ def box_iou(boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor:
524
+ """
525
+ Compute IoU between two sets of boxes.
526
+
527
+ Args:
528
+ boxes1: (N, 4) in (x1, y1, x2, y2) format
529
+ boxes2: (N, 4) in (x1, y1, x2, y2) format (same N, element-wise)
530
+
531
+ Returns:
532
+ iou: (N,) IoU for each pair
533
+ """
534
+ # Intersection
535
+ lt = torch.max(boxes1[:, :2], boxes2[:, :2])
536
+ rb = torch.min(boxes1[:, 2:], boxes2[:, 2:])
537
+ wh = (rb - lt).clamp(min=0)
538
+ inter = wh[:, 0] * wh[:, 1]
539
+
540
+ # Union
541
+ area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
542
+ area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])
543
+ union = area1 + area2 - inter
544
+
545
+ return inter / (union + 1e-6)
546
+
547
+ def box_iou_matrix(boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor:
548
+ """
549
+ Pairwise IoU between two sets of boxes.
550
+
551
+ Args:
552
+ boxes1: (N, 4) in (x1, y1, x2, y2) format
553
+ boxes2: (M, 4) in (x1, y1, x2, y2) format
554
+
555
+ Returns:
556
+ iou: (N, M)
557
+ """
558
+ if boxes1.numel() == 0 or boxes2.numel() == 0:
559
+ return torch.zeros((boxes1.shape[0], boxes2.shape[0]), device=boxes1.device, dtype=boxes1.dtype)
560
+
561
+ lt = torch.max(boxes1[:, None, :2], boxes2[None, :, :2]) # (N, M, 2)
562
+ rb = torch.min(boxes1[:, None, 2:], boxes2[None, :, 2:]) # (N, M, 2)
563
+ wh = (rb - lt).clamp(min=0)
564
+ inter = wh[..., 0] * wh[..., 1]
565
+
566
+ area1 = (boxes1[:, 2] - boxes1[:, 0]).clamp(min=0) * (boxes1[:, 3] - boxes1[:, 1]).clamp(min=0) # (N,)
567
+ area2 = (boxes2[:, 2] - boxes2[:, 0]).clamp(min=0) * (boxes2[:, 3] - boxes2[:, 1]).clamp(min=0) # (M,)
568
+ union = area1[:, None] + area2[None, :] - inter
569
+
570
+ eps = 1e-6
571
+ return inter / (union + eps)
572
+
573
+
574
+ def _ap_101(recalls: torch.Tensor, precisions: torch.Tensor) -> float:
575
+ """
576
+ COCO-style AP approximation by 101-point interpolation.
577
+ Expects recalls to be non-decreasing.
578
+ """
579
+ if recalls.numel() == 0:
580
+ return 0.0
581
+
582
+ recalls = recalls.clamp(0, 1)
583
+ precisions = precisions.clamp(0, 1)
584
+
585
+ # Precision envelope (monotone decreasing)
586
+ mpre = precisions.clone()
587
+ for i in range(mpre.numel() - 2, -1, -1):
588
+ mpre[i] = torch.maximum(mpre[i], mpre[i + 1])
589
+
590
+ ap = 0.0
591
+ for r in torch.linspace(0, 1, 101, device=recalls.device, dtype=recalls.dtype):
592
+ idx = torch.searchsorted(recalls, r)
593
+ if idx < mpre.numel():
594
+ ap += float(mpre[idx].item())
595
+ return ap / 101.0
596
+
597
+
598
+ @torch.no_grad()
599
+ def compute_metrics(
600
+ pred_boxes: torch.Tensor,
601
+ pred_scores: torch.Tensor,
602
+ target_boxes: torch.Tensor,
603
+ target_mask: torch.Tensor,
604
+ iou_thresholds: list[float] | None = None,
605
+ score_threshold: float = 0.5,
606
+ ) -> dict[str, float]:
607
+ """
608
+ Compute detection-style metrics for OVD (single-class).
609
+
610
+ Notes:
611
+ - Computes COCO-style mAP over IoU thresholds [0.50:0.95:0.05].
612
+ - Computes AP50/AP75.
613
+ - Computes precision/recall at IoU=0.50 with a fixed score threshold.
614
+ - `matcher` is kept for backward compatibility but is not used for AP computation.
615
+
616
+ Args:
617
+ pred_boxes: (B, num_queries, 4) predicted boxes
618
+ pred_scores: (B, num_queries, 1) predicted objectness scores (logits)
619
+ target_boxes: (B, max_targets, 4) target boxes
620
+ target_mask: (B, max_targets) True for valid targets
621
+ iou_thresholds: IoU thresholds for recall calculation
622
+ score_threshold: Score threshold for objectness accuracy
623
+
624
+ Returns:
625
+ Dictionary of metrics
626
+ """
627
+ if iou_thresholds is None:
628
+ iou_thresholds = [0.5, 0.75]
629
+
630
+ B = pred_boxes.shape[0]
631
+
632
+ # COCO mAP thresholds
633
+ coco_thresholds = [round(x, 2) for x in torch.arange(0.5, 0.96, 0.05).tolist()]
634
+
635
+ total_gt = int(target_mask.sum().item())
636
+ metrics: dict[str, float] = {
637
+ "num_gt": float(total_gt),
638
+ "num_queries": float(pred_boxes.shape[1]),
639
+ }
640
+ if total_gt == 0:
641
+ metrics.update({"mAP": 0.0, "AP50": 0.0, "AP75": 0.0, "precision@0.5": 0.0, "recall@0.5": 0.0})
642
+ return metrics
643
+
644
+ # For each IoU threshold, collect (score, is_tp) across the batch.
645
+ thr_to_scores: dict[float, list[torch.Tensor]] = {thr: [] for thr in coco_thresholds}
646
+ thr_to_is_tp: dict[float, list[torch.Tensor]] = {thr: [] for thr in coco_thresholds}
647
+
648
+ # Also track precision/recall at IoU=0.5 using a fixed score threshold.
649
+ pr_iou_thr = 0.5
650
+ pr_tp = 0
651
+ pr_fp = 0
652
+
653
+ for b in range(B):
654
+ gt_mask_b = target_mask[b]
655
+ if gt_mask_b.sum() == 0:
656
+ continue
657
+ gt_boxes = target_boxes[b, gt_mask_b].float() # (G, 4)
658
+
659
+ scores = pred_scores[b].squeeze(-1).sigmoid().float() # (Q,)
660
+ boxes = pred_boxes[b].float() # (Q, 4)
661
+ order = torch.argsort(scores, descending=True)
662
+ scores_sorted = scores[order]
663
+ boxes_sorted = boxes[order]
664
+
665
+ # Precompute IoU matrix (Q, G) once per image
666
+ ious_qg = box_iou_matrix(boxes_sorted, gt_boxes) # (Q, G)
667
+
668
+ for thr in coco_thresholds:
669
+ matched_gt = torch.zeros((gt_boxes.shape[0],), dtype=torch.bool, device=ious_qg.device)
670
+ is_tp = torch.zeros((boxes_sorted.shape[0],), dtype=torch.bool, device=ious_qg.device)
671
+
672
+ for i in range(boxes_sorted.shape[0]):
673
+ if gt_boxes.shape[0] == 0:
674
+ break
675
+ ious_i = ious_qg[i] # (G,)
676
+ max_iou, max_j = torch.max(ious_i, dim=0)
677
+ if float(max_iou.item()) >= thr and not bool(matched_gt[max_j].item()):
678
+ is_tp[i] = True
679
+ matched_gt[max_j] = True
680
+
681
+ thr_to_scores[thr].append(scores_sorted)
682
+ thr_to_is_tp[thr].append(is_tp)
683
+
684
+ # PR at IoU=0.5 with score threshold (count TP/FP after matching)
685
+ if thr == pr_iou_thr:
686
+ keep = scores_sorted >= score_threshold
687
+ pr_tp += int(is_tp[keep].sum().item())
688
+ pr_fp += int((~is_tp[keep]).sum().item())
689
+
690
+ # Compute APs
691
+ aps: list[float] = []
692
+ ap50 = 0.0
693
+ ap75 = 0.0
694
+ for thr in coco_thresholds:
695
+ if not thr_to_scores[thr]:
696
+ aps.append(0.0)
697
+ continue
698
+ scores_all = torch.cat(thr_to_scores[thr], dim=0)
699
+ is_tp_all = torch.cat(thr_to_is_tp[thr], dim=0).to(dtype=torch.float32)
700
+
701
+ # Sort globally by score
702
+ global_order = torch.argsort(scores_all, descending=True)
703
+ is_tp_all = is_tp_all[global_order]
704
+ is_fp_all = 1.0 - is_tp_all
705
+
706
+ cum_tp = torch.cumsum(is_tp_all, dim=0)
707
+ cum_fp = torch.cumsum(is_fp_all, dim=0)
708
+
709
+ recalls = cum_tp / max(float(total_gt), 1.0)
710
+ precisions = cum_tp / torch.clamp(cum_tp + cum_fp, min=1.0)
711
+
712
+ ap = _ap_101(recalls, precisions)
713
+ aps.append(ap)
714
+ if abs(thr - 0.5) < 1e-9:
715
+ ap50 = ap
716
+ if abs(thr - 0.75) < 1e-9:
717
+ ap75 = ap
718
+
719
+ metrics["mAP"] = float(sum(aps) / max(len(aps), 1))
720
+ metrics["AP50"] = float(ap50)
721
+ metrics["AP75"] = float(ap75)
722
+
723
+ precision = pr_tp / max(pr_tp + pr_fp, 1)
724
+ recall = pr_tp / max(total_gt, 1)
725
+ metrics["precision@0.5"] = float(precision)
726
+ metrics["recall@0.5"] = float(recall)
727
+
728
+ # Keep backward compatible recalls for common IoU thresholds at the fixed score threshold.
729
+ # (Same definition as above but for additional IoU thresholds if requested.)
730
+ for thr in iou_thresholds:
731
+ if thr not in coco_thresholds:
732
+ continue
733
+ if not thr_to_scores[thr]:
734
+ metrics[f"recall@{thr}"] = 0.0
735
+ continue
736
+ scores_all = torch.cat(thr_to_scores[thr], dim=0)
737
+ is_tp_all = torch.cat(thr_to_is_tp[thr], dim=0)
738
+ keep = scores_all >= score_threshold
739
+ tp = int(is_tp_all[keep].sum().item())
740
+ metrics[f"recall@{thr}"] = float(tp / max(total_gt, 1))
741
+
742
+ return metrics
internvl_image_procesing.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from torchvision import transforms as T
4
+ from torchvision.transforms import InterpolationMode
5
+
6
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
7
+ IMAGENET_STD = (0.229, 0.224, 0.225)
8
+
9
+ def build_transform(input_size):
10
+ MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
11
+ transform = T.Compose([
12
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
13
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
14
+ T.ToTensor(),
15
+ T.Normalize(mean=MEAN, std=STD)
16
+ ])
17
+ return transform
18
+
19
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
20
+ best_ratio_diff = float('inf')
21
+ best_ratio = (1, 1)
22
+ area = width * height
23
+ for ratio in target_ratios:
24
+ target_aspect_ratio = ratio[0] / ratio[1]
25
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
26
+ if ratio_diff < best_ratio_diff:
27
+ best_ratio_diff = ratio_diff
28
+ best_ratio = ratio
29
+ elif ratio_diff == best_ratio_diff:
30
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
31
+ best_ratio = ratio
32
+ return best_ratio
33
+
34
+ def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
35
+ orig_width, orig_height = image.size
36
+ aspect_ratio = orig_width / orig_height
37
+
38
+ # calculate the existing image aspect ratio
39
+ target_ratios = set(
40
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
41
+ i * j <= max_num and i * j >= min_num)
42
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
43
+
44
+ # find the closest aspect ratio to the target
45
+ target_aspect_ratio = find_closest_aspect_ratio(
46
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
47
+
48
+ # calculate the target width and height
49
+ target_width = image_size * target_aspect_ratio[0]
50
+ target_height = image_size * target_aspect_ratio[1]
51
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
52
+
53
+ # resize the image
54
+ resized_img = image.resize((target_width, target_height))
55
+ processed_images = []
56
+ for i in range(blocks):
57
+ box = (
58
+ (i % (target_width // image_size)) * image_size,
59
+ (i // (target_width // image_size)) * image_size,
60
+ ((i % (target_width // image_size)) + 1) * image_size,
61
+ ((i // (target_width // image_size)) + 1) * image_size
62
+ )
63
+ # split the image
64
+ split_img = resized_img.crop(box)
65
+ processed_images.append(split_img)
66
+ assert len(processed_images) == blocks
67
+ if use_thumbnail and len(processed_images) != 1:
68
+ thumbnail_img = image.resize((image_size, image_size))
69
+ processed_images.append(thumbnail_img)
70
+ return processed_images
71
+
72
+ def load_image(image_file, input_size=448, max_num=12):
73
+ if isinstance(image_file, str):
74
+ image = Image.open(image_file).convert('RGB')
75
+ else:
76
+ image = image_file.convert('RGB')
77
+ transform = build_transform(input_size=input_size)
78
+ images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
79
+ pixel_values = [transform(image) for image in images]
80
+ pixel_values = torch.stack(pixel_values)
81
+ return pixel_values
internvl_ovd.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .heads import DetrOvdHead
3
+ from .vlmbackbone import InternVL3_5_Backbone
4
+ from torch import nn
5
+
6
+
7
+ class InternVL3_5_OvdModel(nn.Module):
8
+ def __init__(
9
+ self,
10
+ backbone: InternVL3_5_Backbone,
11
+ model_config: object,
12
+ ) -> None:
13
+ super().__init__()
14
+ self.backbone = backbone
15
+ self.ovd_head = DetrOvdHead(model_config)
16
+ # Keep head dtype aligned with backbone output for non-autocast inference paths.
17
+ self.ovd_head.to(dtype=self.backbone.dtype)
18
+
19
+ def forward(
20
+ self,
21
+ pixel_values: torch.Tensor,
22
+ input_ids: torch.Tensor,
23
+ attention_mask: torch.Tensor,
24
+ patch_mask: torch.Tensor | None = None,
25
+ ) -> tuple[torch.Tensor, torch.Tensor]:
26
+ """
27
+ Forward pass.
28
+
29
+ Args:
30
+ pixel_values: Image tensor
31
+ input_ids: Tokenized prompt
32
+ attention_mask: Attention mask for prompt
33
+ """
34
+ memory, padding_mask = self.backbone.forward_fused(
35
+ pixel_values,
36
+ input_ids,
37
+ attention_mask,
38
+ patch_mask=patch_mask,
39
+ )
40
+ pred_boxes, pred_scores = self.ovd_head(memory, padding_mask)
41
+ return pred_boxes, pred_scores
42
+
43
+
44
+ def build_internvl_ovd(
45
+ model_config: object,
46
+ device: str,
47
+ dtype: torch.dtype,
48
+ ) -> InternVL3_5_OvdModel:
49
+ backbone = InternVL3_5_Backbone(
50
+ model_config.vlm_model_name,
51
+ device,
52
+ dtype,
53
+ use_token_fpn=model_config.use_token_fpn,
54
+ token_fpn_levels=model_config.token_fpn_levels,
55
+ token_fpn_include_text=model_config.token_fpn_include_text,
56
+ )
57
+ model = InternVL3_5_OvdModel(backbone, model_config)
58
+ return model
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0edac953471304a8759724dec3616d8010d51f588789ee3a8162c53e373312be
3
+ size 2140804602
modeling_internvl_ovd.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Optional
5
+
6
+ import torch
7
+ from PIL import Image
8
+ from transformers import PreTrainedModel
9
+ from transformers.utils import ModelOutput
10
+
11
+ from .hungarian_matcher import build_criterion
12
+ from .configuration_internvl_ovd import InternVLOVDConfig
13
+ from .internvl_ovd import build_internvl_ovd
14
+ from .internvl_image_procesing import load_image
15
+
16
+
17
+ @dataclass
18
+ class InternVLOVDOutput(ModelOutput):
19
+ loss: torch.Tensor | None = None
20
+ pred_boxes: torch.Tensor | None = None
21
+ pred_scores: torch.Tensor | None = None
22
+ loss_total: torch.Tensor | None = None
23
+ loss_bbox: torch.Tensor | None = None
24
+ loss_giou: torch.Tensor | None = None
25
+ loss_cls: torch.Tensor | None = None
26
+
27
+
28
+ class InternVLOVDForDetection(PreTrainedModel):
29
+ config_class = InternVLOVDConfig
30
+
31
+ def __init__(self, config: InternVLOVDConfig) -> None:
32
+ super().__init__(config)
33
+
34
+ amp_dtype = torch.bfloat16 if config.dtype == "bfloat16" else torch.float16
35
+
36
+ self.inner = build_internvl_ovd(
37
+ model_config=config,
38
+ device=config.device_map,
39
+ dtype=amp_dtype,
40
+ )
41
+
42
+ if config.freeze_backbone:
43
+ for name, param in self.inner.named_parameters():
44
+ if name.startswith("backbone.vlm.vision_model"):
45
+ param.requires_grad = False
46
+
47
+ # Training criterion is created lazily to keep Hub inference-only loads minimal.
48
+ self._criterion = None
49
+
50
+ @property
51
+ def criterion(self) -> torch.nn.Module:
52
+ if self._criterion is None:
53
+ cfg = self.config
54
+ self._criterion = build_criterion(
55
+ cost_bbox=cfg.cost_bbox,
56
+ cost_giou=cfg.cost_giou,
57
+ cost_class=cfg.cost_class,
58
+ loss_bbox=cfg.loss_bbox,
59
+ loss_giou=cfg.loss_giou,
60
+ loss_cls=cfg.loss_cls,
61
+ eos_coef=cfg.eos_coef,
62
+ use_focal_loss=cfg.use_focal_loss,
63
+ focal_alpha=cfg.focal_alpha,
64
+ focal_gamma=cfg.focal_gamma,
65
+ loss_mode=cfg.loss_mode,
66
+ )
67
+ return self._criterion
68
+
69
+ def forward_inference(
70
+ self,
71
+ *,
72
+ pixel_values: torch.Tensor,
73
+ input_ids: torch.Tensor,
74
+ attention_mask: torch.Tensor,
75
+ patch_mask: Optional[torch.Tensor] = None,
76
+ **kwargs: Any,
77
+ ) -> InternVLOVDOutput:
78
+ if hasattr(self.inner.backbone.vlm, "vision_model"):
79
+ self.inner.backbone.vlm.vision_model.eval()
80
+
81
+ pred_boxes, pred_scores = self.inner(
82
+ pixel_values=pixel_values,
83
+ input_ids=input_ids,
84
+ attention_mask=attention_mask,
85
+ patch_mask=patch_mask,
86
+ )
87
+ return InternVLOVDOutput(loss=None, pred_boxes=pred_boxes, pred_scores=pred_scores)
88
+
89
+ def forward(
90
+ self,
91
+ pixel_values: torch.Tensor,
92
+ input_ids: torch.Tensor,
93
+ attention_mask: torch.Tensor,
94
+ patch_mask: Optional[torch.Tensor] = None,
95
+ boxes: Optional[torch.Tensor] = None,
96
+ box_mask: Optional[torch.Tensor] = None,
97
+ compute_loss: bool = False,
98
+ **kwargs: Any,
99
+ ) -> InternVLOVDOutput:
100
+ outputs = self.forward_inference(
101
+ pixel_values=pixel_values,
102
+ input_ids=input_ids,
103
+ attention_mask=attention_mask,
104
+ patch_mask=patch_mask,
105
+ **kwargs,
106
+ )
107
+
108
+ if not compute_loss:
109
+ return outputs
110
+
111
+ if boxes is None or box_mask is None:
112
+ raise ValueError("compute_loss=True requires both `boxes` and `box_mask`.")
113
+
114
+ pred_boxes = outputs.pred_boxes
115
+ pred_scores = outputs.pred_scores
116
+ losses = self.criterion(pred_boxes, pred_scores, boxes, box_mask)
117
+ loss_total = losses.get("loss_total")
118
+ return InternVLOVDOutput(
119
+ loss=loss_total,
120
+ pred_boxes=pred_boxes,
121
+ pred_scores=pred_scores,
122
+ loss_total=loss_total,
123
+ loss_bbox=losses.get("loss_bbox"),
124
+ loss_giou=losses.get("loss_giou"),
125
+ loss_cls=losses.get("loss_cls"),
126
+ )
127
+
128
+ @torch.no_grad()
129
+ def infer_image(
130
+ self,
131
+ *,
132
+ image: Image.Image | str,
133
+ query: str,
134
+ tokenizer,
135
+ max_length: int = 4096,
136
+ device: Optional[torch.device] = None,
137
+ ) -> InternVLOVDOutput:
138
+ """
139
+ Convenience inference helper that accepts a PIL image (or path) + query text.
140
+ Handles image preprocessing and prompt construction.
141
+ """
142
+ cfg = self.config
143
+ if device is None:
144
+ device = next(self.parameters()).device
145
+ amp_dtype = torch.bfloat16 if cfg.dtype == "bfloat16" else torch.float16
146
+ if device.type == "cpu" and amp_dtype == torch.float16:
147
+ amp_dtype = torch.bfloat16
148
+
149
+ pixel_values = load_image(image, input_size=cfg.input_size, max_num=cfg.max_num_patches)
150
+ num_patches = int(pixel_values.shape[0])
151
+ pixel_values = pixel_values.unsqueeze(0)
152
+ patch_mask = torch.ones((1, num_patches), dtype=torch.bool)
153
+
154
+ img_context_token = "<IMG_CONTEXT>"
155
+ img_start_token = "<img>"
156
+ img_end_token = "</img>"
157
+ tokens_per_patch = 256
158
+
159
+ image_tokens = img_start_token + img_context_token * (tokens_per_patch * num_patches) + img_end_token
160
+ prompt = (
161
+ f"{image_tokens}\n"
162
+ "Please provide the bounding box coordinate of the region this sentence describes: "
163
+ f"<ref>{query}</ref>"
164
+ )
165
+
166
+ tokens = tokenizer(
167
+ [prompt],
168
+ return_tensors="pt",
169
+ padding=True,
170
+ truncation=True,
171
+ max_length=max_length,
172
+ )
173
+
174
+ pixel_values = pixel_values.to(device=device, dtype=amp_dtype)
175
+ patch_mask = patch_mask.to(device=device)
176
+ input_ids = tokens["input_ids"].to(device=device)
177
+ attention_mask = tokens["attention_mask"].to(device=device)
178
+
179
+ self.eval()
180
+ amp_device_type = "cuda" if device.type == "cuda" else "cpu"
181
+ with torch.amp.autocast(device_type=amp_device_type, dtype=amp_dtype):
182
+ return self.forward_inference(
183
+ pixel_values=pixel_values,
184
+ input_ids=input_ids,
185
+ attention_mask=attention_mask,
186
+ patch_mask=patch_mask,
187
+ )
vlmbackbone.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers import AutoModel, AutoTokenizer
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class VlmBackboneBase(nn.Module):
8
+ """
9
+ 공통 VLM 비전 백본 인터페이스.
10
+ - forward_vision(pixel_values) -> (image_tokens, padding_mask)
11
+ image_tokens: (B, L, D)
12
+ padding_mask: (B, L) (True == pad)
13
+ """
14
+
15
+ def __init__(self) -> None:
16
+ super().__init__()
17
+
18
+ def forward_fused(
19
+ self,
20
+ pixel_values: torch.Tensor,
21
+ input_ids: torch.Tensor,
22
+ attention_mask: torch.Tensor,
23
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
24
+ raise NotImplementedError
25
+
26
+
27
+ class InternVL3_5_Backbone(VlmBackboneBase):
28
+ def __init__(
29
+ self,
30
+ model_name: str,
31
+ device: str,
32
+ dtype: torch.dtype,
33
+ *,
34
+ use_token_fpn: bool = False,
35
+ token_fpn_levels: tuple[int, ...] = (16, 8, 4, 2),
36
+ token_fpn_include_text: bool = True,
37
+ ) -> None:
38
+ super().__init__()
39
+ self.device = device
40
+ self.dtype = dtype
41
+ self.use_token_fpn = use_token_fpn
42
+ self.token_fpn_levels = token_fpn_levels
43
+ self.token_fpn_include_text = token_fpn_include_text
44
+ self.vlm = AutoModel.from_pretrained(
45
+ model_name,
46
+ trust_remote_code=True,
47
+ torch_dtype=dtype,
48
+ low_cpu_mem_usage=False,
49
+ device_map=None,
50
+ _attn_implementation="flash_attention_2"
51
+ )
52
+
53
+ self.hidden_size_llm = 1024 # InternVL3_5 text hidden dim
54
+ self.hidden_size_detr = 1024 # DETR d_model
55
+
56
+ self.fused_proj = nn.Linear(
57
+ self.hidden_size_llm,
58
+ self.hidden_size_detr,
59
+ bias=True,
60
+ device=None,
61
+ dtype=dtype,
62
+ )
63
+ nn.init.eye_(self.fused_proj.weight)
64
+ nn.init.zeros_(self.fused_proj.bias)
65
+
66
+ # Set img_context_token_id for the model
67
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False)
68
+ IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>"
69
+ self.img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
70
+ self.vlm.img_context_token_id = self.img_context_token_id
71
+
72
+ def _build_token_fpn_memory(
73
+ self,
74
+ memory_last: torch.Tensor, # (B, T, D)
75
+ input_ids: torch.Tensor, # (B, T)
76
+ ) -> tuple[torch.Tensor, torch.Tensor]:
77
+ """
78
+ Build an "FPN-like" multi-level token memory from IMG_CONTEXT token embeddings.
79
+
80
+ - Extract IMG_CONTEXT tokens per sample
81
+ - Reshape per patch into (num_patches, 16, 16, D)
82
+ - Pool to multiple spatial levels (e.g., 16->8->4->2)
83
+ - Flatten and concatenate levels into one sequence
84
+ - Optionally append non-image tokens (text + special tokens)
85
+
86
+ Returns:
87
+ memory: (B, L, D)
88
+ padding_mask: (B, L) with True == pad
89
+ """
90
+ B, T, D = memory_last.shape
91
+ device = memory_last.device
92
+
93
+ # Validate levels (must be descending powers of 2 from 16)
94
+ levels = tuple(int(x) for x in self.token_fpn_levels)
95
+ if len(levels) == 0 or levels[0] != 16:
96
+ raise ValueError(f"token_fpn_levels must start with 16, got {levels}")
97
+ for a, b in zip(levels, levels[1:]):
98
+ if a % 2 != 0 or b != a // 2:
99
+ raise ValueError(f"token_fpn_levels must be like (16,8,4,2,...) got {levels}")
100
+
101
+ is_img = input_ids.eq(self.img_context_token_id) # (B, T)
102
+
103
+ per_sample_memory: list[torch.Tensor] = []
104
+ max_len = 0
105
+ for i in range(B):
106
+ img_tokens = memory_last[i][is_img[i]] # (N_img, D)
107
+ n_img = img_tokens.shape[0]
108
+ if n_img == 0:
109
+ # Fallback: no img tokens found -> keep original memory.
110
+ mem_i = memory_last[i]
111
+ else:
112
+ if n_img % 256 != 0:
113
+ raise ValueError(f"IMG_CONTEXT token count must be multiple of 256, got {n_img}")
114
+ num_patches = n_img // 256
115
+
116
+ # (num_patches, D, 16, 16)
117
+ patch_feat = img_tokens.view(num_patches, 16, 16, D).permute(0, 3, 1, 2).contiguous()
118
+
119
+ # Build levels by pooling
120
+ level_tokens: list[torch.Tensor] = []
121
+ feat = patch_feat
122
+ cur = 16
123
+ for lvl in levels:
124
+ # Ensure feat is at correct resolution
125
+ while cur > lvl:
126
+ feat = F.avg_pool2d(feat, kernel_size=2, stride=2)
127
+ cur //= 2
128
+ # Flatten: (num_patches, D, H, W) -> (num_patches*H*W, D)
129
+ h, w = feat.shape[-2:]
130
+ lvl_tok = feat.permute(0, 2, 3, 1).reshape(num_patches * h * w, D).contiguous()
131
+ level_tokens.append(lvl_tok)
132
+
133
+ mem_i = torch.cat(level_tokens, dim=0) # (L_img_fpn, D)
134
+
135
+ if self.token_fpn_include_text:
136
+ txt_tokens = memory_last[i][~is_img[i]] # (N_txt, D)
137
+ mem_i = torch.cat([txt_tokens, mem_i], dim=0)
138
+
139
+ per_sample_memory.append(mem_i)
140
+ max_len = max(max_len, mem_i.shape[0])
141
+
142
+ # Pad to (B, max_len, D)
143
+ padded = memory_last.new_zeros((B, max_len, D))
144
+ padding_mask = torch.ones((B, max_len), device=device, dtype=torch.bool)
145
+ for i, mem_i in enumerate(per_sample_memory):
146
+ seq_len = mem_i.shape[0]
147
+ padded[i, :seq_len] = mem_i
148
+ padding_mask[i, :seq_len] = False
149
+
150
+ return padded, padding_mask
151
+
152
+ def forward_fused(
153
+ self,
154
+ pixel_values: torch.Tensor,
155
+ input_ids: torch.Tensor,
156
+ attention_mask: torch.Tensor,
157
+ patch_mask: torch.Tensor | None = None,
158
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
159
+ # pixel_values: (B, P, 3, H, W)
160
+ if pixel_values.dim() == 5:
161
+ bsz, num_patches, channels, height, width = pixel_values.shape
162
+ pixel_values = pixel_values.view(-1, channels, height, width)
163
+ if patch_mask is not None:
164
+ patch_mask = patch_mask.view(bsz * num_patches)
165
+ else:
166
+ bsz = pixel_values.shape[0]
167
+ num_patches = 1
168
+ patch_mask = None
169
+
170
+ # image_flags must match number of images provided to the VLM
171
+ if patch_mask is not None:
172
+ image_flags = patch_mask.to(pixel_values.device, dtype=torch.long)
173
+ else:
174
+ image_flags = torch.ones(pixel_values.shape[0], dtype=torch.long, device=pixel_values.device)
175
+
176
+ outputs = self.vlm(
177
+ pixel_values=pixel_values,
178
+ input_ids=input_ids,
179
+ attention_mask=attention_mask,
180
+ image_flags=image_flags,
181
+ output_hidden_states=True,
182
+ return_dict=True
183
+ )
184
+
185
+ # CausalLMOutputWithPast has hidden_states tuple, last element is the final layer output
186
+ memory = outputs.hidden_states[-1] # (B, T, hidden_size)
187
+ memory = self.fused_proj(memory) # (B, T, hidden_size_detr)
188
+ if self.use_token_fpn:
189
+ memory, padding_mask = self._build_token_fpn_memory(
190
+ memory_last=memory,
191
+ input_ids=input_ids,
192
+ )
193
+ return memory, padding_mask
194
+
195
+ return memory, None