ACE-LoRA / loss.py
aydnarda's picture
Upload loss.py
a56c7bf
import warnings
warnings.filterwarnings("ignore")
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
class Identity(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x
class CLIPLoss(nn.Module):
def __init__(self, args, logit_scale):
super(CLIPLoss, self).__init__()
self.args = args
if args.learnable_logit_scale:
self.logit_scale = nn.Parameter(logit_scale.clone().detach())
else:
self.register_buffer('logit_scale', logit_scale.clone().detach())
def forward(self, image_features, text_features, merged_df=None, indices=None):
device = image_features.device
batch_size, feature_dim = image_features.size()
labels = torch.arange(batch_size, device=device, dtype=torch.long)
logits_per_image = self.logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.T
if merged_df is not None:
compare_matrix = merged_df.iloc[indices, 2:].to_numpy()
vector_similarity_matrix = np.ones((compare_matrix.shape[0], compare_matrix.shape[0]), dtype=np.int32)
comparison = (compare_matrix[:, None, :] == compare_matrix[None, :, :]).all(axis=2)
vector_similarity_matrix[comparison] = 0
np.fill_diagonal(vector_similarity_matrix, 1)
vector_similarity_matrix = torch.from_numpy(vector_similarity_matrix).bool().to(device)
masked_logits_per_image = logits_per_image.masked_fill(~vector_similarity_matrix, float('-inf'))
masked_logits_per_text = logits_per_text.masked_fill(~vector_similarity_matrix.T, float('-inf'))
loss = (F.cross_entropy(masked_logits_per_image, labels) + F.cross_entropy(masked_logits_per_text, labels)) / 2
else:
loss = (F.cross_entropy(logits_per_image, labels) + F.cross_entropy(logits_per_text, labels)) / 2
return loss
class ResidualAdapter(nn.Module):
def __init__(self, dim, bottleneck_dim=128):
super().__init__()
self.down = nn.Linear(dim, bottleneck_dim)
self.act = nn.LeakyReLU(0.2)
self.up = nn.Linear(bottleneck_dim, dim)
nn.init.kaiming_normal_(self.down.weight)
nn.init.kaiming_normal_(self.up.weight)
def forward(self, x):
return self.up(self.act(self.down(x)))
class CLIPLossACE_HGAT(nn.Module):
def __init__(self, args, logit_scale, in_channels):
super(CLIPLossACE_HGAT, self).__init__()
set_random_seed(args.seed)
self.args = args
self.img_edge_adapter = ResidualAdapter(in_channels, args.hidden_features)
self.text_edge_adapter = ResidualAdapter(in_channels, args.hidden_features)
self.img_node_adapter = ResidualAdapter(in_channels, args.hidden_features)
self.text_node_adapter = ResidualAdapter(in_channels, args.hidden_features)
if args.learnable_logit_scale:
self.logit_scale = nn.Parameter(logit_scale.clone().detach())
else:
self.register_buffer('logit_scale', logit_scale.clone().detach())
def apply_ace_hgat(self, features, attn_weights, encoder="img"):
if encoder =="img":
edge_adapter = self.img_edge_adapter
node_adapter = self.img_node_adapter
elif encoder == 'text':
edge_adapter = self.text_edge_adapter
node_adapter = self.text_node_adapter
else:
raise ValueError(f"encoder must be img or text but given {encoder}")
B, N, D = features.shape
patches_norm = F.normalize(features[:, 1:, :], p=2, dim=-1)
# Similarity Matrix: (B, P, P)
sim = torch.zeros(size=(B, N, N), device=features.device)
patch_sim = torch.bmm(patches_norm, patches_norm.transpose(1, 2)) # [B, P, P]
sim[:, 1:, 1:] = patch_sim
sim[:, 0, 1:] = attn_weights
mask_logic = torch.eye(N, device=features.device).bool().unsqueeze(0).repeat(B, 1, 1)
mask_logic[:, 1:, 0] = True
sim = sim.masked_fill(mask_logic, -float('inf'))
topk_vals, topk_indices = torch.topk(sim, k=self.args.topk, dim=-1)
mask_sparse = torch.full_like(sim, -float('inf'))
mask_sparse.scatter_(-1, topk_indices, topk_vals)
A = F.softmax(mask_sparse, dim=-1)
A = A.masked_fill(torch.eye(N, device=features.device).bool().unsqueeze(0).repeat(B, 1, 1), 1)
A[:, 1:, 0] = A[:, 0, 1:]
H_edges_raw = torch.matmul(A, features)
H_edges_refined = edge_adapter(H_edges_raw)
H_context_raw = torch.matmul(A.transpose(1, 2), H_edges_refined)
H_context_processed = node_adapter(H_context_raw)
x_out = H_context_processed
return x_out
def forward(self, clip_model, images, texts, merged_df=None, indices=None):
device = images.device
clip_model.visual.trunk.global_pool = ''
image_features, img_attn_scores = clip_model.visual.trunk.get_attn_scores(images)
image_features = F.normalize(clip_model.visual.head(image_features), dim=-1)
text_features, text_attn_scores = clip_model.encode_text(texts, normalize=True, output_attentions=True, output_tokens=True)
img_attn_scores = img_attn_scores.mean(dim=1) # [B, 197, 197]
img_attn_weights = img_attn_scores[:, 0, 1:] # relationship between CLS token and patch embeddings [B, 196]
text_attn_scores = text_attn_scores[-1].mean(dim=1) # [B, 256, 256]
text_attn_weights = text_attn_scores[:, 0, 1:] # relationship between global token and other token embeddings [B, 255]
if self.args.apply_gnn_encoders == 'vision':
image_features = self.apply_ace_hgat(image_features, img_attn_weights, encoder="img")
image_features = F.normalize(image_features, dim=-1)
logits_per_image = self.logit_scale * image_features[:, 0] @ text_features[:, 0].t()
logits_per_text = logits_per_image.T
elif self.args.apply_gnn_encoders == 'text':
text_features = self.apply_ace_hgat(text_features, text_attn_weights, encoder="text")
text_features = F.normalize(text_features, dim=-1)
logits_per_image = self.logit_scale * image_features[:, 0] @ text_features[:, 0].t()
logits_per_text = logits_per_image.T
elif self.args.apply_gnn_encoders == 'both':
image_features = self.apply_ace_hgat(image_features, img_attn_weights, encoder="img")
image_features = F.normalize(image_features, dim=-1)
text_features = self.apply_ace_hgat(text_features, text_attn_weights, encoder="text")
text_features = F.normalize(text_features, dim=-1)
logits_per_image = self.logit_scale * image_features[:, 0] @ text_features[:, 0].t()
logits_per_text = logits_per_image.T
labels = torch.arange(image_features.shape[0], device=device, dtype=torch.long)
if logits_per_image.isnan().sum() > 0:
raise ValueError('NaN value in logits_per_image')
if merged_df is not None: # Label-Guided InfoNCE loss
compare_matrix = merged_df.iloc[indices, 2:].to_numpy()
vector_similarity_matrix = np.ones((compare_matrix.shape[0], compare_matrix.shape[0]), dtype=np.int32)
comparison = (compare_matrix[:, None, :] == compare_matrix[None, :, :]).all(axis=2)
vector_similarity_matrix[comparison] = 0
np.fill_diagonal(vector_similarity_matrix, 1)
vector_similarity_matrix = torch.from_numpy(vector_similarity_matrix).bool().to(device)
masked_logits_per_image = logits_per_image.masked_fill(~vector_similarity_matrix, float('-inf'))
masked_logits_per_text = logits_per_text.masked_fill(~vector_similarity_matrix.T, float('-inf'))
loss = (F.cross_entropy(masked_logits_per_image, labels) + F.cross_entropy(masked_logits_per_text, labels)) / 2
else:
loss = (F.cross_entropy(logits_per_image, labels) + F.cross_entropy(logits_per_text, labels)) / 2
return loss