Safetensors
custom_code
RadZero / modeling_radzero.py
jonggwon-park's picture
remove post_init
e5a56ee
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel
from transformers.models.clip.modeling_clip import CLIPTextModel
from transformers.models.mpnet.modeling_mpnet import MPNetModel
from transformers.trainer import logger
from .align_transformers import build_align_transformer
from .common_layers import BasePreTrainedModel
from .configuration_radzero import CxrAlignConfig
from .losses import KeyPhraseAlignmentLoss
from .text_encoders import build_text_encoder
from .vision_encoders import Dinov2Model, build_vision_encoder
class CxrAlignModel(BasePreTrainedModel):
config_class = CxrAlignConfig
def build_vision_model(self, config: CxrAlignConfig):
vision_config = config.vision_config
vision_config.pretrained_dir = config.pretrained_dir
vision_model = build_vision_encoder(vision_config)
return vision_model
def build_text_model(self, config: CxrAlignConfig):
text_config = config.text_config
text_model = build_text_encoder(text_config)
return text_model
def build_align_transformer_model(self, config: CxrAlignConfig):
align_transformer_config = config.align_transformer_config
align_transformer = build_align_transformer(align_transformer_config)
return align_transformer
def __init__(self, config: CxrAlignConfig):
super().__init__(config)
logger.info("Build vision model ...")
self.vision_model = self.build_vision_model(config)
logger.info("Build text model ...")
self.text_model = self.build_text_model(config)
if (
isinstance(self.text_model, CLIPTextModel)
or isinstance(self.text_model, MPNetModel)
or isinstance(self.text_model, BertModel)
):
text_dim = self.text_model.config.hidden_size
self.hidden_size = config.align_transformer_config.hidden_size
if config.text_config.use_text_projection:
self.text_projector = nn.Linear(text_dim, 2 * self.hidden_size)
else:
self.text_projector = None
logger.info("Build align transformer model ...")
self.align_transformer = self.build_align_transformer_model(config)
logger.info("Build loss functions ...")
loss_cfg = config.kwargs["loss"]
self.loss_ratio = dict()
self.loss_fns = nn.ModuleDict()
for loss_type, ratio in zip(loss_cfg["apply"], loss_cfg["ratio"]):
logger.info(f"Build {loss_type} loss function ...")
if loss_cfg[loss_type] is None:
loss_cfg[loss_type] = dict()
if torch.distributed.is_available() and torch.distributed.is_initialized():
loss_cfg[loss_type]["rank"] = torch.distributed.get_rank()
loss_cfg[loss_type]["world_size"] = torch.distributed.get_world_size()
self.loss_fns[loss_type] = eval(loss_type)(**loss_cfg[loss_type])
self.loss_ratio[loss_type] = ratio
self.compute_logits_type = config.kwargs.get("compute_logits_type")
self.use_negative_logits = config.kwargs.get("use_negative_logits")
self.module_to_update = config.kwargs.get("module_to_update")
def forward_vision_model(self, pixel_values):
if isinstance(self.vision_model, Dinov2Model):
vision_tokens = self.vision_model(pixel_values)["last_hidden_state"]
else:
raise NotImplementedError
vision_tokens = self.align_transformer(vision_tokens)
cls_token = vision_tokens[:, 0]
patch_tokens = vision_tokens[:, 1:]
image_features = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1)
image_features = F.normalize(image_features, p=2, dim=1)
outputs = {}
outputs["vision_tokens"] = vision_tokens
outputs["image_cls_token"] = cls_token
outputs["image_patch_tokens"] = patch_tokens
outputs["image_features"] = image_features
return outputs
def forward_text_model(self, encoded_input):
text_outputs = {}
if isinstance(self.text_model, MPNetModel):
model_output = self.text_model(
input_ids=encoded_input["input_ids"],
attention_mask=encoded_input["attention_mask"],
)
token_embeddings = model_output[
0
] # First element of model_output contains all token embeddings
# text embedding projection
if self.text_projector is not None:
token_embeddings = self.text_projector(token_embeddings)
# token_embeddings = self.text_projector(token_embeddings)
if self.config.text_config.use_cls_token:
text_features = token_embeddings[:, 0, :]
else:
# mean pooling
input_mask_expanded = (
encoded_input["attention_mask"]
.unsqueeze(-1)
.expand(token_embeddings.size())
.float()
)
text_features = torch.sum(
token_embeddings * input_mask_expanded, 1
) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
else:
raise NotImplementedError
text_outputs["text_features_wo_l2_norm"] = text_features
text_outputs["text_features"] = F.normalize(text_features, p=2, dim=1)
return text_outputs
def forward(
self,
pixel_values,
encoded_key_phrases=None,
return_loss=True,
**kwargs,
):
vision_outputs = self.forward_vision_model(pixel_values)
outputs = {}
outputs.update(vision_outputs)
# Trainer's self.can_return_loss is True if 'return_loss' is in model's forward function
if return_loss:
loss = 0
losses = {}
for loss_type, loss_fn in self.loss_fns.items():
if isinstance(loss_fn, KeyPhraseAlignmentLoss):
loss_outputs = loss_fn(
encoded_key_phrases,
outputs["vision_tokens"],
self.forward_text_model,
)
key_phrase_alignment_losses = loss_outputs["losses"]
losses["key_phrase_alignment_loss"] = (
key_phrase_alignment_losses.pop("loss")
)
for loss_name, loss_value in key_phrase_alignment_losses.items():
losses[loss_name] = loss_value
loop_loss = losses["key_phrase_alignment_loss"]
else:
raise NotImplementedError
loss += loop_loss * self.loss_ratio[loss_type]
losses["loss"] = loss
outputs["losses"] = losses
return outputs
def compute_logits(
self,
pixel_values,
encoded_key_phrases,
**kwargs,
):
vision_outputs = self.forward_vision_model(pixel_values)
outputs = {}
if self.compute_logits_type == "key_phrase_alignment":
splited_key_phrases = [
{
"input_ids": encoded_key_phrases[0]["input_ids"][i : i + 1],
"attention_mask": encoded_key_phrases[0]["attention_mask"][
i : i + 1
],
}
for i in range(encoded_key_phrases[0]["input_ids"].size(0))
]
loss_outputs = self.loss_fns["KeyPhraseAlignmentLoss"](
splited_key_phrases,
vision_outputs["vision_tokens"],
self.forward_text_model,
ddp_gather=False,
need_attn_weights=True,
compute_loss=False,
)
outputs.update(loss_outputs)
# mean attention weights from all layers
outputs["similarity_scores"] = torch.mean(
torch.stack(loss_outputs["t2i_attn_weights"]), dim=0
)
# remove attention score for cls token
if self.loss_fns["KeyPhraseAlignmentLoss"].use_vision_cls_token:
outputs["similarity_scores"] = outputs["similarity_scores"][:, :, 1:]
# compute logits
logits = loss_outputs["t2i_logits"]
logits = logits.T
logits = (
logits / self.loss_fns["KeyPhraseAlignmentLoss"].loss_temperature.exp()
)
outputs["logits"] = logits
return outputs