ianpan commited on
Commit
fa7dbe5
·
1 Parent(s): 26fb764

Upload MRIBrainSequenceBERT

Browse files
Files changed (4) hide show
  1. config.json +1 -1
  2. configuration.py +1 -1
  3. model.safetensors +3 -0
  4. modeling.py +47 -29
config.json CHANGED
@@ -10,6 +10,6 @@
10
  "dtype": "float32",
11
  "max_len": 512,
12
  "model_type": "mri_brain_sequence_bert",
13
- "num_classes": 16,
14
  "transformers_version": "4.57.3"
15
  }
 
10
  "dtype": "float32",
11
  "max_len": 512,
12
  "model_type": "mri_brain_sequence_bert",
13
+ "num_classes": 17,
14
  "transformers_version": "4.57.3"
15
  }
configuration.py CHANGED
@@ -4,7 +4,7 @@ from transformers import PretrainedConfig
4
  class MRIBrainSequenceBERTConfig(PretrainedConfig):
5
  model_type = "mri_brain_sequence_bert"
6
 
7
- def __init__(self, max_len=512, dropout=0.2, num_classes=16, **kwargs):
8
  self.max_len = max_len
9
  self.dropout = dropout
10
  self.num_classes = num_classes
 
4
  class MRIBrainSequenceBERTConfig(PretrainedConfig):
5
  model_type = "mri_brain_sequence_bert"
6
 
7
+ def __init__(self, max_len=512, dropout=0.2, num_classes=17, **kwargs):
8
  self.max_len = max_len
9
  self.dropout = dropout
10
  self.num_classes = num_classes
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba42fffeeb4437d9883787fdd868f19594989f19771a95d5761c53e95db48ea9
3
+ size 1196973888
modeling.py CHANGED
@@ -12,18 +12,34 @@ from transformers import (
12
  from .configuration import MRIBrainSequenceBERTConfig
13
 
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  class MRIBrainSequenceBERT(PreTrainedModel):
16
  config_class = MRIBrainSequenceBERTConfig
17
 
18
  def __init__(self, config):
19
  super().__init__(config)
20
  self.model_id = "answerdotai/ModernBERT-base"
21
- self.llm = AutoModelForSequenceClassification.from_pretrained(self.model_id)
22
- self.dim_feats = self.llm.classifier.in_features
23
- self.dropout = nn.Dropout(p=config.dropout)
24
- self.classifier = nn.Linear(self.dim_feats, config.num_classes)
25
- self.llm.dropout = nn.Identity()
26
- self.llm.classifier = nn.Identity()
27
 
28
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
29
  self.max_len = config.max_len
@@ -33,7 +49,7 @@ class MRIBrainSequenceBERT(PreTrainedModel):
33
  "ImageType",
34
  "Manufacturer",
35
  "ManufacturerModelName",
36
- # "ContrastBolusAgent",
37
  "ScanningSequence",
38
  "SequenceVariant",
39
  "ScanOptions",
@@ -54,8 +70,8 @@ class MRIBrainSequenceBERT(PreTrainedModel):
54
  "PercentSampling",
55
  "PercentPhaseFieldOfView",
56
  "PixelBandwidth",
57
- # "ContrastBolusVolume",
58
- # "ContrastBolusTotalDose",
59
  "AcquisitionMatrix",
60
  "InPlanePhaseEncodingDirection",
61
  "FlipAngle",
@@ -72,22 +88,23 @@ class MRIBrainSequenceBERT(PreTrainedModel):
72
  ]
73
 
74
  self.label2index = {
75
- "t1": 0, # T1 precontrast
76
- "t1c": 1, # T1 postcontrast
77
- "t2": 2, # T2
78
- "flair": 3, # T2-FLAIR
79
- "dwi": 4, # DWI trace
80
- "adc": 5, # ADC map
81
- "dti": 6, # DTI
82
- "swi": 7, # SWI
83
- "swi_mip": 8, # SWI MinIP
84
- "phase": 9, # SWI phase images
85
- "mag": 10, # SWI mag images
86
- "gre": 11, # T2* GRE
87
- "perf": 12, # Perfusion-related images
88
- "pd": 13, # Proton density
89
- "loc": 14, # Localizers
90
- "other": 15, # Other, NOS
 
91
  }
