File size: 4,715 Bytes
2b1b1ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import torch
from modified_clip import clip
IMAGENET_MEAN = (0.48145466, 0.4578275, 0.40821073)
IMAGENET_STD = (0.26862954, 0.26130258, 0.27577711)
mu = torch.tensor(IMAGENET_MEAN).view(3, 1, 1)
std = torch.tensor(IMAGENET_STD).view(3, 1, 1)
def normalize(X):
return (X - mu.to(X.device)) / std.to(X.device)
def clip_img_preprocessing(X):
img_size = 224
# X = torch.nn.functional.upsample(X, size=(img_size, img_size), mode='bicubic')
X = torch.nn.functional.interpolate(X, size=(img_size, img_size), mode='bicubic')
X = normalize(X)
return X
def create_logits(x1, x2, logit_scale):
x1 = x1 / x1.norm(dim=-1, keepdim=True)
x2 = x2 / x2.norm(dim=-1, keepdim=True)
# cosine similarity as logits
logits_per_x1 = logit_scale * x1 @ x2.t()
logits_per_x2 = logit_scale * x2 @ x1.t()
return logits_per_x1, logits_per_x2
def multiGPU_CLIP_image_logits(images, model, text_tokens, prompter=None, add_prompter=None):
image_tokens = clip_img_preprocessing(images)
prompt_token = None if add_prompter is None else add_prompter()
if prompter is not None:
image_tokens = prompter(image_tokens)
return multiGPU_CLIP(model, image_tokens, text_tokens, prompt_token=prompt_token)[0]
def multiGPU_CLIP(model, images, text_tokens, prompt_token=None, is_embedding=False):
# print("text_token shape", text_tokens.shape)
if prompt_token is not None:
bs = images.size(0)
prompt_token = prompt_token.repeat(bs, 1, 1)
if images.size(0) == 1: # 2 GPUs
images = images.repeat(2,1,1,1)
img_embed, scale_text_embed = model(images, text_tokens, prompt_token)
img_embed = img_embed[0].unsqueeze(0)
# print("images_shape", images.shape)
# print("scale_text_embed_shape", scale_text_embed.shape)
elif text_tokens.size(0) == 2: # 4 GPUs
text_tokens = text_tokens.repeat(2,1)
img_embed, scale_text_embed = model(images, text_tokens, prompt_token)
text_tokens = text_tokens[0:2]
else:
img_embed, scale_text_embed = model(images, text_tokens, prompt_token)
# print("img_embed_shape", img_embed.shape, "scale_text_embed_shape", scale_text_embed.shape)
logits_per_image = img_embed @ scale_text_embed.t()
logits_per_text = scale_text_embed @ img_embed.t()
# print("img_emb_size", img_embed.shape)
# print("logits_size", logits_per_image.shape)
if is_embedding:
return logits_per_image, logits_per_text, img_embed, scale_text_embed
else:
return logits_per_image, logits_per_text
def multiGPU_CLIP_Text_Prompt_Tuning(model, images, text_tokens, prompt_token=None, prompt_learner=None, is_embedding=False):
if prompt_token is not None:
bs = images.size(0)
prompt_token = prompt_token.repeat(bs, 1, 1)
prompts = prompt_learner()
tokenized_prompts = prompt_learner.module.tokenized_prompts
img_embed, scale_text_embed = model(images, text_tokens, prompt_token, prompts, tokenized_prompts, forward_type='Text_Prompt_Tuning')
logits_per_image = img_embed @ scale_text_embed.t()
logits_per_text = scale_text_embed @ img_embed.t()
if is_embedding:
return logits_per_image, logits_per_text, img_embed, scale_text_embed
else:
return logits_per_image, logits_per_text
############################## Noise Modulated CLIP ###########################################
def apply_multiplicative_noise(signal, beta=0.0):
"""
Apply Gaussian multiplicative noise to a signal.
Parameters:
signal (torch.Tensor): Tensor of shape (m, d) where m is the number of samples and d is the dimension.
beta (float): Standard deviation of the Gaussian noise.
Returns:
torch.Tensor: Noisy signal.
"""
m, d = signal.shape
noise = torch.normal(mean=1.0, std=beta, size=(1, d)).cuda()
noisy_signal = signal * noise
return noisy_signal
def multiGPU_CLIP_multiply_noise(model, images, text_tokens, prompt_token=None, is_embedding=False, beta=0.0):
if prompt_token is not None:
bs = images.size(0)
prompt_token = prompt_token.repeat(bs, 1, 1)
img_embed, scale_text_embed = model(images, text_tokens, prompt_token)
### Noise modulate ###
img_embed = apply_multiplicative_noise(img_embed, beta)
scale_text_embed = apply_multiplicative_noise(scale_text_embed, beta)
### Noise modulate ###
logits_per_image = img_embed @ scale_text_embed.t()
logits_per_text = scale_text_embed @ img_embed.t()
if is_embedding:
return logits_per_image, logits_per_text, img_embed, scale_text_embed
else:
return logits_per_image, logits_per_text
|