File size: 6,399 Bytes
075a0ed ba779f8 075a0ed 57e8231 075a0ed 57e8231 ba779f8 57e8231 075a0ed 57e8231 075a0ed 57e8231 075a0ed | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
PreTrainedModel,
AutoTokenizer,
AutoModel
)
from dinov2.models.vision_transformer import vit_base
from projection import load_projection_head
from configuration_chexficient import CheXficientConfig
URL_DICT = {
"dinov2_vits14": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_pretrain.pth",
"dinov2_vitb14": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_pretrain.pth",
"dinov2_vitl14": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_pretrain.pth",
}
class TextEncoder(nn.Module):
def __init__(self, model_name='emilyalsentzer/Bio_ClinicalBERT'):
super().__init__()
self.model = AutoModel.from_pretrained(model_name, use_safetensors=True, ignore_mismatched_sizes=False, )
self.tokenizer = AutoTokenizer.from_pretrained(model_name, )
if self.tokenizer.bos_token_id is None:
self.tokenizer.bos_token_id = self.tokenizer.cls_token_id
self.out_dim = self.model.config.hidden_size
def forward(self, inputs):
outputs = self.model(**inputs)
return outputs["last_hidden_state"] # (batch, seq_len, hidden_size)
class ImageEncoder(nn.Module):
def __init__(self, model_name='dinov2_vitb14', image_size=224):
super().__init__()
self.model = vit_base(patch_size=14, img_size=image_size, init_values=1.0, block_chunks=0)
stact_dict = torch.hub.load_state_dict_from_url(URL_DICT[model_name], map_location="cpu")
##########################################################
if self.model.pos_embed.shape[1] != stact_dict['pos_embed'].shape[1]:
cls_pos_embed = stact_dict['pos_embed'][:, 0:1, :] # [1, hidden_dim]
patch_pos_embed = stact_dict['pos_embed'][:, 1:, :] # [1369, hidden_dim]
# raw patch grid size
orig_size = int(patch_pos_embed.shape[1] ** 0.5) # 37
new_size = image_size // self.model.patch_size # 512 // 16 = 32
patch_pos_embed = patch_pos_embed.reshape(1, orig_size, orig_size, -1).permute(0, 3, 1, 2) # [1, dim, 37, 37]
patch_pos_embed = F.interpolate(patch_pos_embed, size=(new_size, new_size), mode='bicubic', align_corners=False)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).reshape(1, new_size * new_size, -1)
stact_dict['pos_embed'] = torch.cat((cls_pos_embed, patch_pos_embed), dim=1) # [1, 1+new_size*new_size, dim]
##########################################################
res = self.model.load_state_dict(stact_dict, strict=False)
print('load dinov2 pretrained model:', res)
self.out_dim = self.model.embed_dim
def forward(self, x):
feats = self.model(x) # Shape: (b, d)
return feats
class CheXficientModel(PreTrainedModel):
config_class = CheXficientConfig
base_model_prefix = "chexficient"
def __init__(self, config: CheXficientConfig):
super().__init__(config)
# ===== Encoders =====
self.image_encoder = ImageEncoder(model_name=config.vision_model_name, image_size=config.image_size)
self.text_encoder = TextEncoder(model_name=config.text_model_name)
self.text_pooling = 'eos'
# ===== Projection heads =====
self.image_projection = load_projection_head(
embedding_dim=self.image_encoder.out_dim,
config_projection_head={'name': 'linear', 'dropout': 0.1, 'proj_dim': config.projection_dim}
)
self.text_projection = load_projection_head(
embedding_dim=self.text_encoder.out_dim,
config_projection_head={'name': 'linear', 'dropout': 0.1, 'proj_dim': config.projection_dim}
)
self.logit_scale = nn.Parameter(torch.ones([]) * 0.01)
self.post_init()
def encode_image(self, pixel_values):
image_features = self.image_encoder(pixel_values)
image_embeddings = self.image_projection(image_features)
image_embeddings = image_embeddings / image_embeddings.norm(dim=1, keepdim=True)
return image_embeddings
def encode_text(self, text_tokens):
text_features = self.text_encoder(text_tokens)
if self.text_pooling == "eos":
# take features from the eot embedding (eos_token is the highest number in each sequence)
eos_token_indices = text_tokens["attention_mask"].sum(dim=-1) - 1
text_features = text_features[torch.arange(text_features.shape[0]), eos_token_indices]
elif self.text_pooling == "bos": # [CLS] token
text_features = text_features[:, 0]
elif self.text_pooling == "mean":
input_mask_expanded = text_tokens["attention_mask"].unsqueeze(axis=-1).expand(text_features.size()).float()
text_features = torch.sum(text_features * input_mask_expanded, axis=1) / torch.clamp(input_mask_expanded.sum(axis=1), min=1e-9)
else:
raise NotImplementedError("Not supported pooling method : %s", self.text_pooling)
text_embeddings = self.text_projection(text_features)
text_embeddings = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)
return text_embeddings
def forward(
self,
pixel_values=None,
text_tokens=None,
return_loss=False
):
image_features = self.encode_image(pixel_values)
text_features = self.encode_text(text_tokens)
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
loss = None
if return_loss:
labels = torch.arange(len(logits_per_image)).to(logits_per_image.device)
loss_i = F.cross_entropy(logits_per_image, labels)
loss_t = F.cross_entropy(logits_per_text, labels)
loss = (loss_i + loss_t) / 2
return {
"loss": loss,
"logits_per_image": logits_per_image,
"logits_per_text": logits_per_text,
"image_embeds": image_features,
"text_embeds": text_features,
}
|