92
 
93
  self.index2label = {v: k for k, v in self.label2index.items()}
@@ -106,10 +123,11 @@ class MRIBrainSequenceBERT(PreTrainedModel):
106
  for k, v in x.items():
107
  x[k] = v.to(device)
108
 
109
- features = self.llm(**x)["logits"]
110
- logits = self.classifier(self.dropout(features))
111
- if apply_softmax:
112
- logits = torch.softmax(logits, dim=1)
 
113
  return logits
114
 
115
  def create_string_from_dicom(
 
12
  from .configuration import MRIBrainSequenceBERTConfig
13
 
14
 
15
+ class SingleModel(nn.Module):
16
+ def __init__(self, config, model_id: str):
17
+ super().__init__()
18
+ self.llm = AutoModelForSequenceClassification.from_pretrained(model_id)
19
+ self.dim_feats = self.llm.classifier.in_features
20
+ self.dropout = nn.Dropout(p=config.dropout)
21
+ self.classifier = nn.Linear(self.dim_feats, config.num_classes)
22
+ self.llm.dropout = nn.Identity()
23
+ self.llm.classifier = nn.Identity()
24
+
25
+ def forward(self, x, apply_softmax: bool = True):
26
+ features = self.llm(**x)["logits"]
27
+ logits = self.classifier(self.dropout(features))
28
+ if apply_softmax:
29
+ logits = torch.softmax(logits, dim=1)
30
+ return logits
31
+
32
+
33
  class MRIBrainSequenceBERT(PreTrainedModel):
34
  config_class = MRIBrainSequenceBERTConfig
35
 
36
  def __init__(self, config):
37
  super().__init__(config)
38
  self.model_id = "answerdotai/ModernBERT-base"
39
+ self.m1 = SingleModel(config, self.model_id)
40
+ self.m2 = SingleModel(config, self.model_id)
41
+
42
+ self.ensemble = True
 
 
43
 
44
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
45
  self.max_len = config.max_len
 
49
  "ImageType",
50
  "Manufacturer",
51
  "ManufacturerModelName",
52
+ "ContrastBolusAgent",
53
  "ScanningSequence",
54
  "SequenceVariant",
55
  "ScanOptions",
 
70
  "PercentSampling",
71
  "PercentPhaseFieldOfView",
72
  "PixelBandwidth",
73
+ "ContrastBolusVolume",
74
+ "ContrastBolusTotalDose",
75
  "AcquisitionMatrix",
76
  "InPlanePhaseEncodingDirection",
77
  "FlipAngle",
 
88
  ]
89
 
90
  self.label2index = {
91
+ "t1": 0,
92
+ "t1c": 1,
93
+ "t2": 2,
94
+ "flair": 3,
95
+ "dwi": 4,
96
+ "adc": 5,
97
+ "eadc": 6,
98
+ "swi": 7,
99
+ "swi_mag": 8,
100
+ "swi_phase": 9,
101
+ "swi_minip": 10,
102
+ "t2_gre": 11,
103
+ "perfusion": 12,
104
+ "pd": 13,
105
+ "mra": 14,
106
+ "loc": 15,
107
+ "other": 16,
108
  }
109
 
110
  self.index2label = {v: k for k, v in self.label2index.items()}
 
123
  for k, v in x.items():
124
  x[k] = v.to(device)
125
 
126
+ logits = self.m1(x, apply_softmax=apply_softmax)
127
+ if self.ensemble:
128
+ logits += self.m2(x, apply_softmax=apply_softmax)
129
+ logits /= 2.0
130
+
131
  return logits
132
 
133
  def create_string_from_dicom(