import numpy as np import pydicom import torch import torch.nn as nn from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, PreTrainedModel, ) from .configuration import MRIBrainSequenceBERTConfig class SingleModel(nn.Module): def __init__(self, config, model_id: str): super().__init__() self.llm = AutoModelForSequenceClassification.from_pretrained(model_id) self.dim_feats = self.llm.classifier.in_features self.dropout = nn.Dropout(p=config.dropout) self.classifier = nn.Linear(self.dim_feats, config.num_classes) self.llm.dropout = nn.Identity() self.llm.classifier = nn.Identity() def forward(self, x, apply_softmax: bool = True): features = self.llm(**x)["logits"] logits = self.classifier(self.dropout(features)) if apply_softmax: logits = torch.softmax(logits, dim=1) return logits class MRIBrainSequenceBERT(PreTrainedModel): config_class = MRIBrainSequenceBERTConfig def __init__(self, config): super().__init__(config) self.model_id = "answerdotai/ModernBERT-base" self.m1 = SingleModel(config, self.model_id) self.m2 = SingleModel(config, self.model_id) self.ensemble = True self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) self.max_len = config.max_len self.metadata_elements = [ "SeriesDescription", "ImageType", "Manufacturer", "ManufacturerModelName", "ContrastBolusAgent", "ScanningSequence", "SequenceVariant", "ScanOptions", "MRAcquisitionType", "SequenceName", "AngioFlag", "SliceThickness", "RepetitionTime", "EchoTime", "InversionTime", "NumberOfAverages", "ImagingFrequency", "ImagedNucleus", "EchoNumbers", "SpacingBetweenSlices", "NumberOfPhaseEncodingSteps", "EchoTrainLength", "PercentSampling", "PercentPhaseFieldOfView", "PixelBandwidth", "ContrastBolusVolume", "ContrastBolusTotalDose", "AcquisitionMatrix", "InPlanePhaseEncodingDirection", "FlipAngle", "VariableFlipAngleFlag", "SAR", "dBdt", "SeriesNumber", "AcquisitionNumber", "PhotometricInterpretation", "PixelSpacing", "ImagesInAcquisition", "SmallestImagePixelValue", "LargestImagePixelValue", ] self.label2index = { "t1": 0, "t1c": 1, "t2": 2, "flair": 3, "dwi": 4, "adc": 5, "eadc": 6, "swi": 7, "swi_mag": 8, "swi_phase": 9, "swi_minip": 10, "t2_gre": 11, "perfusion": 12, "pd": 13, "mra": 14, "loc": 15, "other": 16, } self.index2label = {v: k for k, v in self.label2index.items()} def forward( self, x: str, device: str | torch.device = "cpu", apply_softmax: bool = True ): x = self.tokenizer( x, return_tensors="pt", padding="max_length", truncation=True, max_length=self.max_len, ) for k, v in x.items(): x[k] = v.to(device) logits = self.m1(x, apply_softmax=apply_softmax) if self.ensemble: logits += self.m2(x, apply_softmax=apply_softmax) logits /= 2.0 return logits def create_string_from_dicom( self, ds: pydicom.Dataset | dict, exclude_elements: list[str] = [] ): # Sometimes we may want to exclude specific elements from being used for prediction x = [] for each_element in self.metadata_elements: # Only include elements which are present if each_element in ds and each_element not in exclude_elements: if ds[each_element] is not None and str(ds[each_element]) != "nan": x.append(f"{each_element} {ds[each_element]}") x = " | ".join(x) x = x.replace("[", "").replace("]", "").replace(",", "").replace("'", "") return x @staticmethod def determine_plane_from_dicom(ds: pydicom.Dataset | dict): iop = ds.get("ImageOrientationPatient", None) if iop is None: return None iop = np.asarray(iop) # Calculate the direction cosine for the normal vector of the plane normal_vector = np.cross(iop[:3], iop[3:]) # Determine the plane based on the largest component of the normal vector abs_normal = np.abs(normal_vector) if abs_normal[0] > abs_normal[1] and abs_normal[0] > abs_normal[2]: return "SAG" elif abs_normal[1] > abs_normal[0] and abs_normal[1] > abs_normal[2]: return "COR" else: return "AX"