File size: 2,411 Bytes
d295ca1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""
RBoxEncoder - pure PyTorch, no ldm/bldm dependency.

Encodes rotated bounding boxes (8 coords) with Fourier embedding and text embeddings.
"""

import torch
import torch.nn as nn


class FourierEmbedder:
    def __init__(self, num_freqs=64, temperature=100):
        self.num_freqs = num_freqs
        self.temperature = temperature
        self.freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)

    @torch.no_grad()
    def __call__(self, x, cat_dim=-1):
        out = []
        for freq in self.freq_bands:
            out.append(torch.sin(freq * x))
            out.append(torch.cos(freq * x))
        return torch.cat(out, cat_dim)


class RBoxEncoder(nn.Module):
    """Encoder for rotated bounding boxes (8 coords) with text embeddings."""

    def __init__(self, in_dim, out_dim, fourier_freqs=8):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim

        self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
        self.position_dim = fourier_freqs * 2 * 8  # 2 is sin&cos, 8 is xyxyxyxy

        self.linears = nn.Sequential(
            nn.Linear(self.in_dim + self.position_dim, 512),
            nn.SiLU(),
            nn.Linear(512, 512),
            nn.SiLU(),
            nn.Linear(512, out_dim),
        )

        self.null_text_feature = nn.Parameter(torch.zeros([self.in_dim]))
        self.null_position_feature = nn.Parameter(torch.zeros([self.position_dim]))

    def forward(self, boxes=None, masks=None, text_embeddings=None, **kwargs):
        # Pipeline passes boxes=[bboxes], masks=[mask_vector], text_embeddings=[category_conditions]
        boxes = (boxes or kwargs.get("boxes", [[]]))[0]
        masks = (masks or kwargs.get("masks", [[]]))[0]
        text_embeddings = (text_embeddings or kwargs.get("text_embeddings", [[]]))[0]

        B, N, _ = boxes.shape
        masks = masks.unsqueeze(-1)

        xyxy_embedding = self.fourier_embedder(boxes)  # B*N*8 --> B*N*C

        text_null = self.null_text_feature.view(1, 1, -1)
        xyxy_null = self.null_position_feature.view(1, 1, -1)

        text_embeddings = text_embeddings * masks + (1 - masks) * text_null
        xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null

        objs = self.linears(torch.cat([text_embeddings, xyxy_embedding], dim=-1))
        assert objs.shape == torch.Size([B, N, self.out_dim])
        return objs