Commit
·
6205663
1
Parent(s):
2ba7893
auto model bug fix
Browse files- modeling_radzero.py +3 -64
modeling_radzero.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
| 1 |
-
import numpy as np
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
| 4 |
import torch.nn.functional as F
|
| 5 |
-
from transformers import
|
| 6 |
from transformers.models.clip.modeling_clip import CLIPTextModel
|
| 7 |
from transformers.models.mpnet.modeling_mpnet import MPNetModel
|
| 8 |
from transformers.trainer import logger
|
|
@@ -11,8 +10,8 @@ from .align_transformers import build_align_transformer
|
|
| 11 |
from .common_layers import BasePreTrainedModel
|
| 12 |
from .configuration_radzero import CxrAlignConfig
|
| 13 |
from .losses import KeyPhraseAlignmentLoss
|
| 14 |
-
from .text_encoders import
|
| 15 |
-
from .vision_encoders import
|
| 16 |
|
| 17 |
|
| 18 |
class CxrAlignModel(BasePreTrainedModel):
|
|
@@ -28,13 +27,6 @@ class CxrAlignModel(BasePreTrainedModel):
|
|
| 28 |
def build_text_model(self, config: CxrAlignConfig):
|
| 29 |
text_config = config.text_config
|
| 30 |
text_model = build_text_encoder(text_config)
|
| 31 |
-
|
| 32 |
-
if text_config.model_type == "bioclinicalmpbert":
|
| 33 |
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 34 |
-
text_config.pretrained_tokenizer_name_or_path
|
| 35 |
-
)
|
| 36 |
-
self.idxtoword = {v: k for k, v in self.tokenizer.get_vocab().items()}
|
| 37 |
-
|
| 38 |
return text_model
|
| 39 |
|
| 40 |
def build_align_transformer_model(self, config: CxrAlignConfig):
|
|
@@ -94,13 +86,7 @@ class CxrAlignModel(BasePreTrainedModel):
|
|
| 94 |
|
| 95 |
if isinstance(self.vision_model, Dinov2Model):
|
| 96 |
vision_tokens = self.vision_model(pixel_values)["last_hidden_state"]
|
| 97 |
-
elif isinstance(self.vision_model, MRM):
|
| 98 |
-
img_emb_g, img_emb_l = self.vision_model(pixel_values)
|
| 99 |
-
img_emb_g = img_emb_g.unsqueeze(1)
|
| 100 |
-
img_emb_l = img_emb_l.view(img_emb_l.size(0), img_emb_l.size(1), -1)
|
| 101 |
-
img_emb_l = img_emb_l.permute(0, 2, 1)
|
| 102 |
|
| 103 |
-
vision_tokens = torch.cat([img_emb_g, img_emb_l], dim=1)
|
| 104 |
else:
|
| 105 |
raise NotImplementedError
|
| 106 |
|
|
@@ -152,53 +138,6 @@ class CxrAlignModel(BasePreTrainedModel):
|
|
| 152 |
token_embeddings * input_mask_expanded, 1
|
| 153 |
) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
| 154 |
|
| 155 |
-
elif isinstance(self.text_model, BertModel):
|
| 156 |
-
# BioClinicalMPBERT
|
| 157 |
-
|
| 158 |
-
model_output = self.text_model(
|
| 159 |
-
input_ids=encoded_input["input_ids"],
|
| 160 |
-
attention_mask=encoded_input["attention_mask"],
|
| 161 |
-
token_type_ids=encoded_input.get("token_type_ids", None),
|
| 162 |
-
)
|
| 163 |
-
|
| 164 |
-
if self.config.text_config.use_cls_token:
|
| 165 |
-
text_features = model_output.last_hidden_state[:, 0, :]
|
| 166 |
-
|
| 167 |
-
elif self.config.text_config.use_aggregate_tokens:
|
| 168 |
-
|
| 169 |
-
all_embeddings = model_output[2]
|
| 170 |
-
embeddings = torch.stack(
|
| 171 |
-
all_embeddings[-self.config.text_config.last_n_layers :]
|
| 172 |
-
)
|
| 173 |
-
embeddings = embeddings.permute(1, 0, 2, 3)
|
| 174 |
-
|
| 175 |
-
embeddings, sents = aggregate_tokens(
|
| 176 |
-
embeddings, encoded_input["input_ids"], self.idxtoword
|
| 177 |
-
)
|
| 178 |
-
sent_embeddings = embeddings.mean(axis=2)
|
| 179 |
-
|
| 180 |
-
if self.config.text_config.aggregate_method == "sum":
|
| 181 |
-
word_embeddings = embeddings.sum(axis=1)
|
| 182 |
-
sent_embeddings = sent_embeddings.sum(axis=1)
|
| 183 |
-
elif self.config.text_config.aggregate_method == "mean":
|
| 184 |
-
word_embeddings = embeddings.mean(axis=1)
|
| 185 |
-
sent_embeddings = sent_embeddings.mean(axis=1)
|
| 186 |
-
|
| 187 |
-
word_embeddings = word_embeddings.permute(0, 2, 1)
|
| 188 |
-
|
| 189 |
-
text_features = sent_embeddings
|
| 190 |
-
text_outputs["word_embeddings"] = word_embeddings
|
| 191 |
-
|
| 192 |
-
else:
|
| 193 |
-
text_features = model_output.last_hidden_state
|
| 194 |
-
mask = encoded_input["attention_mask"].unsqueeze(-1).float()
|
| 195 |
-
text_features = torch.sum(text_features * mask, dim=1) / torch.clamp(
|
| 196 |
-
mask.sum(dim=1), min=1e-9
|
| 197 |
-
)
|
| 198 |
-
|
| 199 |
-
if self.text_projector is not None:
|
| 200 |
-
text_features = self.text_projector(text_features)
|
| 201 |
-
|
| 202 |
else:
|
| 203 |
raise NotImplementedError
|
| 204 |
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
import torch.nn.functional as F
|
| 4 |
+
from transformers import BertModel
|
| 5 |
from transformers.models.clip.modeling_clip import CLIPTextModel
|
| 6 |
from transformers.models.mpnet.modeling_mpnet import MPNetModel
|
| 7 |
from transformers.trainer import logger
|
|
|
|
| 10 |
from .common_layers import BasePreTrainedModel
|
| 11 |
from .configuration_radzero import CxrAlignConfig
|
| 12 |
from .losses import KeyPhraseAlignmentLoss
|
| 13 |
+
from .text_encoders import build_text_encoder
|
| 14 |
+
from .vision_encoders import Dinov2Model, build_vision_encoder
|
| 15 |
|
| 16 |
|
| 17 |
class CxrAlignModel(BasePreTrainedModel):
|
|
|
|
| 27 |
def build_text_model(self, config: CxrAlignConfig):
|
| 28 |
text_config = config.text_config
|
| 29 |
text_model = build_text_encoder(text_config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
return text_model
|
| 31 |
|
| 32 |
def build_align_transformer_model(self, config: CxrAlignConfig):
|
|
|
|
| 86 |
|
| 87 |
if isinstance(self.vision_model, Dinov2Model):
|
| 88 |
vision_tokens = self.vision_model(pixel_values)["last_hidden_state"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
|
|
|
|
| 90 |
else:
|
| 91 |
raise NotImplementedError
|
| 92 |
|
|
|
|
| 138 |
token_embeddings * input_mask_expanded, 1
|
| 139 |
) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
| 140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
else:
|
| 142 |
raise NotImplementedError
|
| 143 |
|