File size: 5,100 Bytes
f638d9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
import torch.nn as nn
from easydict import EasyDict as edict


class Block(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.config = config


class Model(torch.nn.Module):

    def __init__(self, clip_model, config):
        super().__init__()
        self.clip_model = clip_model
        # if config.i2t_encoder_layers > 0:
        #     self.i2t_encoder = nn.ModuleList([Block(config) for _ in range(config.i2t_encoder_layers)])

        # if config.t2i_encoder_layers > 0:
        #     self.t2i_encoder = nn.ModuleList([Block(config) for _ in range(config.i2t_encoder_layers)])

        self.config = config

    def img_forward(self, x: torch.Tensor):  # [N, 3, 224, 224]
        x = self.clip_model.visual.conv1(x)  # shape = [*, width, grid, grid]
        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
        x = x.permute(0, 2, 1)  # shape = [*, gri d ** 2, width]
        x = torch.cat(
            [self.clip_model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x],
            dim=1)  # shape = [*, grid ** 2 + 1, width]
        x = x + self.clip_model.visual.positional_embedding.to(x.dtype)
        x = self.clip_model.visual.ln_pre(x)

        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.clip_model.visual.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.clip_model.visual.ln_post(x)  # [NLD]
        cls_token = self.clip_model.visual.ln_post(x[:, 0, :])

        if self.clip_model.visual.proj is not None:
            cls_token = cls_token @ self.clip_model.visual.proj
        return x, cls_token

    def txt_forward(self, text):
        dtype = self.clip_model.dtype
        x = self.clip_model.token_embedding(text).type(dtype)  # [batch_size, n_ctx, d_model]

        x = x + self.clip_model.positional_embedding.type(dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.clip_model.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.clip_model.ln_final(x).type(dtype)

        # take features from the eot embedding (eot_token is the highest number in each sequence)
        eot = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.clip_model.text_projection
        return x, eot  # [NLD]

    def var_img_forward(self, image):
        if len(image.shape) == 5:
            img_features1, img_token1 = self.img_forward(image[:, 0, ...])
            img_features2, img_token2 = self.img_forward(image[:, 1, ...])
            img_token = (img_token1 + img_token2) / 2
            img_features = (img_features1 + img_features2) / 2
        else:
            img_features, img_token = self.img_forward(image)
        img_token = img_token / img_token.norm(dim=-1, keepdim=True)
        return img_features, img_token

    def var_txt_forward(self, text):
        txt_features, txt_token = self.txt_forward(text)
        txt_token = txt_token / txt_token.norm(dim=-1, keepdim=True)
        return txt_features, txt_token

    def forward(self, image, text, past_img_tokens=None, past_txt_tokens=None):
        # TODO: aggregate past img and txt tokens
        img_features, img_token = self.var_img_forward(image)
        txt_features, txt_token = self.var_txt_forward(text)
        logit_scale = self.clip_model.logit_scale.exp()

        if past_img_tokens is not None:
            past_img_tokens = torch.cat([past_img_tokens, img_token], dim=0)
            past_txt_tokens = torch.cat([past_txt_tokens, txt_token], dim=0)

            batch_size = past_img_tokens.shape[0]
            ground_truth = torch.arange(batch_size, dtype=torch.long, device=img_token.device)

            logits_for_imgs = logit_scale * past_img_tokens @ past_txt_tokens.t()
            logits_for_txts = logits_for_imgs.t()
            # print(f"past_img_tokens: {past_img_tokens.shape}, past_txt_tokens: {past_txt_tokens.shape}")

            # CLIP Contrastive Learning Loss Function
            loss_img = torch.nn.CrossEntropyLoss()
            loss_txt = torch.nn.CrossEntropyLoss()
            loss = (loss_img(logits_for_imgs, ground_truth[:batch_size]) + loss_txt(logits_for_txts, ground_truth[:batch_size])) / 2
        else:
            batch_size = img_token.shape[0]
            ground_truth = torch.arange(batch_size, dtype=torch.long, device=img_token.device)

            logits_for_imgs = logit_scale * img_token @ txt_token.t()
            logits_for_txts = logits_for_imgs.t()

            # CLIP Contrastive Learning Loss Function
            loss_img = torch.nn.CrossEntropyLoss()
            loss_txt = torch.nn.CrossEntropyLoss()
            loss = (loss_img(logits_for_imgs, ground_truth[:batch_size]) + loss_txt(logits_for_txts, ground_truth[:batch_size])) / 2

        return dict(
            img_token=img_token,
            txt_token=txt_token,
            img_features=img_features,
            txt_features=txt_features,
            loss=loss,
            past_img_tokens=past_img_tokens,
            past_txt_tokens=past_txt_tokens,
        )