Robust_vlm / models /model.py
Yaning1001's picture
Add files using upload-large-folder tool
2b1b1ac verified
import torch
from modified_clip import clip
IMAGENET_MEAN = (0.48145466, 0.4578275, 0.40821073)
IMAGENET_STD = (0.26862954, 0.26130258, 0.27577711)
mu = torch.tensor(IMAGENET_MEAN).view(3, 1, 1)
std = torch.tensor(IMAGENET_STD).view(3, 1, 1)
def normalize(X):
return (X - mu.to(X.device)) / std.to(X.device)
def clip_img_preprocessing(X):
img_size = 224
# X = torch.nn.functional.upsample(X, size=(img_size, img_size), mode='bicubic')
X = torch.nn.functional.interpolate(X, size=(img_size, img_size), mode='bicubic')
X = normalize(X)
return X
def create_logits(x1, x2, logit_scale):
x1 = x1 / x1.norm(dim=-1, keepdim=True)
x2 = x2 / x2.norm(dim=-1, keepdim=True)
# cosine similarity as logits
logits_per_x1 = logit_scale * x1 @ x2.t()
logits_per_x2 = logit_scale * x2 @ x1.t()
return logits_per_x1, logits_per_x2
def multiGPU_CLIP_image_logits(images, model, text_tokens, prompter=None, add_prompter=None):
image_tokens = clip_img_preprocessing(images)
prompt_token = None if add_prompter is None else add_prompter()
if prompter is not None:
image_tokens = prompter(image_tokens)
return multiGPU_CLIP(model, image_tokens, text_tokens, prompt_token=prompt_token)[0]
def multiGPU_CLIP(model, images, text_tokens, prompt_token=None, is_embedding=False):
# print("text_token shape", text_tokens.shape)
if prompt_token is not None:
bs = images.size(0)
prompt_token = prompt_token.repeat(bs, 1, 1)
if images.size(0) == 1: # 2 GPUs
images = images.repeat(2,1,1,1)
img_embed, scale_text_embed = model(images, text_tokens, prompt_token)
img_embed = img_embed[0].unsqueeze(0)
# print("images_shape", images.shape)
# print("scale_text_embed_shape", scale_text_embed.shape)
elif text_tokens.size(0) == 2: # 4 GPUs
text_tokens = text_tokens.repeat(2,1)
img_embed, scale_text_embed = model(images, text_tokens, prompt_token)
text_tokens = text_tokens[0:2]
else:
img_embed, scale_text_embed = model(images, text_tokens, prompt_token)
# print("img_embed_shape", img_embed.shape, "scale_text_embed_shape", scale_text_embed.shape)
logits_per_image = img_embed @ scale_text_embed.t()
logits_per_text = scale_text_embed @ img_embed.t()
# print("img_emb_size", img_embed.shape)
# print("logits_size", logits_per_image.shape)
if is_embedding:
return logits_per_image, logits_per_text, img_embed, scale_text_embed
else:
return logits_per_image, logits_per_text
def multiGPU_CLIP_Text_Prompt_Tuning(model, images, text_tokens, prompt_token=None, prompt_learner=None, is_embedding=False):
if prompt_token is not None:
bs = images.size(0)
prompt_token = prompt_token.repeat(bs, 1, 1)
prompts = prompt_learner()
tokenized_prompts = prompt_learner.module.tokenized_prompts
img_embed, scale_text_embed = model(images, text_tokens, prompt_token, prompts, tokenized_prompts, forward_type='Text_Prompt_Tuning')
logits_per_image = img_embed @ scale_text_embed.t()
logits_per_text = scale_text_embed @ img_embed.t()
if is_embedding:
return logits_per_image, logits_per_text, img_embed, scale_text_embed
else:
return logits_per_image, logits_per_text
############################## Noise Modulated CLIP ###########################################
def apply_multiplicative_noise(signal, beta=0.0):
"""
Apply Gaussian multiplicative noise to a signal.
Parameters:
signal (torch.Tensor): Tensor of shape (m, d) where m is the number of samples and d is the dimension.
beta (float): Standard deviation of the Gaussian noise.
Returns:
torch.Tensor: Noisy signal.
"""
m, d = signal.shape
noise = torch.normal(mean=1.0, std=beta, size=(1, d)).cuda()
noisy_signal = signal * noise
return noisy_signal
def multiGPU_CLIP_multiply_noise(model, images, text_tokens, prompt_token=None, is_embedding=False, beta=0.0):
if prompt_token is not None:
bs = images.size(0)
prompt_token = prompt_token.repeat(bs, 1, 1)
img_embed, scale_text_embed = model(images, text_tokens, prompt_token)
### Noise modulate ###
img_embed = apply_multiplicative_noise(img_embed, beta)
scale_text_embed = apply_multiplicative_noise(scale_text_embed, beta)
### Noise modulate ###
logits_per_image = img_embed @ scale_text_embed.t()
logits_per_text = scale_text_embed @ img_embed.t()
if is_embedding:
return logits_per_image, logits_per_text, img_embed, scale_text_embed
else:
return logits_per_image, logits_per_text