|
|
import torch |
|
|
import torchvision.transforms as transforms |
|
|
from typing import Union, List |
|
|
import open_clip |
|
|
|
|
|
|
|
|
class CLIPLoss(torch.nn.Module): |
|
|
def __init__(self, device, lambda_direction=1., direction_loss_type='cosine', clip_model='ViT-B/32', |
|
|
pretrained=None, class_names=None): |
|
|
super(CLIPLoss, self).__init__() |
|
|
|
|
|
self.device = device |
|
|
self.model, _, clip_preprocess = open_clip.create_model_and_transforms( |
|
|
clip_model, |
|
|
pretrained='convs/CLIP-ViT-B-32-laion2B-s34B-b79K/open_clip_pytorch_model.bin', |
|
|
|
|
|
jit=True, |
|
|
device=device, |
|
|
) |
|
|
self.tokenizer = open_clip.get_tokenizer(clip_model) |
|
|
|
|
|
self.clip_preprocess = clip_preprocess |
|
|
|
|
|
self.preprocess = transforms.Compose([transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, |
|
|
2.0])] + |
|
|
clip_preprocess.transforms[:2] + |
|
|
clip_preprocess.transforms[4:]) |
|
|
|
|
|
self.target_direction = None |
|
|
self.patch_text_directions = None |
|
|
|
|
|
self.patch_direction_loss = torch.nn.CosineSimilarity(dim=2) |
|
|
self.cosine_sim = torch.nn.CosineSimilarity(dim=1) |
|
|
|
|
|
self.lambda_direction = lambda_direction |
|
|
|
|
|
self.src_text_features = None |
|
|
self.target_text_features = None |
|
|
self.angle_loss = torch.nn.L1Loss() |
|
|
|
|
|
self.text_class_features = dict() |
|
|
self.image_class_features = dict() |
|
|
self.predicted_classes = class_names |
|
|
self.mse = torch.nn.MSELoss() |
|
|
|
|
|
def tokenize(self, strings: list): |
|
|
return self.tokenizer(strings).to(self.device) |
|
|
|
|
|
def encode_text(self, tokens: list) -> torch.Tensor: |
|
|
return self.model.encode_text(tokens) |
|
|
|
|
|
def encode_images(self, images: torch.Tensor) -> torch.Tensor: |
|
|
images = self.preprocess(images).to(self.device) |
|
|
return self.model.encode_image(images) |
|
|
|
|
|
def get_text_features(self, class_str: str) -> torch.Tensor: |
|
|
text_features = [self.text_class_features[i] for i in class_str] |
|
|
return torch.stack(text_features, dim=0) |
|
|
|
|
|
def precompute_text_features(self, class_str: str, templates=None, norm: bool = True) -> torch.Tensor: |
|
|
for classes in class_str: |
|
|
if classes not in self.text_class_features: |
|
|
template_text = self.compose_text_with_templates(classes, templates) |
|
|
tokens = self.tokenizer(template_text).to(self.device) |
|
|
text_features = self.encode_text(tokens).detach() |
|
|
text_features = text_features.mean(dim=0) |
|
|
if norm: |
|
|
text_features /= text_features.norm(dim=-1, keepdim=True) |
|
|
self.text_class_features[classes] = text_features |
|
|
|
|
|
def get_image_features(self, img: torch.Tensor, norm: bool = True) -> torch.Tensor: |
|
|
image_features = self.encode_images(img) |
|
|
|
|
|
if norm: |
|
|
image_features /= image_features.clone().norm(dim=-1, keepdim=True) |
|
|
|
|
|
return image_features |
|
|
|
|
|
def compute_text_direction(self, source_class: Union[str, List[str]], target_class: Union[str, List[str]], |
|
|
broadcast=False) -> torch.Tensor: |
|
|
source_features = self.get_text_features(source_class) |
|
|
target_features = self.get_text_features(target_class) |
|
|
|
|
|
if broadcast: |
|
|
text_direction = (target_features.T.unsqueeze(0) - source_features.unsqueeze(-1)) |
|
|
else: |
|
|
text_direction = (target_features - source_features) |
|
|
text_direction /= text_direction.norm(dim=-1, keepdim=True) |
|
|
|
|
|
return text_direction |
|
|
|
|
|
def compose_text_with_templates(self, text: str, templates=None) -> list: |
|
|
return [template.format(text) for template in templates] |
|
|
|
|
|
def clip_loss(self, src_img: torch.Tensor, source_class: Union[str, List[str]], |
|
|
target_img: torch.Tensor, target_class: Union[str, List[str]], |
|
|
negative_class: Union[str, List[str]]) -> torch.Tensor: |
|
|
self.target_direction = self.compute_text_direction(source_class, target_class) |
|
|
self.negative_direction = self.compute_text_direction(source_class, negative_class, broadcast=True) |
|
|
|
|
|
src_encoding = self.get_image_features(src_img) |
|
|
target_encoding = self.get_image_features(target_img) |
|
|
|
|
|
edit_direction = (target_encoding - src_encoding) |
|
|
edit_direction /= (edit_direction.clone().norm(dim=-1, keepdim=True) + 1e-7) |
|
|
|
|
|
logit_target = self.cosine_sim(self.target_direction, edit_direction) |
|
|
logit_negative = self.cosine_sim(self.negative_direction, edit_direction.unsqueeze(-1)) |
|
|
pp = torch.exp(logit_target) |
|
|
pn = torch.sum(torch.exp(logit_negative), dim=-1) |
|
|
p = pp / (pp + pn) |
|
|
return -torch.log(p).mean() |
|
|
|
|
|
def forward(self, src_img: torch.Tensor, source_class: Union[str, List[str]], target_img: torch.Tensor, |
|
|
target_class: Union[str, List[str]], negative_class: Union[str, List[str]], |
|
|
texture_image: torch.Tensor = None): |
|
|
clip_loss = 0.0 |
|
|
|
|
|
if self.lambda_direction: |
|
|
direction_loss = self.clip_loss(src_img, source_class, target_img, target_class, negative_class) |
|
|
clip_loss += self.lambda_direction * direction_loss |
|
|
|
|
|
return clip_loss |