|
|
import torch |
|
|
from modified_clip import clip |
|
|
|
|
|
IMAGENET_MEAN = (0.48145466, 0.4578275, 0.40821073) |
|
|
IMAGENET_STD = (0.26862954, 0.26130258, 0.27577711) |
|
|
|
|
|
mu = torch.tensor(IMAGENET_MEAN).view(3, 1, 1) |
|
|
std = torch.tensor(IMAGENET_STD).view(3, 1, 1) |
|
|
|
|
|
def normalize(X): |
|
|
return (X - mu.to(X.device)) / std.to(X.device) |
|
|
|
|
|
def clip_img_preprocessing(X): |
|
|
img_size = 224 |
|
|
|
|
|
X = torch.nn.functional.interpolate(X, size=(img_size, img_size), mode='bicubic') |
|
|
X = normalize(X) |
|
|
return X |
|
|
|
|
|
def create_logits(x1, x2, logit_scale): |
|
|
x1 = x1 / x1.norm(dim=-1, keepdim=True) |
|
|
x2 = x2 / x2.norm(dim=-1, keepdim=True) |
|
|
|
|
|
logits_per_x1 = logit_scale * x1 @ x2.t() |
|
|
logits_per_x2 = logit_scale * x2 @ x1.t() |
|
|
return logits_per_x1, logits_per_x2 |
|
|
|
|
|
def multiGPU_CLIP_image_logits(images, model, text_tokens, prompter=None, add_prompter=None): |
|
|
image_tokens = clip_img_preprocessing(images) |
|
|
prompt_token = None if add_prompter is None else add_prompter() |
|
|
if prompter is not None: |
|
|
image_tokens = prompter(image_tokens) |
|
|
return multiGPU_CLIP(model, image_tokens, text_tokens, prompt_token=prompt_token)[0] |
|
|
|
|
|
|
|
|
def multiGPU_CLIP(model, images, text_tokens, prompt_token=None, is_embedding=False): |
|
|
|
|
|
|
|
|
|
|
|
if prompt_token is not None: |
|
|
bs = images.size(0) |
|
|
prompt_token = prompt_token.repeat(bs, 1, 1) |
|
|
if images.size(0) == 1: |
|
|
images = images.repeat(2,1,1,1) |
|
|
img_embed, scale_text_embed = model(images, text_tokens, prompt_token) |
|
|
img_embed = img_embed[0].unsqueeze(0) |
|
|
|
|
|
|
|
|
elif text_tokens.size(0) == 2: |
|
|
text_tokens = text_tokens.repeat(2,1) |
|
|
img_embed, scale_text_embed = model(images, text_tokens, prompt_token) |
|
|
text_tokens = text_tokens[0:2] |
|
|
else: |
|
|
img_embed, scale_text_embed = model(images, text_tokens, prompt_token) |
|
|
|
|
|
logits_per_image = img_embed @ scale_text_embed.t() |
|
|
logits_per_text = scale_text_embed @ img_embed.t() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if is_embedding: |
|
|
return logits_per_image, logits_per_text, img_embed, scale_text_embed |
|
|
else: |
|
|
return logits_per_image, logits_per_text |
|
|
|
|
|
def multiGPU_CLIP_Text_Prompt_Tuning(model, images, text_tokens, prompt_token=None, prompt_learner=None, is_embedding=False): |
|
|
if prompt_token is not None: |
|
|
bs = images.size(0) |
|
|
prompt_token = prompt_token.repeat(bs, 1, 1) |
|
|
|
|
|
prompts = prompt_learner() |
|
|
tokenized_prompts = prompt_learner.module.tokenized_prompts |
|
|
img_embed, scale_text_embed = model(images, text_tokens, prompt_token, prompts, tokenized_prompts, forward_type='Text_Prompt_Tuning') |
|
|
|
|
|
logits_per_image = img_embed @ scale_text_embed.t() |
|
|
logits_per_text = scale_text_embed @ img_embed.t() |
|
|
|
|
|
if is_embedding: |
|
|
return logits_per_image, logits_per_text, img_embed, scale_text_embed |
|
|
else: |
|
|
return logits_per_image, logits_per_text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def apply_multiplicative_noise(signal, beta=0.0): |
|
|
""" |
|
|
Apply Gaussian multiplicative noise to a signal. |
|
|
|
|
|
Parameters: |
|
|
signal (torch.Tensor): Tensor of shape (m, d) where m is the number of samples and d is the dimension. |
|
|
beta (float): Standard deviation of the Gaussian noise. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Noisy signal. |
|
|
""" |
|
|
m, d = signal.shape |
|
|
noise = torch.normal(mean=1.0, std=beta, size=(1, d)).cuda() |
|
|
noisy_signal = signal * noise |
|
|
return noisy_signal |
|
|
|
|
|
def multiGPU_CLIP_multiply_noise(model, images, text_tokens, prompt_token=None, is_embedding=False, beta=0.0): |
|
|
if prompt_token is not None: |
|
|
bs = images.size(0) |
|
|
prompt_token = prompt_token.repeat(bs, 1, 1) |
|
|
|
|
|
img_embed, scale_text_embed = model(images, text_tokens, prompt_token) |
|
|
|
|
|
|
|
|
img_embed = apply_multiplicative_noise(img_embed, beta) |
|
|
scale_text_embed = apply_multiplicative_noise(scale_text_embed, beta) |
|
|
|
|
|
|
|
|
logits_per_image = img_embed @ scale_text_embed.t() |
|
|
logits_per_text = scale_text_embed @ img_embed.t() |
|
|
|
|
|
if is_embedding: |
|
|
return logits_per_image, logits_per_text, img_embed, scale_text_embed |
|
|
else: |
|
|
return logits_per_image, logits_per_text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|