| import albumentations as A |
| import torch |
| import torch.nn as nn |
|
|
| from numpy.typing import NDArray |
| from transformers import PreTrainedModel |
| from timm import create_model |
| from typing import Optional |
| from .configuration import CXRConfig |
| from .unet import UnetDecoder, SegmentationHead |
|
|
| _PYDICOM_AVAILABLE = False |
| try: |
| from pydicom import dcmread |
| from pydicom.pixels import apply_voi_lut |
|
|
| _PYDICOM_AVAILABLE = True |
| except ModuleNotFoundError: |
| pass |
|
|
|
|
| class CXRModel(PreTrainedModel): |
| config_class = CXRConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.encoder = create_model( |
| model_name=config.backbone, |
| features_only=True, |
| pretrained=False, |
| in_chans=config.in_chans, |
| ) |
| self.decoder = UnetDecoder( |
| decoder_n_blocks=config.decoder_n_blocks, |
| decoder_channels=config.decoder_channels, |
| encoder_channels=config.encoder_channels, |
| decoder_center_block=config.decoder_center_block, |
| decoder_norm_layer=config.decoder_norm_layer, |
| decoder_attention_type=config.decoder_attention_type, |
| ) |
| self.img_size = config.img_size |
| self.segmentation_head = SegmentationHead( |
| in_channels=config.decoder_channels[-1], |
| out_channels=config.seg_num_classes, |
| size=self.img_size, |
| ) |
| self.pooling = nn.AdaptiveAvgPool2d(1) |
| self.dropout = nn.Dropout(p=config.cls_dropout) |
| self.classifier = nn.Linear(config.feature_dim, config.cls_num_classes) |
|
|
| def normalize(self, x: torch.Tensor) -> torch.Tensor: |
| |
| mini, maxi = 0.0, 255.0 |
| x = (x - mini) / (maxi - mini) |
| x = (x - 0.5) * 2.0 |
| return x |
|
|
| @staticmethod |
| def load_image_from_dicom(path: str) -> Optional[NDArray]: |
| if not _PYDICOM_AVAILABLE: |
| print("`pydicom` is not installed, returning None ...") |
| return None |
| dicom = dcmread(path) |
| arr = apply_voi_lut(dicom.pixel_array, dicom) |
| if dicom.PhotometricInterpretation == "MONOCHROME1": |
| |
| arr = arr.max() - arr |
|
|
| arr = arr - arr.min() |
| arr = arr / arr.max() |
| arr = (arr * 255).astype("uint8") |
| return arr |
|
|
| def preprocess(self, x: NDArray) -> NDArray: |
| x = A.Resize(self.img_size[0], self.img_size[1], p=1)(image=x)["image"] |
| return x |
|
|
| def forward(self, x: torch.Tensor, return_logits: bool = False) -> torch.Tensor: |
| x = self.normalize(x) |
| features = self.encoder(x) |
| decoder_output = self.decoder(features) |
| logits = self.segmentation_head(decoder_output[-1]) |
| b, n = features[-1].shape[:2] |
| features = self.pooling(features[-1]).reshape(b, n) |
| features = self.dropout(features) |
| cls_logits = self.classifier(features) |
| out = { |
| "mask": logits, |
| "age": cls_logits[:, 0].unsqueeze(1), |
| "view": cls_logits[:, 1:4], |
| "female": cls_logits[:, 4].unsqueeze(1), |
| } |
| if return_logits: |
| return out |
| out["mask"] = out["mask"].softmax(1) |
| out["view"] = out["view"].softmax(1) |
| out["female"] = out["female"].sigmoid() |
| return out |
|
|