Spaces:
Sleeping
Sleeping
| import json | |
| import logging | |
| import os | |
| import re | |
| from abc import ABC, abstractmethod | |
| from pathlib import Path | |
| from typing import (Any, Callable, Dict, Iterable, List, Optional, Pattern, | |
| Set, Tuple, Union) | |
| import numpy as np | |
| import SimpleITK | |
| from evalutils.exceptions import FileLoaderError | |
| from evalutils.io import FileLoader, ImageLoader, SimpleITKLoader | |
| from evalutils.validators import (UniqueImagesValidator, | |
| UniquePathIndicesValidator) | |
| from pandas import DataFrame | |
| logger = logging.getLogger(__name__) | |
| # Check if .env file exists and load it | |
| if Path(".env").exists(): | |
| from dotenv import dotenv_values | |
| config = dotenv_values(".env") | |
| TASK_TYPE = config["TASK_TYPE"] | |
| INPUT_FOLDER = config["INPUT_FOLDER"] | |
| print("########## ENVIRONMENT VARIABLES ##########") | |
| print(f"TASK_TYPE: {TASK_TYPE}") | |
| print(f"INPUT_FOLDER: {INPUT_FOLDER}") | |
| else: | |
| TASK_TYPE = "mri" | |
| INPUT_FOLDER = "/input" | |
| if INPUT_FOLDER == "/input": | |
| OUTPUT_FOLDER = "/output" | |
| else: | |
| OUTPUT_FOLDER = "./output" | |
| DEFAULT_IMAGE_PATH = Path(f"{INPUT_FOLDER}/images/{TASK_TYPE}") | |
| DEFAULT_REGION_PATH = Path(f"{INPUT_FOLDER}/region.json") | |
| DEFAULT_MASK_PATH = Path(f"{INPUT_FOLDER}/images/body") | |
| DEFAULT_OUTPUT_PATH = Path(f"{OUTPUT_FOLDER}/images/synthetic-ct") | |
| DEFAULT_OUTPUT_FILE = Path(f"{OUTPUT_FOLDER}/results.json") | |
| class BaseSynthradAlgorithm(ABC): | |
| def __init__( | |
| self, | |
| input_path: Path = DEFAULT_IMAGE_PATH, | |
| mask_path: Path = DEFAULT_MASK_PATH, | |
| region_path: Path = DEFAULT_REGION_PATH, | |
| output_path: Path = DEFAULT_OUTPUT_PATH, | |
| output_file: Path = DEFAULT_OUTPUT_FILE, | |
| validators: Optional[Dict[str, callable]] = None, | |
| file_loader: FileLoader = SimpleITKLoader(), | |
| ): | |
| """ | |
| Parameters | |
| ---------- | |
| input_path | |
| The path in the container where the input images will be loaded from. | |
| from. Default: `/input/images/mri/` | |
| mask_path | |
| The path in the container where the input masks will be loaded from. | |
| Default: `/input/images/body/` | |
| output_path | |
| The path in the container where the output images will be written. | |
| Default: `/output/images/synthetic-ct/` | |
| output_file | |
| The path to the location where the results will be written. | |
| Default: `/output/results.json` | |
| file_loader | |
| The loaders that will be used to get all files. | |
| Default: `evalutils.io.SimpleITKLoader` for `image` and `mask` | |
| validators | |
| A dictionary containing the validators that will be used on the | |
| loaded data per file_loader key. Default: | |
| `evalutils.validators.UniqueImagesValidator` for `input_image` | |
| """ | |
| self._index_keys = ["image", "mask"] | |
| self.input_path = input_path | |
| self.mask_path = mask_path | |
| self.region_path = region_path | |
| self.output_path = output_path | |
| self.output_file = output_file | |
| self._file_loader = file_loader | |
| # TODO: Add validators | |
| # self.validators = [ | |
| # UniquePathIndicesValidator(), | |
| # UniqueImagesValidator(), | |
| # ] | |
| self.cases = {} | |
| self._case_results = [] | |
| def load(self): | |
| self.images = self._load_cases( | |
| folder=self.input_path, file_loader=self._file_loader | |
| ) | |
| self.masks = self._load_cases( | |
| folder=self.mask_path, file_loader=self._file_loader | |
| ) | |
| with open(self.region_path, "r") as f: | |
| self.region = json.load(f) | |
| def _load_cases( | |
| self, | |
| folder: Path, | |
| file_loader: ImageLoader, | |
| ) -> DataFrame: | |
| cases = [] | |
| for fp in sorted(folder.glob("*")): | |
| try: | |
| new_cases = file_loader.load(fname=fp) | |
| except FileLoaderError: | |
| logger.warning(f"Could not load {fp.name} using {file_loader}.") | |
| else: | |
| cases.extend(new_cases) | |
| if len(cases) == 0: | |
| raise FileLoaderError( | |
| f"Could not load any files in {folder} with " f"{file_loader}." | |
| ) | |
| return cases | |
| def validate(self): | |
| """TODO: Validates each dataframe for each fileloader separately""" | |
| pass | |
| def _validate_data_frame(self, df: DataFrame): | |
| "TODO: Validate the dataframe for a specific fileloader" | |
| pass | |
| def process_cases(self): | |
| self._case_results = [] | |
| for idx, case in enumerate(zip(self.images, self.masks)): | |
| self._case_results.append(self.process_case(idx=idx, case=case)) | |
| def process_case(self, idx: int, case: List[DataFrame]) -> Dict: | |
| images, images_file_paths = {}, {} | |
| images["image"], images_file_paths["image"] = self._load_input_image(case[0]) | |
| images["mask"], images_file_paths["mask"] = self._load_input_image(case[1]) | |
| images["region"] = self.region | |
| # Predict and generate output | |
| out = self.predict(input_dict=images) | |
| # Write resulting segmentation to output location | |
| out_path = self.output_path / images_file_paths["image"].name | |
| if not self.output_path.exists(): | |
| self.output_path.mkdir(parents=True, exist_ok=True) | |
| SimpleITK.WriteImage(out, str(out_path), True) | |
| # Write segmentation file path to result.json for this case | |
| return { | |
| "outputs": [dict(type="metaio_image", filename=str(out_path))], | |
| "inputs": [ | |
| dict(type="metaio_image", filename=str(fn)) | |
| for fn in images_file_paths.values() | |
| ] + [dict(type="String", filename=str(self.region_path))], | |
| "error_messages": [], | |
| } | |
| def _load_input_image(self, image) -> Tuple[SimpleITK.Image, Path]: | |
| input_image_file_path = image["path"] | |
| input_image_file_loader = self._file_loader | |
| if not isinstance(input_image_file_loader, ImageLoader): | |
| raise RuntimeError("The used FileLoader was not of subclass ImageLoader") | |
| # Load the image | |
| input_image = input_image_file_loader.load_image(input_image_file_path) | |
| # Check that it is the expected image | |
| if input_image_file_loader.hash_image(input_image) != image["hash"]: | |
| raise RuntimeError("Image hashes do not match") | |
| return input_image, input_image_file_path | |
| def predict(self, *, input_dict: Dict[str, SimpleITK.Image]) -> SimpleITK.Image: | |
| pass | |
| def save(self): | |
| with open(str(self.output_file), "w") as f: | |
| json.dump(self._case_results, f) | |
| def process(self): | |
| self.load() | |
| self.validate() | |
| self.process_cases() | |
| self.save() | |