cwangrun commited on
Commit
57e8231
·
verified ·
1 Parent(s): e8ca894

Upload 3 files

Browse files
configuration_chexficient.py CHANGED
@@ -1,5 +1,6 @@
1
  from transformers import PretrainedConfig
2
 
 
3
  class CheXficientConfig(PretrainedConfig):
4
  model_type = "chexficient_clip"
5
  def __init__(
 
1
  from transformers import PretrainedConfig
2
 
3
+
4
  class CheXficientConfig(PretrainedConfig):
5
  model_type = "chexficient_clip"
6
  def __init__(
modeling_chexficient.py CHANGED
@@ -24,8 +24,6 @@ URL_DICT = {
24
  class TextEncoder(nn.Module):
25
  def __init__(self, model_name='emilyalsentzer/Bio_ClinicalBERT'):
26
  super().__init__()
27
- # self.model = AutoModel.from_pretrained(model_name, ignore_mismatched_sizes=False, cache_dir='./huggingface',)
28
- # self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir='./huggingface/tokenizers')
29
  self.model = AutoModel.from_pretrained(model_name, use_safetensors=True, ignore_mismatched_sizes=False, )
30
  self.tokenizer = AutoTokenizer.from_pretrained(model_name, )
31
  if self.tokenizer.bos_token_id is None:
@@ -72,14 +70,6 @@ class CheXficientModel(PreTrainedModel):
72
  super().__init__(config)
73
 
74
  # ===== Encoders =====
75
- # self.image_encoder = AutoModel.from_pretrained(
76
- # config.vision_model_name,
77
- # use_safetensors=True
78
- # )
79
- # self.text_encoder = AutoModel.from_pretrained(
80
- # config.text_model_name,
81
- # use_safetensors=True
82
- # )
83
  self.image_encoder = ImageEncoder(model_name=config.vision_model_name, image_size=config.image_size)
84
  self.text_encoder = TextEncoder(model_name=config.text_model_name)
85
 
@@ -97,30 +87,41 @@ class CheXficientModel(PreTrainedModel):
97
 
98
  self.post_init()
99
 
100
- def get_image_features(self, pixel_values):
101
- vision_outputs = self.image_encoder(pixel_values=pixel_values)
102
- pooled = vision_outputs.last_hidden_state[:, 0]
103
- projected = self.image_projection(pooled)
104
- return F.normalize(projected, dim=-1)
105
 
106
- def get_text_features(self, input_ids, attention_mask):
107
- text_outputs = self.text_encoder(
108
- input_ids=input_ids,
109
- attention_mask=attention_mask
110
- )
111
- pooled = text_outputs.last_hidden_state[:, 0]
112
- projected = self.text_projection(pooled)
113
- return F.normalize(projected, dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  def forward(
116
  self,
117
  pixel_values=None,
118
- input_ids=None,
119
- attention_mask=None,
120
  return_loss=False
121
  ):
122
- image_features = self.get_image_features(pixel_values)
123
- text_features = self.get_text_features(input_ids, attention_mask)
124
 
125
  logit_scale = self.logit_scale.exp()
126
 
 
24
  class TextEncoder(nn.Module):
25
  def __init__(self, model_name='emilyalsentzer/Bio_ClinicalBERT'):
26
  super().__init__()
 
 
27
  self.model = AutoModel.from_pretrained(model_name, use_safetensors=True, ignore_mismatched_sizes=False, )
28
  self.tokenizer = AutoTokenizer.from_pretrained(model_name, )
29
  if self.tokenizer.bos_token_id is None:
 
70
  super().__init__(config)
71
 
72
  # ===== Encoders =====
 
 
 
 
 
 
 
 
73
  self.image_encoder = ImageEncoder(model_name=config.vision_model_name, image_size=config.image_size)
74
  self.text_encoder = TextEncoder(model_name=config.text_model_name)
75
 
 
87
 
88
  self.post_init()
89
 
90
+ def encode_image(self, pixel_values):
91
+ image_features = self.image_encoder(pixel_values)
92
+ image_embeddings = self.image_projection(image_features)
93
+ image_embeddings = image_embeddings / image_embeddings.norm(dim=1, keepdim=True)
94
+ return image_embeddings
95
 
96
+ def encode_text(self, text_tokens):
97
+ text_features = self.text_encoder(text_tokens)
98
+
99
+ if self.text_pooling == "eos":
100
+ # take features from the eot embedding (eos_token is the highest number in each sequence)
101
+ eos_token_indices = text_tokens["attention_mask"].sum(dim=-1) - 1
102
+ text_features = text_features[torch.arange(text_features.shape[0]), eos_token_indices]
103
+ elif self.text_pooling == "bos": # [CLS] token
104
+ text_features = text_features[:, 0]
105
+ elif self.text_pooling == "mean":
106
+ input_mask_expanded = text_tokens["attention_mask"].unsqueeze(axis=-1).expand(text_features.size()).float()
107
+ text_features = torch.sum(text_features * input_mask_expanded, axis=1) / torch.clamp(input_mask_expanded.sum(axis=1), min=1e-9)
108
+ else:
109
+ raise NotImplementedError("Not supported pooling method : %s", self.text_pooling)
110
+
111
+ text_embeddings = self.text_projection(text_features) if self.projection else text_features
112
+
113
+ text_embeddings = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)
114
+
115
+ return text_embeddings
116
 
117
  def forward(
118
  self,
119
  pixel_values=None,
120
+ text_tokens=None,
 
121
  return_loss=False
122
  ):
123
+ image_features = self.encode_image(pixel_values)
124
+ text_features = self.encode_text(text_tokens)
125
 
126
  logit_scale = self.logit_scale.exp()
127