Safetensors
custom_code
jonggwon-park commited on
Commit
6205663
·
1 Parent(s): 2ba7893

auto model bug fix

Browse files
Files changed (1) hide show
  1. modeling_radzero.py +3 -64
modeling_radzero.py CHANGED
@@ -1,8 +1,7 @@
1
- import numpy as np
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
- from transformers import AutoTokenizer, BertModel
6
  from transformers.models.clip.modeling_clip import CLIPTextModel
7
  from transformers.models.mpnet.modeling_mpnet import MPNetModel
8
  from transformers.trainer import logger
@@ -11,8 +10,8 @@ from .align_transformers import build_align_transformer
11
  from .common_layers import BasePreTrainedModel
12
  from .configuration_radzero import CxrAlignConfig
13
  from .losses import KeyPhraseAlignmentLoss
14
- from .text_encoders import aggregate_tokens, build_text_encoder
15
- from .vision_encoders import MRM, Dinov2Model, build_vision_encoder
16
 
17
 
18
  class CxrAlignModel(BasePreTrainedModel):
@@ -28,13 +27,6 @@ class CxrAlignModel(BasePreTrainedModel):
28
  def build_text_model(self, config: CxrAlignConfig):
29
  text_config = config.text_config
30
  text_model = build_text_encoder(text_config)
31
-
32
- if text_config.model_type == "bioclinicalmpbert":
33
- self.tokenizer = AutoTokenizer.from_pretrained(
34
- text_config.pretrained_tokenizer_name_or_path
35
- )
36
- self.idxtoword = {v: k for k, v in self.tokenizer.get_vocab().items()}
37
-
38
  return text_model
39
 
40
  def build_align_transformer_model(self, config: CxrAlignConfig):
@@ -94,13 +86,7 @@ class CxrAlignModel(BasePreTrainedModel):
94
 
95
  if isinstance(self.vision_model, Dinov2Model):
96
  vision_tokens = self.vision_model(pixel_values)["last_hidden_state"]
97
- elif isinstance(self.vision_model, MRM):
98
- img_emb_g, img_emb_l = self.vision_model(pixel_values)
99
- img_emb_g = img_emb_g.unsqueeze(1)
100
- img_emb_l = img_emb_l.view(img_emb_l.size(0), img_emb_l.size(1), -1)
101
- img_emb_l = img_emb_l.permute(0, 2, 1)
102
 
103
- vision_tokens = torch.cat([img_emb_g, img_emb_l], dim=1)
104
  else:
105
  raise NotImplementedError
106
 
@@ -152,53 +138,6 @@ class CxrAlignModel(BasePreTrainedModel):
152
  token_embeddings * input_mask_expanded, 1
153
  ) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
154
 
155
- elif isinstance(self.text_model, BertModel):
156
- # BioClinicalMPBERT
157
-
158
- model_output = self.text_model(
159
- input_ids=encoded_input["input_ids"],
160
- attention_mask=encoded_input["attention_mask"],
161
- token_type_ids=encoded_input.get("token_type_ids", None),
162
- )
163
-
164
- if self.config.text_config.use_cls_token:
165
- text_features = model_output.last_hidden_state[:, 0, :]
166
-
167
- elif self.config.text_config.use_aggregate_tokens:
168
-
169
- all_embeddings = model_output[2]
170
- embeddings = torch.stack(
171
- all_embeddings[-self.config.text_config.last_n_layers :]
172
- )
173
- embeddings = embeddings.permute(1, 0, 2, 3)
174
-
175
- embeddings, sents = aggregate_tokens(
176
- embeddings, encoded_input["input_ids"], self.idxtoword
177
- )
178
- sent_embeddings = embeddings.mean(axis=2)
179
-
180
- if self.config.text_config.aggregate_method == "sum":
181
- word_embeddings = embeddings.sum(axis=1)
182
- sent_embeddings = sent_embeddings.sum(axis=1)
183
- elif self.config.text_config.aggregate_method == "mean":
184
- word_embeddings = embeddings.mean(axis=1)
185
- sent_embeddings = sent_embeddings.mean(axis=1)
186
-
187
- word_embeddings = word_embeddings.permute(0, 2, 1)
188
-
189
- text_features = sent_embeddings
190
- text_outputs["word_embeddings"] = word_embeddings
191
-
192
- else:
193
- text_features = model_output.last_hidden_state
194
- mask = encoded_input["attention_mask"].unsqueeze(-1).float()
195
- text_features = torch.sum(text_features * mask, dim=1) / torch.clamp(
196
- mask.sum(dim=1), min=1e-9
197
- )
198
-
199
- if self.text_projector is not None:
200
- text_features = self.text_projector(text_features)
201
-
202
  else:
203
  raise NotImplementedError
204
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
+ from transformers import BertModel
5
  from transformers.models.clip.modeling_clip import CLIPTextModel
6
  from transformers.models.mpnet.modeling_mpnet import MPNetModel
7
  from transformers.trainer import logger
 
10
  from .common_layers import BasePreTrainedModel
11
  from .configuration_radzero import CxrAlignConfig
12
  from .losses import KeyPhraseAlignmentLoss
13
+ from .text_encoders import build_text_encoder
14
+ from .vision_encoders import Dinov2Model, build_vision_encoder
15
 
16
 
17
  class CxrAlignModel(BasePreTrainedModel):
 
27
  def build_text_model(self, config: CxrAlignConfig):
28
  text_config = config.text_config
29
  text_model = build_text_encoder(text_config)
 
 
 
 
 
 
 
30
  return text_model
31
 
32
  def build_align_transformer_model(self, config: CxrAlignConfig):
 
86
 
87
  if isinstance(self.vision_model, Dinov2Model):
88
  vision_tokens = self.vision_model(pixel_values)["last_hidden_state"]
 
 
 
 
 
89
 
 
90
  else:
91
  raise NotImplementedError
92
 
 
138
  token_embeddings * input_mask_expanded, 1
139
  ) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  else:
142
  raise NotImplementedError
143