ianpan's picture
Upload MRIBrainSequenceBERT
fa7dbe5
import numpy as np
import pydicom
import torch
import torch.nn as nn
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
PreTrainedModel,
)
from .configuration import MRIBrainSequenceBERTConfig
class SingleModel(nn.Module):
def __init__(self, config, model_id: str):
super().__init__()
self.llm = AutoModelForSequenceClassification.from_pretrained(model_id)
self.dim_feats = self.llm.classifier.in_features
self.dropout = nn.Dropout(p=config.dropout)
self.classifier = nn.Linear(self.dim_feats, config.num_classes)
self.llm.dropout = nn.Identity()
self.llm.classifier = nn.Identity()
def forward(self, x, apply_softmax: bool = True):
features = self.llm(**x)["logits"]
logits = self.classifier(self.dropout(features))
if apply_softmax:
logits = torch.softmax(logits, dim=1)
return logits
class MRIBrainSequenceBERT(PreTrainedModel):
config_class = MRIBrainSequenceBERTConfig
def __init__(self, config):
super().__init__(config)
self.model_id = "answerdotai/ModernBERT-base"
self.m1 = SingleModel(config, self.model_id)
self.m2 = SingleModel(config, self.model_id)
self.ensemble = True
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
self.max_len = config.max_len
self.metadata_elements = [
"SeriesDescription",
"ImageType",
"Manufacturer",
"ManufacturerModelName",
"ContrastBolusAgent",
"ScanningSequence",
"SequenceVariant",
"ScanOptions",
"MRAcquisitionType",
"SequenceName",
"AngioFlag",
"SliceThickness",
"RepetitionTime",
"EchoTime",
"InversionTime",
"NumberOfAverages",
"ImagingFrequency",
"ImagedNucleus",
"EchoNumbers",
"SpacingBetweenSlices",
"NumberOfPhaseEncodingSteps",
"EchoTrainLength",
"PercentSampling",
"PercentPhaseFieldOfView",
"PixelBandwidth",
"ContrastBolusVolume",
"ContrastBolusTotalDose",
"AcquisitionMatrix",
"InPlanePhaseEncodingDirection",
"FlipAngle",
"VariableFlipAngleFlag",
"SAR",
"dBdt",
"SeriesNumber",
"AcquisitionNumber",
"PhotometricInterpretation",
"PixelSpacing",
"ImagesInAcquisition",
"SmallestImagePixelValue",
"LargestImagePixelValue",
]
self.label2index = {
"t1": 0,
"t1c": 1,
"t2": 2,
"flair": 3,
"dwi": 4,
"adc": 5,
"eadc": 6,
"swi": 7,
"swi_mag": 8,
"swi_phase": 9,
"swi_minip": 10,
"t2_gre": 11,
"perfusion": 12,
"pd": 13,
"mra": 14,
"loc": 15,
"other": 16,
}
self.index2label = {v: k for k, v in self.label2index.items()}
def forward(
self, x: str, device: str | torch.device = "cpu", apply_softmax: bool = True
):
x = self.tokenizer(
x,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.max_len,
)
for k, v in x.items():
x[k] = v.to(device)
logits = self.m1(x, apply_softmax=apply_softmax)
if self.ensemble:
logits += self.m2(x, apply_softmax=apply_softmax)
logits /= 2.0
return logits
def create_string_from_dicom(
self, ds: pydicom.Dataset | dict, exclude_elements: list[str] = []
):
# Sometimes we may want to exclude specific elements from being used for prediction
x = []
for each_element in self.metadata_elements:
# Only include elements which are present
if each_element in ds and each_element not in exclude_elements:
if ds[each_element] is not None and str(ds[each_element]) != "nan":
x.append(f"{each_element} {ds[each_element]}")
x = " | ".join(x)
x = x.replace("[", "").replace("]", "").replace(",", "").replace("'", "")
return x
@staticmethod
def determine_plane_from_dicom(ds: pydicom.Dataset | dict):
iop = ds.get("ImageOrientationPatient", None)
if iop is None:
return None
iop = np.asarray(iop)
# Calculate the direction cosine for the normal vector of the plane
normal_vector = np.cross(iop[:3], iop[3:])
# Determine the plane based on the largest component of the normal vector
abs_normal = np.abs(normal_vector)
if abs_normal[0] > abs_normal[1] and abs_normal[0] > abs_normal[2]:
return "SAG"
elif abs_normal[1] > abs_normal[0] and abs_normal[1] > abs_normal[2]:
return "COR"
else:
return "AX"