File size: 5,180 Bytes
0a015a7 fa7dbe5 0a015a7 26fb764 fa7dbe5 0a015a7 26fb764 0a015a7 fa7dbe5 0a015a7 fa7dbe5 0a015a7 fa7dbe5 0a015a7 fa7dbe5 0a015a7 d6725e0 0a015a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import numpy as np
import pydicom
import torch
import torch.nn as nn
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
PreTrainedModel,
)
from .configuration import MRIBrainSequenceBERTConfig
class SingleModel(nn.Module):
def __init__(self, config, model_id: str):
super().__init__()
self.llm = AutoModelForSequenceClassification.from_pretrained(model_id)
self.dim_feats = self.llm.classifier.in_features
self.dropout = nn.Dropout(p=config.dropout)
self.classifier = nn.Linear(self.dim_feats, config.num_classes)
self.llm.dropout = nn.Identity()
self.llm.classifier = nn.Identity()
def forward(self, x, apply_softmax: bool = True):
features = self.llm(**x)["logits"]
logits = self.classifier(self.dropout(features))
if apply_softmax:
logits = torch.softmax(logits, dim=1)
return logits
class MRIBrainSequenceBERT(PreTrainedModel):
config_class = MRIBrainSequenceBERTConfig
def __init__(self, config):
super().__init__(config)
self.model_id = "answerdotai/ModernBERT-base"
self.m1 = SingleModel(config, self.model_id)
self.m2 = SingleModel(config, self.model_id)
self.ensemble = True
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
self.max_len = config.max_len
self.metadata_elements = [
"SeriesDescription",
"ImageType",
"Manufacturer",
"ManufacturerModelName",
"ContrastBolusAgent",
"ScanningSequence",
"SequenceVariant",
"ScanOptions",
"MRAcquisitionType",
"SequenceName",
"AngioFlag",
"SliceThickness",
"RepetitionTime",
"EchoTime",
"InversionTime",
"NumberOfAverages",
"ImagingFrequency",
"ImagedNucleus",
"EchoNumbers",
"SpacingBetweenSlices",
"NumberOfPhaseEncodingSteps",
"EchoTrainLength",
"PercentSampling",
"PercentPhaseFieldOfView",
"PixelBandwidth",
"ContrastBolusVolume",
"ContrastBolusTotalDose",
"AcquisitionMatrix",
"InPlanePhaseEncodingDirection",
"FlipAngle",
"VariableFlipAngleFlag",
"SAR",
"dBdt",
"SeriesNumber",
"AcquisitionNumber",
"PhotometricInterpretation",
"PixelSpacing",
"ImagesInAcquisition",
"SmallestImagePixelValue",
"LargestImagePixelValue",
]
self.label2index = {
"t1": 0,
"t1c": 1,
"t2": 2,
"flair": 3,
"dwi": 4,
"adc": 5,
"eadc": 6,
"swi": 7,
"swi_mag": 8,
"swi_phase": 9,
"swi_minip": 10,
"t2_gre": 11,
"perfusion": 12,
"pd": 13,
"mra": 14,
"loc": 15,
"other": 16,
}
self.index2label = {v: k for k, v in self.label2index.items()}
def forward(
self, x: str, device: str | torch.device = "cpu", apply_softmax: bool = True
):
x = self.tokenizer(
x,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.max_len,
)
for k, v in x.items():
x[k] = v.to(device)
logits = self.m1(x, apply_softmax=apply_softmax)
if self.ensemble:
logits += self.m2(x, apply_softmax=apply_softmax)
logits /= 2.0
return logits
def create_string_from_dicom(
self, ds: pydicom.Dataset | dict, exclude_elements: list[str] = []
):
# Sometimes we may want to exclude specific elements from being used for prediction
x = []
for each_element in self.metadata_elements:
# Only include elements which are present
if each_element in ds and each_element not in exclude_elements:
if ds[each_element] is not None and str(ds[each_element]) != "nan":
x.append(f"{each_element} {ds[each_element]}")
x = " | ".join(x)
x = x.replace("[", "").replace("]", "").replace(",", "").replace("'", "")
return x
@staticmethod
def determine_plane_from_dicom(ds: pydicom.Dataset | dict):
iop = ds.get("ImageOrientationPatient", None)
if iop is None:
return None
iop = np.asarray(iop)
# Calculate the direction cosine for the normal vector of the plane
normal_vector = np.cross(iop[:3], iop[3:])
# Determine the plane based on the largest component of the normal vector
abs_normal = np.abs(normal_vector)
if abs_normal[0] > abs_normal[1] and abs_normal[0] > abs_normal[2]:
return "SAG"
elif abs_normal[1] > abs_normal[0] and abs_normal[1] > abs_normal[2]:
return "COR"
else:
return "AX"
|