| from __future__ import annotations |
|
|
| from typing import Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class _CrossAttentionLayer(nn.Module): |
| def __init__(self, config: object) -> None: |
| super().__init__() |
| self.cross_attn = nn.MultiheadAttention( |
| embed_dim=config.hidden_size, |
| num_heads=config.nhead, |
| dropout=config.dropout, |
| batch_first=True, |
| ) |
| self.linear1 = nn.Linear(config.hidden_size, config.dim_feedforward) |
| self.linear2 = nn.Linear(config.dim_feedforward, config.hidden_size) |
| self.norm1 = nn.LayerNorm(config.hidden_size) |
| self.norm2 = nn.LayerNorm(config.hidden_size) |
| self.dropout = nn.Dropout(config.dropout) |
|
|
| def forward( |
| self, |
| tgt: torch.Tensor, |
| memory: torch.Tensor, |
| memory_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| |
| residual = tgt |
| tgt2, _ = self.cross_attn( |
| query=tgt, |
| key=memory, |
| value=memory, |
| key_padding_mask=memory_mask, |
| ) |
| tgt = residual + self.dropout(tgt2) |
| tgt = self.norm1(tgt) |
|
|
| |
| residual = tgt |
| tgt2 = self.linear2(F.gelu(self.linear1(tgt))) |
| tgt = residual + self.dropout(tgt2) |
| tgt = self.norm2(tgt) |
| return tgt |
|
|
|
|
| class _DetrDecoderLayer(nn.Module): |
| def __init__(self, config: object) -> None: |
| super().__init__() |
| self.self_attn = nn.MultiheadAttention( |
| embed_dim=config.hidden_size, |
| num_heads=config.nhead, |
| dropout=config.dropout, |
| batch_first=True, |
| ) |
| self.cross_attn = nn.MultiheadAttention( |
| embed_dim=config.hidden_size, |
| num_heads=config.nhead, |
| dropout=config.dropout, |
| batch_first=True, |
| ) |
| self.linear1 = nn.Linear(config.hidden_size, config.dim_feedforward) |
| self.linear2 = nn.Linear(config.dim_feedforward, config.hidden_size) |
| self.norm1 = nn.LayerNorm(config.hidden_size) |
| self.norm2 = nn.LayerNorm(config.hidden_size) |
| self.norm3 = nn.LayerNorm(config.hidden_size) |
| self.dropout = nn.Dropout(config.dropout) |
|
|
| def forward( |
| self, |
| tgt: torch.Tensor, |
| memory: torch.Tensor, |
| memory_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| |
| residual = tgt |
| tgt2, _ = self.self_attn(tgt, tgt, tgt) |
| tgt = residual + self.dropout(tgt2) |
| tgt = self.norm1(tgt) |
|
|
| |
| residual = tgt |
| tgt2, _ = self.cross_attn( |
| query=tgt, |
| key=memory, |
| value=memory, |
| key_padding_mask=memory_mask, |
| ) |
| tgt = residual + self.dropout(tgt2) |
| tgt = self.norm2(tgt) |
|
|
| |
| residual = tgt |
| tgt2 = self.linear2(F.gelu(self.linear1(tgt))) |
| tgt = residual + self.dropout(tgt2) |
| tgt = self.norm3(tgt) |
| return tgt |
|
|
|
|
| class DetrOvdHead(nn.Module): |
| """ |
| Unified OVD head: |
| - head_type="small": single cross-attention pooling (fast) |
| - head_type="decoder": DETR-style decoder stack (heavier, experimental) |
| """ |
|
|
| def __init__(self, config: object) -> None: |
| super().__init__() |
| self.config = config |
| self.num_queries = int(getattr(config, "num_queries")) |
| self.d_model = int(getattr(config, "hidden_size")) |
|
|
| head_type = getattr(config, "head_type", "small") |
| self.head_type = str(head_type) |
|
|
| self.query_embed = nn.Embedding(self.num_queries, self.d_model) |
|
|
| if self.head_type == "detr": |
| n_layers = int(getattr(config, "num_decoder_layers")) |
| self.layers = nn.ModuleList([_DetrDecoderLayer(config) for _ in range(n_layers)]) |
| self.pooling = None |
| else: |
| |
| self.pooling = _CrossAttentionLayer(config) |
| self.layers = None |
|
|
| self.bbox_head = nn.Sequential( |
| nn.Linear(self.d_model, self.d_model), |
| nn.ReLU(), |
| nn.Linear(self.d_model, 4), |
| ) |
| self.score_head = nn.Sequential( |
| nn.Linear(self.d_model, self.d_model), |
| nn.ReLU(), |
| nn.Linear(self.d_model, 1), |
| ) |
|
|
| def forward( |
| self, |
| memory: torch.Tensor, |
| memory_mask: torch.Tensor | None = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| B, _, _ = memory.shape |
| device = memory.device |
|
|
| |
| query_idx = torch.arange(self.num_queries, device=device) |
| tgt = self.query_embed(query_idx).unsqueeze(0).expand(B, -1, -1) |
|
|
| if self.head_type == "detr": |
| assert self.layers is not None |
| for layer in self.layers: |
| tgt = layer(tgt, memory, memory_mask) |
| else: |
| assert self.pooling is not None |
| tgt = self.pooling(tgt, memory, memory_mask) |
|
|
| |
| pred_cxcywh = self.bbox_head(tgt).sigmoid() |
|
|
| |
| cx, cy, w, h = pred_cxcywh.unbind(-1) |
| pred_boxes = torch.stack( |
| [ |
| cx - w / 2, |
| cy - h / 2, |
| cx + w / 2, |
| cy + h / 2, |
| ], |
| dim=-1, |
| ).clamp(0, 1) |
|
|
| pred_logits = self.score_head(tgt) |
| return pred_boxes, pred_logits |
|
|
|
|
|
|