| | import torch |
| | import numpy as np |
| | from monai.transforms import Compose, LoadImage, EnsureChannelFirst, Lambda, Resize, NormalizeIntensity, GaussianSmooth, ScaleIntensity, AsDiscrete, KeepLargestConnectedComponent, Invert, Rotate90, SaveImage, Transform |
| | from monai.inferers import SlidingWindowInferer |
| | from monai.networks.nets import UNet |
| |
|
| | class RgbaToGrayscale(Transform): |
| | def __call__(self, x): |
| | |
| | x = x.squeeze(-1) |
| | |
| | if x.ndim != 3: |
| | raise ValueError(f"Input tensor must be 3D. Shape: {x.shape}") |
| | |
| | |
| | if x.shape[0] == 4: |
| | rgb_weights = torch.tensor([0.2989, 0.5870, 0.1140], device=x.device) |
| | |
| | grayscale = torch.einsum('cwh,c->wh', x[:3, :, :], rgb_weights).unsqueeze(0) |
| | elif x.shape[0] == 3: |
| | rgb_weights = torch.tensor([0.2989, 0.5870, 0.1140], device=x.device) |
| | grayscale = torch.einsum('cwh,c->wh', x, rgb_weights).unsqueeze(0) |
| | elif x.shape[0] == 1: |
| | grayscale = x |
| | else: |
| | raise ValueError(f"Unsupported channel number: {x.shape[0]}") |
| | return grayscale |
| |
|
| | def inverse(self, x): |
| | |
| | return x |
| | |
| | model = UNet( |
| | spatial_dims=2, |
| | in_channels=1, |
| | out_channels=4, |
| | channels=[64, 128, 256, 512], |
| | strides=[2, 2, 2], |
| | num_res_units=3 |
| | ) |
| |
|
| | checkpoint_path = 'segmentation_model.pt' |
| | checkpoint = torch.load(checkpoint_path, map_location='cpu') |
| | assert model.state_dict().keys() == checkpoint['network'].keys(), "Model and checkpoint keys do not match" |
| |
|
| | model.load_state_dict(checkpoint['network']) |
| | model.eval() |
| |
|
| | |
| | pre_transforms = Compose([ |
| | LoadImage(image_only=True), |
| | EnsureChannelFirst(), |
| | RgbaToGrayscale(), |
| | Resize(spatial_size=(768, 768)), |
| | Lambda(func=lambda x: x.squeeze(-1)), |
| | NormalizeIntensity(), |
| | GaussianSmooth(sigma=0.1), |
| | ScaleIntensity(minv=-1, maxv=1) |
| | ]) |
| |
|
| |
|
| |
|
| | |
| | post_transforms = Compose([ |
| | AsDiscrete(argmax=True, to_onehot=4), |
| | KeepLargestConnectedComponent(), |
| | AsDiscrete(argmax=True), |
| | Invert(pre_transforms), |
| | |
| | ]) |
| |
|
| |
|
| |
|
| | def load_and_segment_image(input_image_path, device): |
| | image_tensor = pre_transforms(input_image_path) |
| | image_tensor = image_tensor.unsqueeze(0).to(device) |
| |
|
| | |
| | inferer = SlidingWindowInferer(roi_size=(512, 512), sw_batch_size=16, overlap=0.75) |
| | with torch.no_grad(): |
| | outputs = inferer(image_tensor, model.to(device)) |
| |
|
| |
|
| | outputs = outputs.squeeze(0) |
| |
|
| | processed_outputs = post_transforms(outputs).to('cpu') |
| |
|
| | output_array = processed_outputs.squeeze().detach().numpy().astype(np.uint8) |
| |
|
| |
|
| | return output_array |