import torch from modified_clip import clip IMAGENET_MEAN = (0.48145466, 0.4578275, 0.40821073) IMAGENET_STD = (0.26862954, 0.26130258, 0.27577711) mu = torch.tensor(IMAGENET_MEAN).view(3, 1, 1) std = torch.tensor(IMAGENET_STD).view(3, 1, 1) def normalize(X): return (X - mu.to(X.device)) / std.to(X.device) def clip_img_preprocessing(X): img_size = 224 # X = torch.nn.functional.upsample(X, size=(img_size, img_size), mode='bicubic') X = torch.nn.functional.interpolate(X, size=(img_size, img_size), mode='bicubic') X = normalize(X) return X def create_logits(x1, x2, logit_scale): x1 = x1 / x1.norm(dim=-1, keepdim=True) x2 = x2 / x2.norm(dim=-1, keepdim=True) # cosine similarity as logits logits_per_x1 = logit_scale * x1 @ x2.t() logits_per_x2 = logit_scale * x2 @ x1.t() return logits_per_x1, logits_per_x2 def multiGPU_CLIP_image_logits(images, model, text_tokens, prompter=None, add_prompter=None): image_tokens = clip_img_preprocessing(images) prompt_token = None if add_prompter is None else add_prompter() if prompter is not None: image_tokens = prompter(image_tokens) return multiGPU_CLIP(model, image_tokens, text_tokens, prompt_token=prompt_token)[0] def multiGPU_CLIP(model, images, text_tokens, prompt_token=None, is_embedding=False): # print("text_token shape", text_tokens.shape) if prompt_token is not None: bs = images.size(0) prompt_token = prompt_token.repeat(bs, 1, 1) if images.size(0) == 1: # 2 GPUs images = images.repeat(2,1,1,1) img_embed, scale_text_embed = model(images, text_tokens, prompt_token) img_embed = img_embed[0].unsqueeze(0) # print("images_shape", images.shape) # print("scale_text_embed_shape", scale_text_embed.shape) elif text_tokens.size(0) == 2: # 4 GPUs text_tokens = text_tokens.repeat(2,1) img_embed, scale_text_embed = model(images, text_tokens, prompt_token) text_tokens = text_tokens[0:2] else: img_embed, scale_text_embed = model(images, text_tokens, prompt_token) # print("img_embed_shape", img_embed.shape, "scale_text_embed_shape", scale_text_embed.shape) logits_per_image = img_embed @ scale_text_embed.t() logits_per_text = scale_text_embed @ img_embed.t() # print("img_emb_size", img_embed.shape) # print("logits_size", logits_per_image.shape) if is_embedding: return logits_per_image, logits_per_text, img_embed, scale_text_embed else: return logits_per_image, logits_per_text def multiGPU_CLIP_Text_Prompt_Tuning(model, images, text_tokens, prompt_token=None, prompt_learner=None, is_embedding=False): if prompt_token is not None: bs = images.size(0) prompt_token = prompt_token.repeat(bs, 1, 1) prompts = prompt_learner() tokenized_prompts = prompt_learner.module.tokenized_prompts img_embed, scale_text_embed = model(images, text_tokens, prompt_token, prompts, tokenized_prompts, forward_type='Text_Prompt_Tuning') logits_per_image = img_embed @ scale_text_embed.t() logits_per_text = scale_text_embed @ img_embed.t() if is_embedding: return logits_per_image, logits_per_text, img_embed, scale_text_embed else: return logits_per_image, logits_per_text ############################## Noise Modulated CLIP ########################################### def apply_multiplicative_noise(signal, beta=0.0): """ Apply Gaussian multiplicative noise to a signal. Parameters: signal (torch.Tensor): Tensor of shape (m, d) where m is the number of samples and d is the dimension. beta (float): Standard deviation of the Gaussian noise. Returns: torch.Tensor: Noisy signal. """ m, d = signal.shape noise = torch.normal(mean=1.0, std=beta, size=(1, d)).cuda() noisy_signal = signal * noise return noisy_signal def multiGPU_CLIP_multiply_noise(model, images, text_tokens, prompt_token=None, is_embedding=False, beta=0.0): if prompt_token is not None: bs = images.size(0) prompt_token = prompt_token.repeat(bs, 1, 1) img_embed, scale_text_embed = model(images, text_tokens, prompt_token) ### Noise modulate ### img_embed = apply_multiplicative_noise(img_embed, beta) scale_text_embed = apply_multiplicative_noise(scale_text_embed, beta) ### Noise modulate ### logits_per_image = img_embed @ scale_text_embed.t() logits_per_text = scale_text_embed @ img_embed.t() if is_embedding: return logits_per_image, logits_per_text, img_embed, scale_text_embed else: return logits_per_image, logits_per_text