Safetensors
custom_code
RadZero / losses.py
nobottle
add modeling files
4333430
import math
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
class KeyPhraseAlignmentLoss(nn.Module):
def __init__(
self,
hidden_dim=768,
use_vision_cls_token=True,
attn_temperature=None,
loss_temperature=0.07,
text_features_l2_norm=False,
mpnce_row_sum=False,
mpnce_col_sum=False,
sim_op="cos",
use_layer_norm=True,
**kwargs,
):
super().__init__()
self.hidden_dim = hidden_dim
self.layer_norm = nn.LayerNorm(hidden_dim) if use_layer_norm else None
self.use_vision_cls_token = use_vision_cls_token
self.loss_temperature = nn.Parameter(
torch.FloatTensor([np.log(loss_temperature)])
)
if attn_temperature is not None:
self.attn_temperature = nn.Parameter(
torch.FloatTensor([np.log(attn_temperature)])
)
else:
self.attn_temperature = None
self.text_features_l2_norm = text_features_l2_norm
self.sim_op = sim_op
self.similarity_logit = SimilarityLogit(sim_op)
self.mpnce_row_sum = mpnce_row_sum
self.mpnce_col_sum = mpnce_col_sum
def forward(
self,
key_phrases,
vision_tokens,
forward_text_model,
ddp_gather=True,
need_attn_weights=False,
compute_loss=True,
**kwargs,
):
outputs = {}
text_features, group_map = self.compute_text_features(
key_phrases, forward_text_model, ddp_gather
)
if ddp_gather and dist.is_initialized():
vision_tokens = torch.cat(dist.nn.all_gather(vision_tokens), dim=0)
if self.layer_norm is not None:
vision_tokens = self.layer_norm(vision_tokens)
vision_patch_tokens = vision_tokens[:, 1:]
# text to image cross-attention
if not self.use_vision_cls_token:
vision_attn_tokens = vision_patch_tokens
else:
vision_attn_tokens = vision_tokens
t2i_logits, t2i_attn_weights_list = self.compute_t2i_logits(
text_features, vision_attn_tokens, need_attn_weights
)
outputs["t2i_logits"] = t2i_logits
outputs["t2i_attn_weights"] = t2i_attn_weights_list
if compute_loss:
losses = {}
loss = 0
# compute t2i loss
t2i_loss = multi_positive_nce_loss(
t2i_logits,
group_map,
temperature=self.loss_temperature.exp(),
row_sum=self.mpnce_row_sum,
col_sum=self.mpnce_col_sum,
)
loss += t2i_loss
losses["t2i_loss"] = t2i_loss
losses["loss"] = loss
outputs["losses"] = losses
return outputs
def compute_text_features(self, key_phrases, forward_text_model, ddp_gather=True):
key_text_features_list = list()
group_list = list()
B_local = len(key_phrases)
# Calculate offset by getting the rank of the current process when using DDP
local_rank = dist.get_rank() if (ddp_gather and dist.is_initialized()) else 0
for i, kp in enumerate(key_phrases):
feats = forward_text_model(kp)
# (N_i, D)
if self.text_features_l2_norm:
feat = feats["text_features"]
else:
feat = feats["text_features_wo_l2_norm"]
if feat.shape[-1] == 2 * self.hidden_dim:
feat = feat[:, self.hidden_dim :]
key_text_features_list.append(feat)
# Add local_rank * B_local offset to local index i
global_index = i + local_rank * B_local
group_list.extend([global_index] * feat.size(0))
text_features = torch.cat(key_text_features_list, dim=0)
group_map = torch.tensor(group_list, device=text_features.device)
if ddp_gather and dist.is_initialized():
# Gather text_features and image_features and group_map
text_features = pad_and_gather(text_features)
group_map = pad_and_gather(group_map)
group_map = group_map.long()
if self.layer_norm is not None:
text_features = self.layer_norm(text_features)
return text_features, group_map
def compute_t2i_logits(
self, text_features, vision_attn_tokens, need_attn_weights, repeat=True
):
t2i_logits, t2i_attn_weights_list = self.similarity_logit(
text_features,
vision_attn_tokens,
need_attn_weights,
repeat=repeat,
temperature=(
self.attn_temperature.exp()
if self.attn_temperature is not None
else self.loss_temperature.exp()
),
)
return t2i_logits, t2i_attn_weights_list
class SimilarityLogit(nn.Module):
def __init__(self, sim_op="dot", **kwargs):
super().__init__()
self.sim_op = sim_op
def forward(
self,
queries: torch.Tensor,
local_tokens: torch.Tensor,
need_attn_weights: bool = False,
repeat: bool = True,
**kwargs,
):
if repeat:
query_attn_features = queries.unsqueeze(0).expand(
local_tokens.shape[0], queries.shape[0], queries.shape[1]
)
else:
assert queries.dim() == 3
query_attn_features = queries
if self.sim_op == "cos":
temperature = kwargs.get("temperature")
assert temperature is not None
denominator = temperature
query_attn_features = F.normalize(query_attn_features, p=2, dim=-1)
local_tokens = F.normalize(local_tokens, p=2, dim=-1)
elif self.sim_op == "dot":
denominator = math.sqrt(local_tokens.size(-1))
else:
raise NotImplementedError
scores = (
torch.bmm(query_attn_features, local_tokens.permute(0, 2, 1)) / denominator
)
attn_weights = F.softmax(scores, dim=-1)
aggregated = torch.matmul(attn_weights, local_tokens)
query_attn_features = F.normalize(query_attn_features, p=2, dim=-1)
aggregated = F.normalize(aggregated, p=2, dim=-1)
logits = torch.matmul(
query_attn_features.unsqueeze(2), aggregated.unsqueeze(-1)
).squeeze()
logits = logits.T
if need_attn_weights:
attn_scores = [scores]
else:
attn_scores = None
return logits, attn_scores
def multi_positive_nce_loss(
logits: torch.Tensor,
group_map: torch.Tensor,
temperature: float = 1.0,
eps: float = 1e-8,
row_sum: bool = False,
col_sum: bool = False,
):
"""
Args:
logits: tensor of shape (N_total, B_global), each row is a logit between a key phrase and each candidate image.
group_map: tensor of shape (N_total,), source image index of each key phrase.
temperature: scaling factor.
For each key phrase row i, the positive is the candidate image index == group_map[i],
and the rest are treated as negatives.
For each column j, each positive for image j is considered independently.
Returns:
loss: scalar tensor.
"""
scaled_logits = torch.exp(logits / temperature) # (N_total, B_global)
pos_logits = scaled_logits[
torch.arange(scaled_logits.size(0)), group_map
] # (N_total,)
row_loss = get_row_loss(
scaled_logits,
pos_logits,
group_map,
eps,
row_sum,
)
neg_mask = torch.ones_like(scaled_logits)
neg_mask[torch.arange(scaled_logits.size(0)), group_map] = 0 # (N_total, B_global)
column_loss = get_col_loss(
scaled_logits,
pos_logits,
neg_mask,
group_map,
eps,
col_sum,
)
loss = (row_loss.mean() + column_loss.mean()) / 2
return loss
def get_row_loss(
logits: torch.Tensor,
pos_logits: torch.Tensor,
group_map: torch.Tensor,
eps: float = 1e-8,
row_sum: bool = False,
):
if row_sum:
# Create a tensor to hold the summed values
row_sum_logits = torch.zeros(
logits.shape[-1], device=logits.device
) # (B_global)
row_pos_sum_logits = torch.zeros(
logits.shape[-1], device=logits.device
) # (B_global)
# Use scatter_add to sum values based on group_map
row_sum_logits.scatter_add_(0, group_map, logits.sum(dim=1)) # (B_global)
row_pos_sum_logits.scatter_add_(0, group_map, pos_logits) # (B_global)
p_row = row_pos_sum_logits / (row_sum_logits + eps) # (B_global)
else:
row_sum_logits = logits.sum(dim=1) # (N_total)
p_row = pos_logits / (row_sum_logits + eps) # (N_total)
return -torch.log(p_row + eps)
def get_col_loss(
logits: torch.Tensor,
pos_logits: torch.Tensor,
neg_mask: torch.Tensor,
group_map: torch.Tensor,
eps: float = 1e-8,
col_sum: bool = False,
):
if col_sum:
# MIL-NCE loss
column_sum_logits = logits.sum(dim=0) # (B_global,)
pos_mask = torch.ones_like(logits) - neg_mask # (N_total, B_global)
column_pos_logits = (logits * pos_mask).sum(dim=0) # (B_global,)
p_column = column_pos_logits / (column_sum_logits + eps) # (B_global,)
else:
# MP-NCE loss (UniCLIP)
neg_logits = logits * neg_mask # (N_total, B_global)
sum_neg_logits = neg_logits.sum(dim=0) # (B_global,)
sum_neg_logits = sum_neg_logits[group_map] # (N_total)
p_column = pos_logits / (pos_logits + sum_neg_logits + eps) # (N_total)
return -torch.log(p_column + eps)
def pad_and_gather(tensor):
# Determine the size of the tensor
local_size = torch.tensor(tensor.size(), device=tensor.device)
# Gather all sizes
all_sizes = [torch.zeros_like(local_size) for _ in range(dist.get_world_size())]
dist.all_gather(all_sizes, local_size)
# Determine the maximum size
max_size = torch.stack(all_sizes).max(dim=0)[0]
# Pad the tensor to the maximum size
padded_tensor = torch.zeros(max_size.tolist(), device=tensor.device)
padded_tensor[: local_size[0]] = tensor
# Gather all padded tensors
gathered_tensors = dist.nn.all_gather(padded_tensor)
# Trim the gathered tensors to their original sizes
gathered_tensors = [g[: s[0]] for g, s in zip(gathered_tensors, all_sizes)]
gathered_tensors = torch.cat(gathered_tensors, dim=0)
return gathered_tensors