xpuenabler's picture
Upload folder using huggingface_hub
e3454bb verified
from __future__ import annotations
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
class _CrossAttentionLayer(nn.Module):
def __init__(self, config: object) -> None:
super().__init__()
self.cross_attn = nn.MultiheadAttention(
embed_dim=config.hidden_size,
num_heads=config.nhead,
dropout=config.dropout,
batch_first=True,
)
self.linear1 = nn.Linear(config.hidden_size, config.dim_feedforward)
self.linear2 = nn.Linear(config.dim_feedforward, config.hidden_size)
self.norm1 = nn.LayerNorm(config.hidden_size)
self.norm2 = nn.LayerNorm(config.hidden_size)
self.dropout = nn.Dropout(config.dropout)
def forward(
self,
tgt: torch.Tensor, # (B,K,D)
memory: torch.Tensor, # (B,L,D)
memory_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# cross-attn
residual = tgt
tgt2, _ = self.cross_attn(
query=tgt,
key=memory,
value=memory,
key_padding_mask=memory_mask,
)
tgt = residual + self.dropout(tgt2)
tgt = self.norm1(tgt)
# FFN
residual = tgt
tgt2 = self.linear2(F.gelu(self.linear1(tgt)))
tgt = residual + self.dropout(tgt2)
tgt = self.norm2(tgt)
return tgt
class _DetrDecoderLayer(nn.Module):
def __init__(self, config: object) -> None:
super().__init__()
self.self_attn = nn.MultiheadAttention(
embed_dim=config.hidden_size,
num_heads=config.nhead,
dropout=config.dropout,
batch_first=True,
)
self.cross_attn = nn.MultiheadAttention(
embed_dim=config.hidden_size,
num_heads=config.nhead,
dropout=config.dropout,
batch_first=True,
)
self.linear1 = nn.Linear(config.hidden_size, config.dim_feedforward)
self.linear2 = nn.Linear(config.dim_feedforward, config.hidden_size)
self.norm1 = nn.LayerNorm(config.hidden_size)
self.norm2 = nn.LayerNorm(config.hidden_size)
self.norm3 = nn.LayerNorm(config.hidden_size)
self.dropout = nn.Dropout(config.dropout)
def forward(
self,
tgt: torch.Tensor, # (B,K,D)
memory: torch.Tensor, # (B,L,D)
memory_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# self-attn
residual = tgt
tgt2, _ = self.self_attn(tgt, tgt, tgt)
tgt = residual + self.dropout(tgt2)
tgt = self.norm1(tgt)
# cross-attn
residual = tgt
tgt2, _ = self.cross_attn(
query=tgt,
key=memory,
value=memory,
key_padding_mask=memory_mask,
)
tgt = residual + self.dropout(tgt2)
tgt = self.norm2(tgt)
# FFN
residual = tgt
tgt2 = self.linear2(F.gelu(self.linear1(tgt)))
tgt = residual + self.dropout(tgt2)
tgt = self.norm3(tgt)
return tgt
class DetrOvdHead(nn.Module):
"""
Unified OVD head:
- head_type="small": single cross-attention pooling (fast)
- head_type="decoder": DETR-style decoder stack (heavier, experimental)
"""
def __init__(self, config: object) -> None:
super().__init__()
self.config = config
self.num_queries = int(getattr(config, "num_queries"))
self.d_model = int(getattr(config, "hidden_size"))
head_type = getattr(config, "head_type", "small")
self.head_type = str(head_type)
self.query_embed = nn.Embedding(self.num_queries, self.d_model)
if self.head_type == "detr":
n_layers = int(getattr(config, "num_decoder_layers"))
self.layers = nn.ModuleList([_DetrDecoderLayer(config) for _ in range(n_layers)])
self.pooling = None
else:
# default: "small"
self.pooling = _CrossAttentionLayer(config)
self.layers = None
self.bbox_head = nn.Sequential(
nn.Linear(self.d_model, self.d_model),
nn.ReLU(),
nn.Linear(self.d_model, 4),
)
self.score_head = nn.Sequential(
nn.Linear(self.d_model, self.d_model),
nn.ReLU(),
nn.Linear(self.d_model, 1),
)
def forward(
self,
memory: torch.Tensor, # (B,L,D)
memory_mask: torch.Tensor | None = None, # (B,L) or None
) -> Tuple[torch.Tensor, torch.Tensor]:
B, _, _ = memory.shape
device = memory.device
# object queries
query_idx = torch.arange(self.num_queries, device=device)
tgt = self.query_embed(query_idx).unsqueeze(0).expand(B, -1, -1) # (B,K,D)
if self.head_type == "detr":
assert self.layers is not None
for layer in self.layers:
tgt = layer(tgt, memory, memory_mask)
else:
assert self.pooling is not None
tgt = self.pooling(tgt, memory, memory_mask)
# Predict (cx, cy, w, h) in [0, 1] range
pred_cxcywh = self.bbox_head(tgt).sigmoid() # (B,K,4), 0~1
# Convert to (x1, y1, x2, y2) format
cx, cy, w, h = pred_cxcywh.unbind(-1)
pred_boxes = torch.stack(
[
cx - w / 2, # x1
cy - h / 2, # y1
cx + w / 2, # x2
cy + h / 2, # y2
],
dim=-1,
).clamp(0, 1)
pred_logits = self.score_head(tgt) # (B,K,1), raw logits for BCE with logits
return pred_boxes, pred_logits