Upload 3 files
Browse files- configuration_chexficient.py +1 -0
- modeling_chexficient.py +28 -27
configuration_chexficient.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
from transformers import PretrainedConfig
|
| 2 |
|
|
|
|
| 3 |
class CheXficientConfig(PretrainedConfig):
|
| 4 |
model_type = "chexficient_clip"
|
| 5 |
def __init__(
|
|
|
|
| 1 |
from transformers import PretrainedConfig
|
| 2 |
|
| 3 |
+
|
| 4 |
class CheXficientConfig(PretrainedConfig):
|
| 5 |
model_type = "chexficient_clip"
|
| 6 |
def __init__(
|
modeling_chexficient.py
CHANGED
|
@@ -24,8 +24,6 @@ URL_DICT = {
|
|
| 24 |
class TextEncoder(nn.Module):
|
| 25 |
def __init__(self, model_name='emilyalsentzer/Bio_ClinicalBERT'):
|
| 26 |
super().__init__()
|
| 27 |
-
# self.model = AutoModel.from_pretrained(model_name, ignore_mismatched_sizes=False, cache_dir='./huggingface',)
|
| 28 |
-
# self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir='./huggingface/tokenizers')
|
| 29 |
self.model = AutoModel.from_pretrained(model_name, use_safetensors=True, ignore_mismatched_sizes=False, )
|
| 30 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name, )
|
| 31 |
if self.tokenizer.bos_token_id is None:
|
|
@@ -72,14 +70,6 @@ class CheXficientModel(PreTrainedModel):
|
|
| 72 |
super().__init__(config)
|
| 73 |
|
| 74 |
# ===== Encoders =====
|
| 75 |
-
# self.image_encoder = AutoModel.from_pretrained(
|
| 76 |
-
# config.vision_model_name,
|
| 77 |
-
# use_safetensors=True
|
| 78 |
-
# )
|
| 79 |
-
# self.text_encoder = AutoModel.from_pretrained(
|
| 80 |
-
# config.text_model_name,
|
| 81 |
-
# use_safetensors=True
|
| 82 |
-
# )
|
| 83 |
self.image_encoder = ImageEncoder(model_name=config.vision_model_name, image_size=config.image_size)
|
| 84 |
self.text_encoder = TextEncoder(model_name=config.text_model_name)
|
| 85 |
|
|
@@ -97,30 +87,41 @@ class CheXficientModel(PreTrainedModel):
|
|
| 97 |
|
| 98 |
self.post_init()
|
| 99 |
|
| 100 |
-
def
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
return
|
| 105 |
|
| 106 |
-
def
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
def forward(
|
| 116 |
self,
|
| 117 |
pixel_values=None,
|
| 118 |
-
|
| 119 |
-
attention_mask=None,
|
| 120 |
return_loss=False
|
| 121 |
):
|
| 122 |
-
image_features = self.
|
| 123 |
-
text_features = self.
|
| 124 |
|
| 125 |
logit_scale = self.logit_scale.exp()
|
| 126 |
|
|
|
|
| 24 |
class TextEncoder(nn.Module):
|
| 25 |
def __init__(self, model_name='emilyalsentzer/Bio_ClinicalBERT'):
|
| 26 |
super().__init__()
|
|
|
|
|
|
|
| 27 |
self.model = AutoModel.from_pretrained(model_name, use_safetensors=True, ignore_mismatched_sizes=False, )
|
| 28 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name, )
|
| 29 |
if self.tokenizer.bos_token_id is None:
|
|
|
|
| 70 |
super().__init__(config)
|
| 71 |
|
| 72 |
# ===== Encoders =====
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
self.image_encoder = ImageEncoder(model_name=config.vision_model_name, image_size=config.image_size)
|
| 74 |
self.text_encoder = TextEncoder(model_name=config.text_model_name)
|
| 75 |
|
|
|
|
| 87 |
|
| 88 |
self.post_init()
|
| 89 |
|
| 90 |
+
def encode_image(self, pixel_values):
|
| 91 |
+
image_features = self.image_encoder(pixel_values)
|
| 92 |
+
image_embeddings = self.image_projection(image_features)
|
| 93 |
+
image_embeddings = image_embeddings / image_embeddings.norm(dim=1, keepdim=True)
|
| 94 |
+
return image_embeddings
|
| 95 |
|
| 96 |
+
def encode_text(self, text_tokens):
|
| 97 |
+
text_features = self.text_encoder(text_tokens)
|
| 98 |
+
|
| 99 |
+
if self.text_pooling == "eos":
|
| 100 |
+
# take features from the eot embedding (eos_token is the highest number in each sequence)
|
| 101 |
+
eos_token_indices = text_tokens["attention_mask"].sum(dim=-1) - 1
|
| 102 |
+
text_features = text_features[torch.arange(text_features.shape[0]), eos_token_indices]
|
| 103 |
+
elif self.text_pooling == "bos": # [CLS] token
|
| 104 |
+
text_features = text_features[:, 0]
|
| 105 |
+
elif self.text_pooling == "mean":
|
| 106 |
+
input_mask_expanded = text_tokens["attention_mask"].unsqueeze(axis=-1).expand(text_features.size()).float()
|
| 107 |
+
text_features = torch.sum(text_features * input_mask_expanded, axis=1) / torch.clamp(input_mask_expanded.sum(axis=1), min=1e-9)
|
| 108 |
+
else:
|
| 109 |
+
raise NotImplementedError("Not supported pooling method : %s", self.text_pooling)
|
| 110 |
+
|
| 111 |
+
text_embeddings = self.text_projection(text_features) if self.projection else text_features
|
| 112 |
+
|
| 113 |
+
text_embeddings = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)
|
| 114 |
+
|
| 115 |
+
return text_embeddings
|
| 116 |
|
| 117 |
def forward(
|
| 118 |
self,
|
| 119 |
pixel_values=None,
|
| 120 |
+
text_tokens=None,
|
|
|
|
| 121 |
return_loss=False
|
| 122 |
):
|
| 123 |
+
image_features = self.encode_image(pixel_values)
|
| 124 |
+
text_features = self.encode_text(text_tokens)
|
| 125 |
|
| 126 |
logit_scale = self.logit_scale.exp()
|
| 127 |
|