|
|
import json |
|
|
import logging |
|
|
import os |
|
|
import re |
|
|
from abc import ABC, abstractmethod |
|
|
from pathlib import Path |
|
|
from typing import (Any, Callable, Dict, Iterable, List, Optional, Pattern, |
|
|
Set, Tuple, Union) |
|
|
|
|
|
import numpy as np |
|
|
import SimpleITK |
|
|
from evalutils.exceptions import FileLoaderError |
|
|
from evalutils.io import FileLoader, ImageLoader, SimpleITKLoader |
|
|
from evalutils.validators import (UniqueImagesValidator, |
|
|
UniquePathIndicesValidator) |
|
|
from pandas import DataFrame |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
if Path(".env").exists(): |
|
|
from dotenv import dotenv_values |
|
|
|
|
|
config = dotenv_values(".env") |
|
|
|
|
|
TASK_TYPE = config["TASK_TYPE"] |
|
|
INPUT_FOLDER = config["INPUT_FOLDER"] |
|
|
|
|
|
print("########## ENVIRONMENT VARIABLES ##########") |
|
|
print(f"TASK_TYPE: {TASK_TYPE}") |
|
|
print(f"INPUT_FOLDER: {INPUT_FOLDER}") |
|
|
else: |
|
|
TASK_TYPE = "mri" |
|
|
INPUT_FOLDER = "/input" |
|
|
|
|
|
if INPUT_FOLDER == "/input": |
|
|
OUTPUT_FOLDER = "/output" |
|
|
else: |
|
|
OUTPUT_FOLDER = "./output" |
|
|
|
|
|
DEFAULT_IMAGE_PATH = Path(f"{INPUT_FOLDER}/images/{TASK_TYPE}") |
|
|
DEFAULT_REGION_PATH = Path(f"{INPUT_FOLDER}/region.json") |
|
|
DEFAULT_MASK_PATH = Path(f"{INPUT_FOLDER}/images/body") |
|
|
DEFAULT_OUTPUT_PATH = Path(f"{OUTPUT_FOLDER}/images/synthetic-ct") |
|
|
DEFAULT_OUTPUT_FILE = Path(f"{OUTPUT_FOLDER}/results.json") |
|
|
|
|
|
|
|
|
class BaseSynthradAlgorithm(ABC): |
|
|
def __init__( |
|
|
self, |
|
|
input_path: Path = DEFAULT_IMAGE_PATH, |
|
|
mask_path: Path = DEFAULT_MASK_PATH, |
|
|
region_path: Path = DEFAULT_REGION_PATH, |
|
|
output_path: Path = DEFAULT_OUTPUT_PATH, |
|
|
output_file: Path = DEFAULT_OUTPUT_FILE, |
|
|
validators: Optional[Dict[str, callable]] = None, |
|
|
file_loader: FileLoader = SimpleITKLoader(), |
|
|
): |
|
|
""" |
|
|
Parameters |
|
|
---------- |
|
|
|
|
|
input_path |
|
|
The path in the container where the input images will be loaded from. |
|
|
from. Default: `/input/images/mri/` |
|
|
mask_path |
|
|
The path in the container where the input masks will be loaded from. |
|
|
Default: `/input/images/body/` |
|
|
output_path |
|
|
The path in the container where the output images will be written. |
|
|
Default: `/output/images/synthetic-ct/` |
|
|
|
|
|
output_file |
|
|
The path to the location where the results will be written. |
|
|
Default: `/output/results.json` |
|
|
file_loader |
|
|
The loaders that will be used to get all files. |
|
|
Default: `evalutils.io.SimpleITKLoader` for `image` and `mask` |
|
|
validators |
|
|
A dictionary containing the validators that will be used on the |
|
|
loaded data per file_loader key. Default: |
|
|
`evalutils.validators.UniqueImagesValidator` for `input_image` |
|
|
""" |
|
|
|
|
|
self._index_keys = ["image", "mask"] |
|
|
self.input_path = input_path |
|
|
self.mask_path = mask_path |
|
|
self.region_path = region_path |
|
|
self.output_path = output_path |
|
|
self.output_file = output_file |
|
|
self._file_loader = file_loader |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.cases = {} |
|
|
self._case_results = [] |
|
|
|
|
|
def load(self): |
|
|
self.images = self._load_cases( |
|
|
folder=self.input_path, file_loader=self._file_loader |
|
|
) |
|
|
|
|
|
self.masks = self._load_cases( |
|
|
folder=self.mask_path, file_loader=self._file_loader |
|
|
) |
|
|
|
|
|
with open(self.region_path, "r") as f: |
|
|
self.region = json.load(f) |
|
|
|
|
|
def _load_cases( |
|
|
self, |
|
|
folder: Path, |
|
|
file_loader: ImageLoader, |
|
|
) -> DataFrame: |
|
|
cases = [] |
|
|
|
|
|
for fp in sorted(folder.glob("*")): |
|
|
try: |
|
|
new_cases = file_loader.load(fname=fp) |
|
|
except FileLoaderError: |
|
|
logger.warning(f"Could not load {fp.name} using {file_loader}.") |
|
|
else: |
|
|
cases.extend(new_cases) |
|
|
|
|
|
if len(cases) == 0: |
|
|
raise FileLoaderError( |
|
|
f"Could not load any files in {folder} with " f"{file_loader}." |
|
|
) |
|
|
|
|
|
return cases |
|
|
|
|
|
def validate(self): |
|
|
"""TODO: Validates each dataframe for each fileloader separately""" |
|
|
pass |
|
|
|
|
|
def _validate_data_frame(self, df: DataFrame): |
|
|
"TODO: Validate the dataframe for a specific fileloader" |
|
|
pass |
|
|
|
|
|
def process_cases(self): |
|
|
self._case_results = [] |
|
|
|
|
|
for idx, case in enumerate(zip(self.images, self.masks)): |
|
|
self._case_results.append(self.process_case(idx=idx, case=case)) |
|
|
|
|
|
def process_case(self, idx: int, case: List[DataFrame]) -> Dict: |
|
|
images, images_file_paths = {}, {} |
|
|
|
|
|
images["image"], images_file_paths["image"] = self._load_input_image(case[0]) |
|
|
images["mask"], images_file_paths["mask"] = self._load_input_image(case[1]) |
|
|
|
|
|
images["region"] = self.region |
|
|
|
|
|
|
|
|
out = self.predict(input_dict=images) |
|
|
|
|
|
|
|
|
out_path = self.output_path / images_file_paths["image"].name |
|
|
if not self.output_path.exists(): |
|
|
self.output_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
SimpleITK.WriteImage(out, str(out_path), True) |
|
|
|
|
|
|
|
|
return { |
|
|
"outputs": [dict(type="metaio_image", filename=str(out_path))], |
|
|
"inputs": [ |
|
|
dict(type="metaio_image", filename=str(fn)) |
|
|
for fn in images_file_paths.values() |
|
|
] + [dict(type="String", filename=str(self.region_path))], |
|
|
"error_messages": [], |
|
|
} |
|
|
|
|
|
def _load_input_image(self, image) -> Tuple[SimpleITK.Image, Path]: |
|
|
input_image_file_path = image["path"] |
|
|
input_image_file_loader = self._file_loader |
|
|
|
|
|
if not isinstance(input_image_file_loader, ImageLoader): |
|
|
raise RuntimeError("The used FileLoader was not of subclass ImageLoader") |
|
|
|
|
|
|
|
|
input_image = input_image_file_loader.load_image(input_image_file_path) |
|
|
|
|
|
|
|
|
if input_image_file_loader.hash_image(input_image) != image["hash"]: |
|
|
raise RuntimeError("Image hashes do not match") |
|
|
return input_image, input_image_file_path |
|
|
|
|
|
@abstractmethod |
|
|
def predict(self, *, input_dict: Dict[str, SimpleITK.Image]) -> SimpleITK.Image: |
|
|
pass |
|
|
|
|
|
def save(self): |
|
|
with open(str(self.output_file), "w") as f: |
|
|
json.dump(self._case_results, f) |
|
|
|
|
|
def process(self): |
|
|
self.load() |
|
|
self.validate() |
|
|
self.process_cases() |
|
|
self.save() |
|
|
|