| | import torch
|
| | import torch.nn as nn
|
| | import torch.nn.functional as F
|
| |
|
| | from transformers import (
|
| | PreTrainedModel,
|
| | AutoTokenizer,
|
| | AutoModel
|
| | )
|
| |
|
| | from dinov2.models.vision_transformer import vit_base
|
| | from projection import load_projection_head
|
| | from configuration_chexficient import CheXficientConfig
|
| |
|
| |
|
| | URL_DICT = {
|
| | "dinov2_vits14": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_pretrain.pth",
|
| | "dinov2_vitb14": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_pretrain.pth",
|
| | "dinov2_vitl14": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_pretrain.pth",
|
| | }
|
| |
|
| |
|
| |
|
| | class TextEncoder(nn.Module):
|
| | def __init__(self, model_name='emilyalsentzer/Bio_ClinicalBERT'):
|
| | super().__init__()
|
| | self.model = AutoModel.from_pretrained(model_name, use_safetensors=True, ignore_mismatched_sizes=False, )
|
| | self.tokenizer = AutoTokenizer.from_pretrained(model_name, )
|
| | if self.tokenizer.bos_token_id is None:
|
| | self.tokenizer.bos_token_id = self.tokenizer.cls_token_id
|
| | self.out_dim = self.model.config.hidden_size
|
| |
|
| | def forward(self, inputs):
|
| | outputs = self.model(**inputs)
|
| | return outputs["last_hidden_state"]
|
| |
|
| |
|
| | class ImageEncoder(nn.Module):
|
| | def __init__(self, model_name='dinov2_vitb14', image_size=224):
|
| | super().__init__()
|
| | self.model = vit_base(patch_size=14, img_size=image_size, init_values=1.0, block_chunks=0)
|
| | stact_dict = torch.hub.load_state_dict_from_url(URL_DICT[model_name], map_location="cpu")
|
| |
|
| | if self.model.pos_embed.shape[1] != stact_dict['pos_embed'].shape[1]:
|
| | cls_pos_embed = stact_dict['pos_embed'][:, 0:1, :]
|
| | patch_pos_embed = stact_dict['pos_embed'][:, 1:, :]
|
| |
|
| | orig_size = int(patch_pos_embed.shape[1] ** 0.5)
|
| | new_size = image_size // self.model.patch_size
|
| | patch_pos_embed = patch_pos_embed.reshape(1, orig_size, orig_size, -1).permute(0, 3, 1, 2)
|
| | patch_pos_embed = F.interpolate(patch_pos_embed, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
| | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).reshape(1, new_size * new_size, -1)
|
| | stact_dict['pos_embed'] = torch.cat((cls_pos_embed, patch_pos_embed), dim=1)
|
| |
|
| | res = self.model.load_state_dict(stact_dict, strict=False)
|
| | print('load dinov2 pretrained model:', res)
|
| | self.out_dim = self.model.embed_dim
|
| |
|
| | def forward(self, x):
|
| | feats = self.model(x)
|
| | return feats
|
| |
|
| |
|
| |
|
| | class CheXficientModel(PreTrainedModel):
|
| | config_class = CheXficientConfig
|
| | base_model_prefix = "chexficient"
|
| |
|
| | def __init__(self, config: CheXficientConfig):
|
| | super().__init__(config)
|
| |
|
| |
|
| | self.image_encoder = ImageEncoder(model_name=config.vision_model_name, image_size=config.image_size)
|
| | self.text_encoder = TextEncoder(model_name=config.text_model_name)
|
| | self.text_pooling = 'eos'
|
| |
|
| |
|
| | self.image_projection = load_projection_head(
|
| | embedding_dim=self.image_encoder.out_dim,
|
| | config_projection_head={'name': 'linear', 'dropout': 0.1, 'proj_dim': config.projection_dim}
|
| | )
|
| | self.text_projection = load_projection_head(
|
| | embedding_dim=self.text_encoder.out_dim,
|
| | config_projection_head={'name': 'linear', 'dropout': 0.1, 'proj_dim': config.projection_dim}
|
| | )
|
| |
|
| | self.logit_scale = nn.Parameter(torch.ones([]) * 0.01)
|
| |
|
| | self.post_init()
|
| |
|
| | def encode_image(self, pixel_values):
|
| | image_features = self.image_encoder(pixel_values)
|
| | image_embeddings = self.image_projection(image_features)
|
| | image_embeddings = image_embeddings / image_embeddings.norm(dim=1, keepdim=True)
|
| | return image_embeddings
|
| |
|
| | def encode_text(self, text_tokens):
|
| | text_features = self.text_encoder(text_tokens)
|
| |
|
| | if self.text_pooling == "eos":
|
| |
|
| | eos_token_indices = text_tokens["attention_mask"].sum(dim=-1) - 1
|
| | text_features = text_features[torch.arange(text_features.shape[0]), eos_token_indices]
|
| | elif self.text_pooling == "bos":
|
| | text_features = text_features[:, 0]
|
| | elif self.text_pooling == "mean":
|
| | input_mask_expanded = text_tokens["attention_mask"].unsqueeze(axis=-1).expand(text_features.size()).float()
|
| | text_features = torch.sum(text_features * input_mask_expanded, axis=1) / torch.clamp(input_mask_expanded.sum(axis=1), min=1e-9)
|
| | else:
|
| | raise NotImplementedError("Not supported pooling method : %s", self.text_pooling)
|
| |
|
| | text_embeddings = self.text_projection(text_features)
|
| |
|
| | text_embeddings = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)
|
| |
|
| | return text_embeddings
|
| |
|
| | def forward(
|
| | self,
|
| | pixel_values=None,
|
| | text_tokens=None,
|
| | return_loss=False
|
| | ):
|
| | image_features = self.encode_image(pixel_values)
|
| | text_features = self.encode_text(text_tokens)
|
| |
|
| | logit_scale = self.logit_scale.exp()
|
| |
|
| | logits_per_image = logit_scale * image_features @ text_features.t()
|
| | logits_per_text = logits_per_image.t()
|
| |
|
| | loss = None
|
| | if return_loss:
|
| | labels = torch.arange(len(logits_per_image)).to(logits_per_image.device)
|
| | loss_i = F.cross_entropy(logits_per_image, labels)
|
| | loss_t = F.cross_entropy(logits_per_text, labels)
|
| | loss = (loss_i + loss_t) / 2
|
| |
|
| | return {
|
| | "loss": loss,
|
| | "logits_per_image": logits_per_image,
|
| | "logits_per_text": logits_per_text,
|
| | "image_embeds": image_features,
|
| | "text_embeds": text_features,
|
| | }
|
| |
|
| |
|