CheXficient / modeling_chexficient.py
cwangrun's picture
Upload modeling_chexficient.py
ba779f8 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
PreTrainedModel,
AutoTokenizer,
AutoModel
)
from dinov2.models.vision_transformer import vit_base
from projection import load_projection_head
from configuration_chexficient import CheXficientConfig
URL_DICT = {
"dinov2_vits14": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_pretrain.pth",
"dinov2_vitb14": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_pretrain.pth",
"dinov2_vitl14": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_pretrain.pth",
}
class TextEncoder(nn.Module):
def __init__(self, model_name='emilyalsentzer/Bio_ClinicalBERT'):
super().__init__()
self.model = AutoModel.from_pretrained(model_name, use_safetensors=True, ignore_mismatched_sizes=False, )
self.tokenizer = AutoTokenizer.from_pretrained(model_name, )
if self.tokenizer.bos_token_id is None:
self.tokenizer.bos_token_id = self.tokenizer.cls_token_id
self.out_dim = self.model.config.hidden_size
def forward(self, inputs):
outputs = self.model(**inputs)
return outputs["last_hidden_state"] # (batch, seq_len, hidden_size)
class ImageEncoder(nn.Module):
def __init__(self, model_name='dinov2_vitb14', image_size=224):
super().__init__()
self.model = vit_base(patch_size=14, img_size=image_size, init_values=1.0, block_chunks=0)
stact_dict = torch.hub.load_state_dict_from_url(URL_DICT[model_name], map_location="cpu")
##########################################################
if self.model.pos_embed.shape[1] != stact_dict['pos_embed'].shape[1]:
cls_pos_embed = stact_dict['pos_embed'][:, 0:1, :] # [1, hidden_dim]
patch_pos_embed = stact_dict['pos_embed'][:, 1:, :] # [1369, hidden_dim]
# raw patch grid size
orig_size = int(patch_pos_embed.shape[1] ** 0.5) # 37
new_size = image_size // self.model.patch_size # 512 // 16 = 32
patch_pos_embed = patch_pos_embed.reshape(1, orig_size, orig_size, -1).permute(0, 3, 1, 2) # [1, dim, 37, 37]
patch_pos_embed = F.interpolate(patch_pos_embed, size=(new_size, new_size), mode='bicubic', align_corners=False)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).reshape(1, new_size * new_size, -1)
stact_dict['pos_embed'] = torch.cat((cls_pos_embed, patch_pos_embed), dim=1) # [1, 1+new_size*new_size, dim]
##########################################################
res = self.model.load_state_dict(stact_dict, strict=False)
print('load dinov2 pretrained model:', res)
self.out_dim = self.model.embed_dim
def forward(self, x):
feats = self.model(x) # Shape: (b, d)
return feats
class CheXficientModel(PreTrainedModel):
config_class = CheXficientConfig
base_model_prefix = "chexficient"
def __init__(self, config: CheXficientConfig):
super().__init__(config)
# ===== Encoders =====
self.image_encoder = ImageEncoder(model_name=config.vision_model_name, image_size=config.image_size)
self.text_encoder = TextEncoder(model_name=config.text_model_name)
self.text_pooling = 'eos'
# ===== Projection heads =====
self.image_projection = load_projection_head(
embedding_dim=self.image_encoder.out_dim,
config_projection_head={'name': 'linear', 'dropout': 0.1, 'proj_dim': config.projection_dim}
)
self.text_projection = load_projection_head(
embedding_dim=self.text_encoder.out_dim,
config_projection_head={'name': 'linear', 'dropout': 0.1, 'proj_dim': config.projection_dim}
)
self.logit_scale = nn.Parameter(torch.ones([]) * 0.01)
self.post_init()
def encode_image(self, pixel_values):
image_features = self.image_encoder(pixel_values)
image_embeddings = self.image_projection(image_features)
image_embeddings = image_embeddings / image_embeddings.norm(dim=1, keepdim=True)
return image_embeddings
def encode_text(self, text_tokens):
text_features = self.text_encoder(text_tokens)
if self.text_pooling == "eos":
# take features from the eot embedding (eos_token is the highest number in each sequence)
eos_token_indices = text_tokens["attention_mask"].sum(dim=-1) - 1
text_features = text_features[torch.arange(text_features.shape[0]), eos_token_indices]
elif self.text_pooling == "bos": # [CLS] token
text_features = text_features[:, 0]
elif self.text_pooling == "mean":
input_mask_expanded = text_tokens["attention_mask"].unsqueeze(axis=-1).expand(text_features.size()).float()
text_features = torch.sum(text_features * input_mask_expanded, axis=1) / torch.clamp(input_mask_expanded.sum(axis=1), min=1e-9)
else:
raise NotImplementedError("Not supported pooling method : %s", self.text_pooling)
text_embeddings = self.text_projection(text_features)
text_embeddings = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)
return text_embeddings
def forward(
self,
pixel_values=None,
text_tokens=None,
return_loss=False
):
image_features = self.encode_image(pixel_values)
text_features = self.encode_text(text_tokens)
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
loss = None
if return_loss:
labels = torch.arange(len(logits_per_image)).to(logits_per_image.device)
loss_i = F.cross_entropy(logits_per_image, labels)
loss_t = F.cross_entropy(logits_per_text, labels)
loss = (loss_i + loss_t) / 2
return {
"loss": loss,
"logits_per_image": logits_per_image,
"logits_per_text": logits_per_text,
"image_embeds": image_features,
"text_embeds": text_features,
}