cwangrun commited on
Commit
075a0ed
·
verified ·
1 Parent(s): 5ebfade

Upload 5 files

Browse files
config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "CheXficientModel"
4
+ ],
5
+ "image_size": 378,
6
+ "model_type": "chexficient_clip",
7
+ "projection_dim": 512,
8
+ "text_model_name": "emilyalsentzer/Bio_ClinicalBERT",
9
+ "torch_dtype": "float32",
10
+ "transformers_version": "4.51.3",
11
+ "vision_model_name": "dinov2_vitb14",
12
+ "auto_map": {
13
+ "AutoConfig": "configuration_chexficient.CheXficientConfig",
14
+ "AutoModel": "modeling_chexficient.CheXficientModel"
15
+ }
16
+ }
configuration_chexficient.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class CheXficientConfig(PretrainedConfig):
4
+ model_type = "chexficient_clip"
5
+ def __init__(
6
+ self,
7
+ vision_model_name="dinov2_vitb14",
8
+ text_model_name="emilyalsentzer/Bio_ClinicalBERT",
9
+ projection_dim=512,
10
+ image_size=378,
11
+ **kwargs
12
+ ):
13
+ super().__init__(**kwargs)
14
+
15
+ self.vision_model_name = vision_model_name
16
+ self.text_model_name = text_model_name
17
+ self.projection_dim = projection_dim
18
+ self.image_size = image_size
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f691262c3f77c3c850bebff420180602cf4ca5d5214449e377366f3205548336
3
+ size 780793036
modeling_chexficient.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from transformers import (
6
+ PreTrainedModel,
7
+ AutoTokenizer,
8
+ AutoModel
9
+ )
10
+
11
+ from dinov2.models.vision_transformer import vit_base
12
+ from projection import load_projection_head
13
+ from configuration_chexficient import CheXficientConfig
14
+
15
+
16
+ URL_DICT = {
17
+ "dinov2_vits14": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_pretrain.pth",
18
+ "dinov2_vitb14": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_pretrain.pth",
19
+ "dinov2_vitl14": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_pretrain.pth",
20
+ }
21
+
22
+
23
+
24
+ class TextEncoder(nn.Module):
25
+ def __init__(self, model_name='emilyalsentzer/Bio_ClinicalBERT'):
26
+ super().__init__()
27
+ # self.model = AutoModel.from_pretrained(model_name, ignore_mismatched_sizes=False, cache_dir='./huggingface',)
28
+ # self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir='./huggingface/tokenizers')
29
+ self.model = AutoModel.from_pretrained(model_name, use_safetensors=True, ignore_mismatched_sizes=False, )
30
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, )
31
+ if self.tokenizer.bos_token_id is None:
32
+ self.tokenizer.bos_token_id = self.tokenizer.cls_token_id
33
+ self.out_dim = self.model.config.hidden_size
34
+
35
+ def forward(self, inputs):
36
+ outputs = self.model(**inputs)
37
+ return outputs["last_hidden_state"] # (batch, seq_len, hidden_size)
38
+
39
+
40
+ class ImageEncoder(nn.Module):
41
+ def __init__(self, model_name='dinov2_vitb14', image_size=224):
42
+ super().__init__()
43
+ self.model = vit_base(patch_size=14, img_size=image_size, init_values=1.0, block_chunks=0)
44
+ stact_dict = torch.hub.load_state_dict_from_url(URL_DICT[model_name], map_location="cpu")
45
+ ##########################################################
46
+ if self.model.pos_embed.shape[1] != stact_dict['pos_embed'].shape[1]:
47
+ cls_pos_embed = stact_dict['pos_embed'][:, 0:1, :] # [1, hidden_dim]
48
+ patch_pos_embed = stact_dict['pos_embed'][:, 1:, :] # [1369, hidden_dim]
49
+ # raw patch grid size
50
+ orig_size = int(patch_pos_embed.shape[1] ** 0.5) # 37
51
+ new_size = image_size // self.model.patch_size # 512 // 16 = 32
52
+ patch_pos_embed = patch_pos_embed.reshape(1, orig_size, orig_size, -1).permute(0, 3, 1, 2) # [1, dim, 37, 37]
53
+ patch_pos_embed = F.interpolate(patch_pos_embed, size=(new_size, new_size), mode='bicubic', align_corners=False)
54
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).reshape(1, new_size * new_size, -1)
55
+ stact_dict['pos_embed'] = torch.cat((cls_pos_embed, patch_pos_embed), dim=1) # [1, 1+new_size*new_size, dim]
56
+ ##########################################################
57
+ res = self.model.load_state_dict(stact_dict, strict=False)
58
+ print('load dinov2 pretrained model:', res)
59
+ self.out_dim = self.model.embed_dim
60
+
61
+ def forward(self, x):
62
+ feats = self.model(x) # Shape: (b, d)
63
+ return feats
64
+
65
+
66
+
67
+ class CheXficientModel(PreTrainedModel):
68
+ config_class = CheXficientConfig
69
+ base_model_prefix = "chexficient"
70
+
71
+ def __init__(self, config: CheXficientConfig):
72
+ super().__init__(config)
73
+
74
+ # ===== Encoders =====
75
+ # self.image_encoder = AutoModel.from_pretrained(
76
+ # config.vision_model_name,
77
+ # use_safetensors=True
78
+ # )
79
+ # self.text_encoder = AutoModel.from_pretrained(
80
+ # config.text_model_name,
81
+ # use_safetensors=True
82
+ # )
83
+ self.image_encoder = ImageEncoder(model_name=config.vision_model_name, image_size=config.image_size)
84
+ self.text_encoder = TextEncoder(model_name=config.text_model_name)
85
+
86
+ # ===== Projection heads =====
87
+ self.image_projection = load_projection_head(
88
+ embedding_dim=self.image_encoder.out_dim,
89
+ config_projection_head={'name': 'linear', 'dropout': 0.1, 'proj_dim': config.projection_dim}
90
+ )
91
+ self.text_projection = load_projection_head(
92
+ embedding_dim=self.text_encoder.out_dim,
93
+ config_projection_head={'name': 'linear', 'dropout': 0.1, 'proj_dim': config.projection_dim}
94
+ )
95
+
96
+ self.logit_scale = nn.Parameter(torch.ones([]) * 0.01)
97
+
98
+ self.post_init()
99
+
100
+ def get_image_features(self, pixel_values):
101
+ vision_outputs = self.image_encoder(pixel_values=pixel_values)
102
+ pooled = vision_outputs.last_hidden_state[:, 0]
103
+ projected = self.image_projection(pooled)
104
+ return F.normalize(projected, dim=-1)
105
+
106
+ def get_text_features(self, input_ids, attention_mask):
107
+ text_outputs = self.text_encoder(
108
+ input_ids=input_ids,
109
+ attention_mask=attention_mask
110
+ )
111
+ pooled = text_outputs.last_hidden_state[:, 0]
112
+ projected = self.text_projection(pooled)
113
+ return F.normalize(projected, dim=-1)
114
+
115
+ def forward(
116
+ self,
117
+ pixel_values=None,
118
+ input_ids=None,
119
+ attention_mask=None,
120
+ return_loss=False
121
+ ):
122
+ image_features = self.get_image_features(pixel_values)
123
+ text_features = self.get_text_features(input_ids, attention_mask)
124
+
125
+ logit_scale = self.logit_scale.exp()
126
+
127
+ logits_per_image = logit_scale * image_features @ text_features.t()
128
+ logits_per_text = logits_per_image.t()
129
+
130
+ loss = None
131
+ if return_loss:
132
+ labels = torch.arange(len(logits_per_image)).to(logits_per_image.device)
133
+ loss_i = F.cross_entropy(logits_per_image, labels)
134
+ loss_t = F.cross_entropy(logits_per_text, labels)
135
+ loss = (loss_i + loss_t) / 2
136
+
137
+ return {
138
+ "loss": loss,
139
+ "logits_per_image": logits_per_image,
140
+ "logits_per_text": logits_per_text,
141
+ "image_embeds": image_features,
142
+ "text_embeds": text_features,
143
+ }
144
+
projection.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from typing import Dict
3
+
4
+
5
+ class MLPProjectionHead(nn.Module):
6
+ def __init__(self, embedding_dim, projection_dim, dropout):
7
+ super().__init__()
8
+ self.projection = nn.Linear(embedding_dim, projection_dim)
9
+ self.gelu = nn.GELU()
10
+ self.fc = nn.Linear(projection_dim, projection_dim)
11
+ self.dropout = nn.Dropout(dropout)
12
+ self.layer_norm = nn.LayerNorm(projection_dim)
13
+
14
+ def forward(self, x):
15
+ projected = self.projection(x)
16
+ x = self.gelu(projected)
17
+ x = self.fc(x)
18
+ x = self.dropout(x)
19
+ x = x + projected
20
+ x = self.layer_norm(x)
21
+ return x
22
+
23
+
24
+ class LinearProjectionHead(nn.Module):
25
+ def __init__(self, embedding_dim, projection_dim):
26
+ super().__init__()
27
+ self.projection = nn.Linear(embedding_dim, projection_dim)
28
+
29
+ def forward(self, x):
30
+ return self.projection(x)
31
+
32
+
33
+ def load_projection_head(embedding_dim: int, config_projection_head: Dict):
34
+ if config_projection_head["name"].lower() == "mlp":
35
+ projection_head = MLPProjectionHead(
36
+ embedding_dim=embedding_dim, projection_dim=config_projection_head["proj_dim"], dropout=config_projection_head["dropout"]
37
+ )
38
+ elif config_projection_head["name"].lower() == "linear":
39
+ projection_head = LinearProjectionHead(embedding_dim=embedding_dim, projection_dim=config_projection_head["proj_dim"])
40
+ else:
41
+ raise KeyError(f"Not supported text encoder: {config_projection_head}")
42
+ return projection_head