import torch import torch.nn as nn import torch.nn.functional as F from transformers import ( PreTrainedModel, AutoTokenizer, AutoModel ) from dinov2.models.vision_transformer import vit_base from projection import load_projection_head from configuration_chexficient import CheXficientConfig URL_DICT = { "dinov2_vits14": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_pretrain.pth", "dinov2_vitb14": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_pretrain.pth", "dinov2_vitl14": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_pretrain.pth", } class TextEncoder(nn.Module): def __init__(self, model_name='emilyalsentzer/Bio_ClinicalBERT'): super().__init__() self.model = AutoModel.from_pretrained(model_name, use_safetensors=True, ignore_mismatched_sizes=False, ) self.tokenizer = AutoTokenizer.from_pretrained(model_name, ) if self.tokenizer.bos_token_id is None: self.tokenizer.bos_token_id = self.tokenizer.cls_token_id self.out_dim = self.model.config.hidden_size def forward(self, inputs): outputs = self.model(**inputs) return outputs["last_hidden_state"] # (batch, seq_len, hidden_size) class ImageEncoder(nn.Module): def __init__(self, model_name='dinov2_vitb14', image_size=224): super().__init__() self.model = vit_base(patch_size=14, img_size=image_size, init_values=1.0, block_chunks=0) stact_dict = torch.hub.load_state_dict_from_url(URL_DICT[model_name], map_location="cpu") ########################################################## if self.model.pos_embed.shape[1] != stact_dict['pos_embed'].shape[1]: cls_pos_embed = stact_dict['pos_embed'][:, 0:1, :] # [1, hidden_dim] patch_pos_embed = stact_dict['pos_embed'][:, 1:, :] # [1369, hidden_dim] # raw patch grid size orig_size = int(patch_pos_embed.shape[1] ** 0.5) # 37 new_size = image_size // self.model.patch_size # 512 // 16 = 32 patch_pos_embed = patch_pos_embed.reshape(1, orig_size, orig_size, -1).permute(0, 3, 1, 2) # [1, dim, 37, 37] patch_pos_embed = F.interpolate(patch_pos_embed, size=(new_size, new_size), mode='bicubic', align_corners=False) patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).reshape(1, new_size * new_size, -1) stact_dict['pos_embed'] = torch.cat((cls_pos_embed, patch_pos_embed), dim=1) # [1, 1+new_size*new_size, dim] ########################################################## res = self.model.load_state_dict(stact_dict, strict=False) print('load dinov2 pretrained model:', res) self.out_dim = self.model.embed_dim def forward(self, x): feats = self.model(x) # Shape: (b, d) return feats class CheXficientModel(PreTrainedModel): config_class = CheXficientConfig base_model_prefix = "chexficient" def __init__(self, config: CheXficientConfig): super().__init__(config) # ===== Encoders ===== self.image_encoder = ImageEncoder(model_name=config.vision_model_name, image_size=config.image_size) self.text_encoder = TextEncoder(model_name=config.text_model_name) self.text_pooling = 'eos' # ===== Projection heads ===== self.image_projection = load_projection_head( embedding_dim=self.image_encoder.out_dim, config_projection_head={'name': 'linear', 'dropout': 0.1, 'proj_dim': config.projection_dim} ) self.text_projection = load_projection_head( embedding_dim=self.text_encoder.out_dim, config_projection_head={'name': 'linear', 'dropout': 0.1, 'proj_dim': config.projection_dim} ) self.logit_scale = nn.Parameter(torch.ones([]) * 0.01) self.post_init() def encode_image(self, pixel_values): image_features = self.image_encoder(pixel_values) image_embeddings = self.image_projection(image_features) image_embeddings = image_embeddings / image_embeddings.norm(dim=1, keepdim=True) return image_embeddings def encode_text(self, text_tokens): text_features = self.text_encoder(text_tokens) if self.text_pooling == "eos": # take features from the eot embedding (eos_token is the highest number in each sequence) eos_token_indices = text_tokens["attention_mask"].sum(dim=-1) - 1 text_features = text_features[torch.arange(text_features.shape[0]), eos_token_indices] elif self.text_pooling == "bos": # [CLS] token text_features = text_features[:, 0] elif self.text_pooling == "mean": input_mask_expanded = text_tokens["attention_mask"].unsqueeze(axis=-1).expand(text_features.size()).float() text_features = torch.sum(text_features * input_mask_expanded, axis=1) / torch.clamp(input_mask_expanded.sum(axis=1), min=1e-9) else: raise NotImplementedError("Not supported pooling method : %s", self.text_pooling) text_embeddings = self.text_projection(text_features) text_embeddings = text_embeddings / text_embeddings.norm(dim=1, keepdim=True) return text_embeddings def forward( self, pixel_values=None, text_tokens=None, return_loss=False ): image_features = self.encode_image(pixel_values) text_features = self.encode_text(text_tokens) logit_scale = self.logit_scale.exp() logits_per_image = logit_scale * image_features @ text_features.t() logits_per_text = logits_per_image.t() loss = None if return_loss: labels = torch.arange(len(logits_per_image)).to(logits_per_image.device) loss_i = F.cross_entropy(logits_per_image, labels) loss_t = F.cross_entropy(logits_per_text, labels) loss = (loss_i + loss_t) / 2 return { "loss": loss, "logits_per_image": logits_per_image, "logits_per_text": logits_per_text, "image_embeds": image_features, "text_embeds": text_features, }