""" Detect and remove constant-value artifact planes at volume boundaries. Interpolation during preprocessing can introduce planes filled with a single non-zero constant value (e.g. 8.0 for CT) at the start or end of each spatial axis. This script: 1. Scans all .nii.gz files under --image_dir (and optionally --label_dir). 2. For each image, identifies boundary planes that are entirely one NON-ZERO value — zero-valued planes are skipped as they are legitimate background. 3. Crops the artifact planes from image AND matching label (if present), preserving the spatial origin so the image stays in the same physical coordinate space. 4. Overwrites in-place (use --dry-run to preview without writing). Usage: # Dry-run: report artifacts without modifying files python clean_artifact_planes.py \ --image_dir /path/to/MSD_processed/images \ --label_dir /path/to/MSD_processed/labels \ --dry-run # Actually clean: python clean_artifact_planes.py \ --image_dir /path/to/MSD_processed/images \ --label_dir /path/to/MSD_processed/labels """ import os import glob import argparse import numpy as np import SimpleITK as sitk from tqdm import tqdm def _get_plane(arr, axis, idx): """Extract a single plane from the array along the given axis.""" slc = [slice(None)] * arr.ndim slc[axis] = idx return arr[tuple(slc)] def find_artifact_slices(arr, axis, max_search=20): """Find contiguous constant-value boundary slices along `axis`. Returns (n_start, n_end): number of artifact slices to trim from the start and end of the given axis. A slice is considered an artifact if: - It has exactly 1 unique value, AND - That value is foreign to the adjacent interior plane (i.e. the value does not appear, or appears very rarely, in the neighbor). This avoids trimming legitimate background planes (e.g. -300 in CT air) that are naturally connected to interior regions with the same value. """ n = arr.shape[axis] def _is_artifact(idx, interior_idx): plane = _get_plane(arr, axis, idx) unique = np.unique(plane) if len(unique) != 1: return False val = float(unique[0]) # Check if this constant value appears in the adjacent interior plane neighbor = _get_plane(arr, axis, interior_idx) # If the value appears in >1% of the neighbor's voxels, it's likely # connected background, not an artifact match_ratio = np.mean(np.abs(neighbor - val) < 1e-6) if match_ratio > 0.01: return False return True # Find the first non-constant plane from each boundary to use as reference def _find_reference(start, stop, step): for idx in range(start, stop, step): plane = _get_plane(arr, axis, idx) if len(np.unique(plane)) > 1: return idx return start # fallback ref_start = _find_reference(0, min(max_search + 5, n), 1) ref_end = _find_reference(n - 1, max(n - 1 - max_search - 5, -1), -1) n_start = 0 for i in range(min(max_search, n // 2)): if _is_artifact(i, ref_start): n_start = i + 1 else: break n_end = 0 for i in range(n - 1, max(n - 1 - max_search, n // 2), -1): if _is_artifact(i, ref_end): n_end = (n - 1 - i) + 1 else: break return n_start, n_end def detect_artifacts(arr, max_search=20): """Detect artifact planes on all spatial axes. For 4D arrays (e.g. BRATS with shape [C, D, H, W]), only spatial axes (1, 2, 3) are checked; the channel axis (0) is skipped. Returns a dict: {axis: (n_start, n_end)} for axes that need trimming. """ if arr.ndim == 3: spatial_axes = [0, 1, 2] elif arr.ndim == 4: spatial_axes = [1, 2, 3] else: spatial_axes = list(range(arr.ndim)) crops = {} for axis in spatial_axes: n_start, n_end = find_artifact_slices(arr, axis, max_search=max_search) if n_start > 0 or n_end > 0: crops[axis] = (n_start, n_end) return crops def build_crop_slices(ndim, crops): """Build a tuple of slices to crop the array according to `crops`.""" slices = [slice(None)] * ndim for axis, (n_start, n_end) in crops.items(): end = None if n_end == 0 else -n_end slices[axis] = slice(n_start, end) return tuple(slices) def crop_sitk_image(sitk_img, crops): """Crop a SimpleITK image according to the detected artifact planes. Updates the origin so the cropped image occupies the correct physical space. """ arr = sitk.GetArrayFromImage(sitk_img) crop_slices = build_crop_slices(arr.ndim, crops) cropped_arr = arr[crop_slices] cropped_img = sitk.GetImageFromArray(cropped_arr) cropped_img.SetSpacing(sitk_img.GetSpacing()) cropped_img.SetDirection(sitk_img.GetDirection()) # Adjust origin: SimpleITK arrays are in ZYX order, origin is in XYZ ndim_phys = sitk_img.GetDimension() # physical dimensions (3 for 3D, 4 for 4D) origin = list(sitk_img.GetOrigin()) spacing = list(sitk_img.GetSpacing()) direction = np.array(sitk_img.GetDirection()).reshape(ndim_phys, ndim_phys) for axis, (n_start, _) in crops.items(): if n_start > 0: # Map array axis to physical axis # SimpleITK: last array axis = first physical axis if arr.ndim == 3: phys_axis = 2 - axis elif arr.ndim == 4: phys_axis = 2 - (axis - 1) else: continue if phys_axis < ndim_phys: for i in range(min(3, ndim_phys)): origin[i] += n_start * spacing[phys_axis] * direction[i, phys_axis] cropped_img.SetOrigin(origin) for key in sitk_img.GetMetaDataKeys(): cropped_img.SetMetaData(key, sitk_img.GetMetaData(key)) return cropped_img def main(): parser = argparse.ArgumentParser(description="Detect and remove constant-value artifact planes at volume boundaries.") parser.add_argument("--image_dir", type=str, required=True, help="Directory containing .nii.gz image files.") parser.add_argument("--label_dir", type=str, default=None, help="Directory containing matching .nii.gz label files (same filenames). " "In recursive mode, labels are found at {subject_dir}/segmentation/{filename}.") parser.add_argument("--recursive", action="store_true", help="Recursively search for .nii.gz files, excluding segmentation/ subdirs.") parser.add_argument("--max_search", type=int, default=20, help="Max number of boundary slices to check per side (default: 20).") parser.add_argument("--dry-run", action="store_true", help="Report artifacts without modifying any files.") args = parser.parse_args() if args.recursive: all_files = sorted(glob.glob(os.path.join(args.image_dir, "**", "*.nii.gz"), recursive=True)) image_files = [f for f in all_files if '/segmentation/' not in f and '/label' not in f.lower()] else: image_files = sorted(glob.glob(os.path.join(args.image_dir, "*.nii.gz"))) print(f"Found {len(image_files)} images in {args.image_dir}{' (recursive)' if args.recursive else ''}") if args.label_dir: print(f"Label dir: {args.label_dir}") if args.dry_run: print("*** DRY RUN — no files will be modified ***") total_artifacts = 0 total_clean = 0 total_slices_removed = 0 for img_path in tqdm(image_files, desc="Scanning"): filename = os.path.basename(img_path) sitk_img = sitk.ReadImage(img_path) arr = sitk.GetArrayFromImage(sitk_img) crops = detect_artifacts(arr, max_search=args.max_search) if not crops: total_clean += 1 continue total_artifacts += 1 slices_removed = sum(s + e for s, e in crops.values()) total_slices_removed += slices_removed detail = ", ".join( f"axis{ax}: -{s} start, -{e} end" for ax, (s, e) in sorted(crops.items()) ) # Report the artifact value for ax, (s, e) in crops.items(): slc = [slice(None)] * arr.ndim if s > 0: slc[ax] = 0 else: slc[ax] = arr.shape[ax] - 1 val = arr[tuple(slc)].flat[0] break print(f" {filename}: {arr.shape} -> trim {slices_removed} planes, val={val} ({detail})") if args.dry_run: continue # Crop and save image cropped_img = crop_sitk_image(sitk_img, crops) sitk.WriteImage(cropped_img, img_path) # Crop matching label if present if args.label_dir and not args.recursive: label_path = os.path.join(args.label_dir, filename) if os.path.isfile(label_path): sitk_lbl = sitk.ReadImage(label_path) cropped_lbl = crop_sitk_image(sitk_lbl, crops) sitk.WriteImage(cropped_lbl, label_path) elif args.recursive: # In recursive mode, look for label at {parent}/segmentation/{filename} parent_dir = os.path.dirname(img_path) label_path = os.path.join(parent_dir, 'segmentation', filename) if os.path.isfile(label_path): sitk_lbl = sitk.ReadImage(label_path) cropped_lbl = crop_sitk_image(sitk_lbl, crops) sitk.WriteImage(cropped_lbl, label_path) print(f"\nSummary:") print(f" Total images: {len(image_files)}") print(f" With artifacts: {total_artifacts}") print(f" Clean: {total_clean}") print(f" Planes removed: {total_slices_removed}") if args.dry_run: print(" (dry-run — nothing was modified)") if __name__ == "__main__": main()