File size: 6,740 Bytes
bbc0514
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from transformers import Dinov2Model, Dinov2Config
from torchvision.transforms import v2
from code import interact
import json
import os
from PIL import Image
import numpy as np
from typing import Union

transforms = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Resize((224, 224)),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

class CrossAttention(nn.Module):

    def __init__(self, d_model:int, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.Wq = nn.Linear(d_model, d_model)
        self.Wk = nn.Linear(d_model, d_model)
        self.Wv = nn.Linear(d_model, d_model)

    def forward(self, queries, candidates):
        
        Q = self.Wk(candidates) # (B, num_candidates, d_model)
        K = self.Wq(queries) # (B, num_queries, d_model)
        V = self.Wv(queries) # (B, num_queries, d_model)
        attn_out = F.scaled_dot_product_attention(Q, K, V) # (B, num_candidates, d_model)

        return attn_out

class JointTransformer(nn.Module):

    def __init__(
            self,
            d_model=384,
            nhead=4,
            num_layers=4,
            *args, **kwargs
        ):
        super().__init__(*args, **kwargs)

        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=4 * d_model,
            batch_first=True,
            dropout=0.0
        )

        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)

    def forward(self, query: Tensor, candidates: Tensor) -> Tensor :
        Q = query.size(1)
        assert Q == 1
        
        x = torch.cat((query, candidates), dim=1) # (B, Q+C, D)
        x = self.transformer(x) # (B, Q+C, D)
        query = x[:,:Q,:] # (B, Q, D)
        candidates = x[:, Q:, :] # (B, C, D)

        return query, candidates
    
class MLP(nn.Module):

    def __init__(self, emb_dim, expand_factor, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.lin1 = nn.Linear(emb_dim, emb_dim*expand_factor)
        self.gelu = nn.GELU("tanh")
        self.lin2 = nn.Linear(emb_dim*expand_factor, emb_dim)

    def forward(self, x:Tensor) -> Tensor:
        x = self.lin1(x)
        x = self.gelu(x)
        x = self.lin2(x)
        return x

class Matcher(nn.Module):

    def __init__(self, max_candidates, num_layers, dino_dir, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # -------------- Pre-trained Encoder (frozen) -----------------
        assert isinstance(dino_dir, str)
        with open(os.path.join(dino_dir, "config.json"), "r") as f:
            dino_cfg = json.load(f)

        self.encoder = Dinov2Model.from_pretrained(dino_dir, config = Dinov2Config(**dino_cfg))
        self.freeze_encoder()
        
        # ----------------- Embeddings to distinguish queries and candidates ---------------------
        self.query_image_embed = nn.Parameter(torch.randn(1, 1, dino_cfg["hidden_size"]))
        self.candidates_image_embed = nn.Embedding(max_candidates, dino_cfg["hidden_size"])
        self.null_candidate = nn.Parameter(torch.randn(1, 1, dino_cfg["hidden_size"])) # null candidate embedding

        # ---------------- Joint transformer (trained) ----------------------
        self.max_candidates = max_candidates
        self.num_layers = num_layers
        self.joint_transformer = JointTransformer(
            d_model = dino_cfg["hidden_size"],
            nhead = dino_cfg["num_attention_heads"],
            num_layers = num_layers,
        )
        self.lnormq = nn.LayerNorm(dino_cfg["hidden_size"], )
        self.lnormc = nn.LayerNorm(dino_cfg["hidden_size"], )
        
        # ------------------------ Final operation ---------------------------
        self.cross_attn = CrossAttention(dino_cfg["hidden_size"])
        self.lnormc2 = nn.LayerNorm(dino_cfg["hidden_size"])
        self.classification_layer = nn.Linear(dino_cfg["hidden_size"], 1)

    def freeze_encoder(self) -> None:
        for p in self.encoder.parameters():
            p.requires_grad_(False)
    
    def pre_process_img(self, image:Union[Image.Image, np.ndarray, str]):

        if isinstance(image, str):
            image = Image.open(image)
        
        return transforms(image)

    @torch.inference_mode()
    def predict(self, query_crop: np.ndarray, candidate_crops: list[np.ndarray]):

        query = transforms(query_crop)[None, None, ...]
        candidates = torch.stack([transforms(candidate_crop) for candidate_crop in candidate_crops]).unsqueeze(0)
        probs = self.forward(query, candidates).softmax(dim=-1)

        return probs.numpy()

    
    def forward(self, query: Tensor, candidates: Tensor) -> Tensor :
        # query (B,1,3,H,W), candidates (B,C,3,H,W)
        B, C, _, H, W = candidates.shape

        query = self.encoder(
            query.view(B, 3, H, W)
        )['last_hidden_state'] # (B, T, D)

        # pick the CLS_TOKEN
        query = query[:,0,:].view(B, 1, -1) # (B, 1, D)

        candidates = self.encoder(
            candidates.view(B*C, 3, H, W)
        )['last_hidden_state'] # (B*C, T, D)

        # pick the CLS_TOKEN
        candidates = candidates[:,0,:].view(B, C, -1) # (B, C, D)

        # Add embeddings
        query = query + self.query_image_embed.repeat(B, 1, 1) # (B, 1, D)
        candidate_ids = torch.arange(C, device=query.device).view(1, C)
        candidates = candidates + self.candidates_image_embed(candidate_ids) # (B, C, D)
        candidates = torch.cat(
            (
                candidates, 
                self.null_candidate.repeat(B, 1, 1)
            ), 
        dim=1) # (B, C+1, D)
        
        # Joint transformer, candidate and query tokens attend to each other
        q, c = self.joint_transformer(query, candidates)
        # skip connections
        query = self.lnormq(query + q)
        candidates = self.lnormc(candidates + c)
        
        # Cross attention, query attends to candidates
        c = self.cross_attn(query, candidates) # (B, C+1, D)
        candidates = self.lnormc2(candidates + c)
        candidates = candidates + c
        logits = self.classification_layer(candidates) # (B, C+1, 1)

        return logits.squeeze(-1)
    
if __name__ == "__main__":

    import random

    B, H, W = 1, 224, 224
    max_candidates = 10
    num_layers = 4

    query = torch.randn((B, 1, 3, H, W))
    candidates = torch.randn((B, random.randint(2, max_candidates), 3, H, W))
    
    matcher = Matcher(max_candidates, num_layers, "DINOv2_base")
    out = matcher(query, candidates)

    interact(local=locals())