Ttius's picture
Upload 192 files
998bb30 verified
import torch
import torchvision.transforms as transforms
from typing import Union, List
import open_clip
class CLIPLoss(torch.nn.Module):
def __init__(self, device, lambda_direction=1., direction_loss_type='cosine', clip_model='ViT-B/32',
pretrained=None, class_names=None):
super(CLIPLoss, self).__init__()
self.device = device
self.model, _, clip_preprocess = open_clip.create_model_and_transforms(
clip_model,
pretrained='convs/CLIP-ViT-B-32-laion2B-s34B-b79K/open_clip_pytorch_model.bin',
# pretrained,
jit=True,
device=device,
)
self.tokenizer = open_clip.get_tokenizer(clip_model)
self.clip_preprocess = clip_preprocess
self.preprocess = transforms.Compose([transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0,
2.0])] + # Un-normalize from [-1.0, 1.0] (GAN output) to [0, 1].
clip_preprocess.transforms[:2] + # to match CLIP input scale assumptions
clip_preprocess.transforms[4:]) # + skip convert PIL to tensor
self.target_direction = None
self.patch_text_directions = None
self.patch_direction_loss = torch.nn.CosineSimilarity(dim=2)
self.cosine_sim = torch.nn.CosineSimilarity(dim=1)
self.lambda_direction = lambda_direction
self.src_text_features = None
self.target_text_features = None
self.angle_loss = torch.nn.L1Loss()
self.text_class_features = dict()
self.image_class_features = dict()
self.predicted_classes = class_names
self.mse = torch.nn.MSELoss()
def tokenize(self, strings: list):
return self.tokenizer(strings).to(self.device)
def encode_text(self, tokens: list) -> torch.Tensor:
return self.model.encode_text(tokens)
def encode_images(self, images: torch.Tensor) -> torch.Tensor:
images = self.preprocess(images).to(self.device)
return self.model.encode_image(images)
def get_text_features(self, class_str: str) -> torch.Tensor:
text_features = [self.text_class_features[i] for i in class_str]
return torch.stack(text_features, dim=0)
def precompute_text_features(self, class_str: str, templates=None, norm: bool = True) -> torch.Tensor:
for classes in class_str:
if classes not in self.text_class_features:
template_text = self.compose_text_with_templates(classes, templates)
tokens = self.tokenizer(template_text).to(self.device)
text_features = self.encode_text(tokens).detach()
text_features = text_features.mean(dim=0)
if norm:
text_features /= text_features.norm(dim=-1, keepdim=True)
self.text_class_features[classes] = text_features
def get_image_features(self, img: torch.Tensor, norm: bool = True) -> torch.Tensor:
image_features = self.encode_images(img)
if norm:
image_features /= image_features.clone().norm(dim=-1, keepdim=True)
return image_features
def compute_text_direction(self, source_class: Union[str, List[str]], target_class: Union[str, List[str]],
broadcast=False) -> torch.Tensor:
source_features = self.get_text_features(source_class)
target_features = self.get_text_features(target_class)
if broadcast:
text_direction = (target_features.T.unsqueeze(0) - source_features.unsqueeze(-1))
else:
text_direction = (target_features - source_features)
text_direction /= text_direction.norm(dim=-1, keepdim=True)
return text_direction
def compose_text_with_templates(self, text: str, templates=None) -> list:
return [template.format(text) for template in templates]
def clip_loss(self, src_img: torch.Tensor, source_class: Union[str, List[str]],
target_img: torch.Tensor, target_class: Union[str, List[str]],
negative_class: Union[str, List[str]]) -> torch.Tensor:
self.target_direction = self.compute_text_direction(source_class, target_class)
self.negative_direction = self.compute_text_direction(source_class, negative_class, broadcast=True)
src_encoding = self.get_image_features(src_img)
target_encoding = self.get_image_features(target_img)
edit_direction = (target_encoding - src_encoding)
edit_direction /= (edit_direction.clone().norm(dim=-1, keepdim=True) + 1e-7)
logit_target = self.cosine_sim(self.target_direction, edit_direction)
logit_negative = self.cosine_sim(self.negative_direction, edit_direction.unsqueeze(-1))
pp = torch.exp(logit_target)
pn = torch.sum(torch.exp(logit_negative), dim=-1)
p = pp / (pp + pn)
return -torch.log(p).mean()
def forward(self, src_img: torch.Tensor, source_class: Union[str, List[str]], target_img: torch.Tensor,
target_class: Union[str, List[str]], negative_class: Union[str, List[str]],
texture_image: torch.Tensor = None):
clip_loss = 0.0
if self.lambda_direction:
direction_loss = self.clip_loss(src_img, source_class, target_img, target_class, negative_class)
clip_loss += self.lambda_direction * direction_loss
return clip_loss