File size: 5,592 Bytes
998bb30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch
import torchvision.transforms as transforms
from typing import Union, List
import open_clip


class CLIPLoss(torch.nn.Module):
    def __init__(self, device, lambda_direction=1., direction_loss_type='cosine', clip_model='ViT-B/32',
                 pretrained=None, class_names=None):
        super(CLIPLoss, self).__init__()

        self.device = device
        self.model, _, clip_preprocess = open_clip.create_model_and_transforms(
            clip_model,
            pretrained='convs/CLIP-ViT-B-32-laion2B-s34B-b79K/open_clip_pytorch_model.bin',
            # pretrained,
            jit=True,
            device=device,
        )
        self.tokenizer = open_clip.get_tokenizer(clip_model)

        self.clip_preprocess = clip_preprocess

        self.preprocess = transforms.Compose([transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0,
                                                                                                 2.0])] +  # Un-normalize from [-1.0, 1.0] (GAN output) to [0, 1].
                                             clip_preprocess.transforms[:2] +  # to match CLIP input scale assumptions
                                             clip_preprocess.transforms[4:])  # + skip convert PIL to tensor

        self.target_direction = None
        self.patch_text_directions = None

        self.patch_direction_loss = torch.nn.CosineSimilarity(dim=2)
        self.cosine_sim = torch.nn.CosineSimilarity(dim=1)

        self.lambda_direction = lambda_direction

        self.src_text_features = None
        self.target_text_features = None
        self.angle_loss = torch.nn.L1Loss()

        self.text_class_features = dict()
        self.image_class_features = dict()
        self.predicted_classes = class_names
        self.mse = torch.nn.MSELoss()

    def tokenize(self, strings: list):
        return self.tokenizer(strings).to(self.device)

    def encode_text(self, tokens: list) -> torch.Tensor:
        return self.model.encode_text(tokens)

    def encode_images(self, images: torch.Tensor) -> torch.Tensor:
        images = self.preprocess(images).to(self.device)
        return self.model.encode_image(images)

    def get_text_features(self, class_str: str) -> torch.Tensor:
        text_features = [self.text_class_features[i] for i in class_str]
        return torch.stack(text_features, dim=0)

    def precompute_text_features(self, class_str: str, templates=None, norm: bool = True) -> torch.Tensor:
        for classes in class_str:
            if classes not in self.text_class_features:
                template_text = self.compose_text_with_templates(classes, templates)
                tokens = self.tokenizer(template_text).to(self.device)
                text_features = self.encode_text(tokens).detach()
                text_features = text_features.mean(dim=0)
                if norm:
                    text_features /= text_features.norm(dim=-1, keepdim=True)
                self.text_class_features[classes] = text_features

    def get_image_features(self, img: torch.Tensor, norm: bool = True) -> torch.Tensor:
        image_features = self.encode_images(img)

        if norm:
            image_features /= image_features.clone().norm(dim=-1, keepdim=True)

        return image_features

    def compute_text_direction(self, source_class: Union[str, List[str]], target_class: Union[str, List[str]],
                               broadcast=False) -> torch.Tensor:
        source_features = self.get_text_features(source_class)
        target_features = self.get_text_features(target_class)

        if broadcast:
            text_direction = (target_features.T.unsqueeze(0) - source_features.unsqueeze(-1))
        else:
            text_direction = (target_features - source_features)
        text_direction /= text_direction.norm(dim=-1, keepdim=True)

        return text_direction

    def compose_text_with_templates(self, text: str, templates=None) -> list:
        return [template.format(text) for template in templates]

    def clip_loss(self, src_img: torch.Tensor, source_class: Union[str, List[str]],
                              target_img: torch.Tensor, target_class: Union[str, List[str]],
                              negative_class: Union[str, List[str]]) -> torch.Tensor:
        self.target_direction = self.compute_text_direction(source_class, target_class)
        self.negative_direction = self.compute_text_direction(source_class, negative_class, broadcast=True)

        src_encoding = self.get_image_features(src_img)
        target_encoding = self.get_image_features(target_img)

        edit_direction = (target_encoding - src_encoding)
        edit_direction /= (edit_direction.clone().norm(dim=-1, keepdim=True) + 1e-7)

        logit_target = self.cosine_sim(self.target_direction, edit_direction)
        logit_negative = self.cosine_sim(self.negative_direction, edit_direction.unsqueeze(-1))
        pp = torch.exp(logit_target)
        pn = torch.sum(torch.exp(logit_negative), dim=-1)
        p = pp / (pp + pn)
        return -torch.log(p).mean()

    def forward(self, src_img: torch.Tensor, source_class: Union[str, List[str]], target_img: torch.Tensor,
                target_class: Union[str, List[str]], negative_class: Union[str, List[str]],
                texture_image: torch.Tensor = None):
        clip_loss = 0.0

        if self.lambda_direction:
            direction_loss = self.clip_loss(src_img, source_class, target_img, target_class, negative_class)
            clip_loss += self.lambda_direction * direction_loss

        return clip_loss