|
|
import numpy as np |
|
|
import pydicom |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from transformers import ( |
|
|
AutoTokenizer, |
|
|
AutoModelForSequenceClassification, |
|
|
PreTrainedModel, |
|
|
) |
|
|
|
|
|
from .configuration import MRIBrainSequenceBERTConfig |
|
|
|
|
|
|
|
|
class SingleModel(nn.Module): |
|
|
def __init__(self, config, model_id: str): |
|
|
super().__init__() |
|
|
self.llm = AutoModelForSequenceClassification.from_pretrained(model_id) |
|
|
self.dim_feats = self.llm.classifier.in_features |
|
|
self.dropout = nn.Dropout(p=config.dropout) |
|
|
self.classifier = nn.Linear(self.dim_feats, config.num_classes) |
|
|
self.llm.dropout = nn.Identity() |
|
|
self.llm.classifier = nn.Identity() |
|
|
|
|
|
def forward(self, x, apply_softmax: bool = True): |
|
|
features = self.llm(**x)["logits"] |
|
|
logits = self.classifier(self.dropout(features)) |
|
|
if apply_softmax: |
|
|
logits = torch.softmax(logits, dim=1) |
|
|
return logits |
|
|
|
|
|
|
|
|
class MRIBrainSequenceBERT(PreTrainedModel): |
|
|
config_class = MRIBrainSequenceBERTConfig |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.model_id = "answerdotai/ModernBERT-base" |
|
|
self.m1 = SingleModel(config, self.model_id) |
|
|
self.m2 = SingleModel(config, self.model_id) |
|
|
|
|
|
self.ensemble = True |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) |
|
|
self.max_len = config.max_len |
|
|
|
|
|
self.metadata_elements = [ |
|
|
"SeriesDescription", |
|
|
"ImageType", |
|
|
"Manufacturer", |
|
|
"ManufacturerModelName", |
|
|
"ContrastBolusAgent", |
|
|
"ScanningSequence", |
|
|
"SequenceVariant", |
|
|
"ScanOptions", |
|
|
"MRAcquisitionType", |
|
|
"SequenceName", |
|
|
"AngioFlag", |
|
|
"SliceThickness", |
|
|
"RepetitionTime", |
|
|
"EchoTime", |
|
|
"InversionTime", |
|
|
"NumberOfAverages", |
|
|
"ImagingFrequency", |
|
|
"ImagedNucleus", |
|
|
"EchoNumbers", |
|
|
"SpacingBetweenSlices", |
|
|
"NumberOfPhaseEncodingSteps", |
|
|
"EchoTrainLength", |
|
|
"PercentSampling", |
|
|
"PercentPhaseFieldOfView", |
|
|
"PixelBandwidth", |
|
|
"ContrastBolusVolume", |
|
|
"ContrastBolusTotalDose", |
|
|
"AcquisitionMatrix", |
|
|
"InPlanePhaseEncodingDirection", |
|
|
"FlipAngle", |
|
|
"VariableFlipAngleFlag", |
|
|
"SAR", |
|
|
"dBdt", |
|
|
"SeriesNumber", |
|
|
"AcquisitionNumber", |
|
|
"PhotometricInterpretation", |
|
|
"PixelSpacing", |
|
|
"ImagesInAcquisition", |
|
|
"SmallestImagePixelValue", |
|
|
"LargestImagePixelValue", |
|
|
] |
|
|
|
|
|
self.label2index = { |
|
|
"t1": 0, |
|
|
"t1c": 1, |
|
|
"t2": 2, |
|
|
"flair": 3, |
|
|
"dwi": 4, |
|
|
"adc": 5, |
|
|
"eadc": 6, |
|
|
"swi": 7, |
|
|
"swi_mag": 8, |
|
|
"swi_phase": 9, |
|
|
"swi_minip": 10, |
|
|
"t2_gre": 11, |
|
|
"perfusion": 12, |
|
|
"pd": 13, |
|
|
"mra": 14, |
|
|
"loc": 15, |
|
|
"other": 16, |
|
|
} |
|
|
|
|
|
self.index2label = {v: k for k, v in self.label2index.items()} |
|
|
|
|
|
def forward( |
|
|
self, x: str, device: str | torch.device = "cpu", apply_softmax: bool = True |
|
|
): |
|
|
x = self.tokenizer( |
|
|
x, |
|
|
return_tensors="pt", |
|
|
padding="max_length", |
|
|
truncation=True, |
|
|
max_length=self.max_len, |
|
|
) |
|
|
|
|
|
for k, v in x.items(): |
|
|
x[k] = v.to(device) |
|
|
|
|
|
logits = self.m1(x, apply_softmax=apply_softmax) |
|
|
if self.ensemble: |
|
|
logits += self.m2(x, apply_softmax=apply_softmax) |
|
|
logits /= 2.0 |
|
|
|
|
|
return logits |
|
|
|
|
|
def create_string_from_dicom( |
|
|
self, ds: pydicom.Dataset | dict, exclude_elements: list[str] = [] |
|
|
): |
|
|
|
|
|
x = [] |
|
|
for each_element in self.metadata_elements: |
|
|
|
|
|
if each_element in ds and each_element not in exclude_elements: |
|
|
if ds[each_element] is not None and str(ds[each_element]) != "nan": |
|
|
x.append(f"{each_element} {ds[each_element]}") |
|
|
|
|
|
x = " | ".join(x) |
|
|
x = x.replace("[", "").replace("]", "").replace(",", "").replace("'", "") |
|
|
return x |
|
|
|
|
|
@staticmethod |
|
|
def determine_plane_from_dicom(ds: pydicom.Dataset | dict): |
|
|
iop = ds.get("ImageOrientationPatient", None) |
|
|
if iop is None: |
|
|
return None |
|
|
iop = np.asarray(iop) |
|
|
|
|
|
normal_vector = np.cross(iop[:3], iop[3:]) |
|
|
|
|
|
|
|
|
abs_normal = np.abs(normal_vector) |
|
|
if abs_normal[0] > abs_normal[1] and abs_normal[0] > abs_normal[2]: |
|
|
return "SAG" |
|
|
elif abs_normal[1] > abs_normal[0] and abs_normal[1] > abs_normal[2]: |
|
|
return "COR" |
|
|
else: |
|
|
return "AX" |
|
|
|