|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from transformers import BertModel |
|
|
from transformers.models.clip.modeling_clip import CLIPTextModel |
|
|
from transformers.models.mpnet.modeling_mpnet import MPNetModel |
|
|
from transformers.trainer import logger |
|
|
|
|
|
from .align_transformers import build_align_transformer |
|
|
from .common_layers import BasePreTrainedModel |
|
|
from .configuration_radzero import CxrAlignConfig |
|
|
from .losses import KeyPhraseAlignmentLoss |
|
|
from .text_encoders import build_text_encoder |
|
|
from .vision_encoders import Dinov2Model, build_vision_encoder |
|
|
|
|
|
|
|
|
class CxrAlignModel(BasePreTrainedModel): |
|
|
|
|
|
config_class = CxrAlignConfig |
|
|
|
|
|
def build_vision_model(self, config: CxrAlignConfig): |
|
|
vision_config = config.vision_config |
|
|
vision_config.pretrained_dir = config.pretrained_dir |
|
|
vision_model = build_vision_encoder(vision_config) |
|
|
return vision_model |
|
|
|
|
|
def build_text_model(self, config: CxrAlignConfig): |
|
|
text_config = config.text_config |
|
|
text_model = build_text_encoder(text_config) |
|
|
return text_model |
|
|
|
|
|
def build_align_transformer_model(self, config: CxrAlignConfig): |
|
|
align_transformer_config = config.align_transformer_config |
|
|
align_transformer = build_align_transformer(align_transformer_config) |
|
|
|
|
|
return align_transformer |
|
|
|
|
|
def __init__(self, config: CxrAlignConfig): |
|
|
super().__init__(config) |
|
|
|
|
|
logger.info("Build vision model ...") |
|
|
self.vision_model = self.build_vision_model(config) |
|
|
|
|
|
logger.info("Build text model ...") |
|
|
self.text_model = self.build_text_model(config) |
|
|
|
|
|
if ( |
|
|
isinstance(self.text_model, CLIPTextModel) |
|
|
or isinstance(self.text_model, MPNetModel) |
|
|
or isinstance(self.text_model, BertModel) |
|
|
): |
|
|
text_dim = self.text_model.config.hidden_size |
|
|
|
|
|
self.hidden_size = config.align_transformer_config.hidden_size |
|
|
|
|
|
if config.text_config.use_text_projection: |
|
|
self.text_projector = nn.Linear(text_dim, 2 * self.hidden_size) |
|
|
else: |
|
|
self.text_projector = None |
|
|
|
|
|
logger.info("Build align transformer model ...") |
|
|
self.align_transformer = self.build_align_transformer_model(config) |
|
|
|
|
|
logger.info("Build loss functions ...") |
|
|
loss_cfg = config.kwargs["loss"] |
|
|
self.loss_ratio = dict() |
|
|
self.loss_fns = nn.ModuleDict() |
|
|
for loss_type, ratio in zip(loss_cfg["apply"], loss_cfg["ratio"]): |
|
|
logger.info(f"Build {loss_type} loss function ...") |
|
|
if loss_cfg[loss_type] is None: |
|
|
loss_cfg[loss_type] = dict() |
|
|
if torch.distributed.is_available() and torch.distributed.is_initialized(): |
|
|
loss_cfg[loss_type]["rank"] = torch.distributed.get_rank() |
|
|
loss_cfg[loss_type]["world_size"] = torch.distributed.get_world_size() |
|
|
self.loss_fns[loss_type] = eval(loss_type)(**loss_cfg[loss_type]) |
|
|
self.loss_ratio[loss_type] = ratio |
|
|
|
|
|
self.compute_logits_type = config.kwargs.get("compute_logits_type") |
|
|
self.use_negative_logits = config.kwargs.get("use_negative_logits") |
|
|
|
|
|
self.module_to_update = config.kwargs.get("module_to_update") |
|
|
|
|
|
def forward_vision_model(self, pixel_values): |
|
|
|
|
|
if isinstance(self.vision_model, Dinov2Model): |
|
|
vision_tokens = self.vision_model(pixel_values)["last_hidden_state"] |
|
|
|
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
vision_tokens = self.align_transformer(vision_tokens) |
|
|
|
|
|
cls_token = vision_tokens[:, 0] |
|
|
patch_tokens = vision_tokens[:, 1:] |
|
|
image_features = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1) |
|
|
image_features = F.normalize(image_features, p=2, dim=1) |
|
|
|
|
|
outputs = {} |
|
|
outputs["vision_tokens"] = vision_tokens |
|
|
outputs["image_cls_token"] = cls_token |
|
|
outputs["image_patch_tokens"] = patch_tokens |
|
|
outputs["image_features"] = image_features |
|
|
|
|
|
return outputs |
|
|
|
|
|
def forward_text_model(self, encoded_input): |
|
|
text_outputs = {} |
|
|
|
|
|
if isinstance(self.text_model, MPNetModel): |
|
|
model_output = self.text_model( |
|
|
input_ids=encoded_input["input_ids"], |
|
|
attention_mask=encoded_input["attention_mask"], |
|
|
) |
|
|
|
|
|
token_embeddings = model_output[ |
|
|
0 |
|
|
] |
|
|
|
|
|
|
|
|
if self.text_projector is not None: |
|
|
token_embeddings = self.text_projector(token_embeddings) |
|
|
|
|
|
|
|
|
if self.config.text_config.use_cls_token: |
|
|
text_features = token_embeddings[:, 0, :] |
|
|
|
|
|
else: |
|
|
|
|
|
input_mask_expanded = ( |
|
|
encoded_input["attention_mask"] |
|
|
.unsqueeze(-1) |
|
|
.expand(token_embeddings.size()) |
|
|
.float() |
|
|
) |
|
|
text_features = torch.sum( |
|
|
token_embeddings * input_mask_expanded, 1 |
|
|
) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
|
|
|
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
text_outputs["text_features_wo_l2_norm"] = text_features |
|
|
text_outputs["text_features"] = F.normalize(text_features, p=2, dim=1) |
|
|
|
|
|
return text_outputs |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
pixel_values, |
|
|
encoded_key_phrases=None, |
|
|
return_loss=True, |
|
|
**kwargs, |
|
|
): |
|
|
vision_outputs = self.forward_vision_model(pixel_values) |
|
|
|
|
|
outputs = {} |
|
|
outputs.update(vision_outputs) |
|
|
|
|
|
|
|
|
if return_loss: |
|
|
loss = 0 |
|
|
losses = {} |
|
|
|
|
|
for loss_type, loss_fn in self.loss_fns.items(): |
|
|
if isinstance(loss_fn, KeyPhraseAlignmentLoss): |
|
|
loss_outputs = loss_fn( |
|
|
encoded_key_phrases, |
|
|
outputs["vision_tokens"], |
|
|
self.forward_text_model, |
|
|
) |
|
|
key_phrase_alignment_losses = loss_outputs["losses"] |
|
|
losses["key_phrase_alignment_loss"] = ( |
|
|
key_phrase_alignment_losses.pop("loss") |
|
|
) |
|
|
for loss_name, loss_value in key_phrase_alignment_losses.items(): |
|
|
losses[loss_name] = loss_value |
|
|
loop_loss = losses["key_phrase_alignment_loss"] |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
loss += loop_loss * self.loss_ratio[loss_type] |
|
|
|
|
|
losses["loss"] = loss |
|
|
|
|
|
outputs["losses"] = losses |
|
|
|
|
|
return outputs |
|
|
|
|
|
def compute_logits( |
|
|
self, |
|
|
pixel_values, |
|
|
encoded_key_phrases, |
|
|
**kwargs, |
|
|
): |
|
|
vision_outputs = self.forward_vision_model(pixel_values) |
|
|
|
|
|
outputs = {} |
|
|
|
|
|
if self.compute_logits_type == "key_phrase_alignment": |
|
|
|
|
|
splited_key_phrases = [ |
|
|
{ |
|
|
"input_ids": encoded_key_phrases[0]["input_ids"][i : i + 1], |
|
|
"attention_mask": encoded_key_phrases[0]["attention_mask"][ |
|
|
i : i + 1 |
|
|
], |
|
|
} |
|
|
for i in range(encoded_key_phrases[0]["input_ids"].size(0)) |
|
|
] |
|
|
|
|
|
loss_outputs = self.loss_fns["KeyPhraseAlignmentLoss"]( |
|
|
splited_key_phrases, |
|
|
vision_outputs["vision_tokens"], |
|
|
self.forward_text_model, |
|
|
ddp_gather=False, |
|
|
need_attn_weights=True, |
|
|
compute_loss=False, |
|
|
) |
|
|
outputs.update(loss_outputs) |
|
|
|
|
|
|
|
|
outputs["similarity_scores"] = torch.mean( |
|
|
torch.stack(loss_outputs["t2i_attn_weights"]), dim=0 |
|
|
) |
|
|
|
|
|
|
|
|
if self.loss_fns["KeyPhraseAlignmentLoss"].use_vision_cls_token: |
|
|
outputs["similarity_scores"] = outputs["similarity_scores"][:, :, 1:] |
|
|
|
|
|
|
|
|
logits = loss_outputs["t2i_logits"] |
|
|
logits = logits.T |
|
|
|
|
|
logits = ( |
|
|
logits / self.loss_fns["KeyPhraseAlignmentLoss"].loss_temperature.exp() |
|
|
) |
|
|
|
|
|
outputs["logits"] = logits |
|
|
return outputs |
|
|
|