LesionSegmenter / p_processing.py
ChrisXzZ's picture
Update p_processing.py
b434218 verified
import os
import argparse
from concurrent.futures import ProcessPoolExecutor, as_completed
import nibabel as nib
import numpy as np
from scipy import ndimage
from tqdm import tqdm
def parse_rules(rule_str):
rules = {}
items = [x.strip() for x in rule_str.split(",") if x.strip()]
for item in items:
if ":" not in item:
raise ValueError(
f"Invalid rule format: {item}. Expected format like '1:2|3|4,5:6'"
)
src, dst = item.split(":")
src_label = int(src.strip())
dst_labels = [int(x.strip()) for x in dst.split("|") if x.strip()]
if len(dst_labels) == 0:
raise ValueError(f"No target labels found in rule: {item}")
rules[src_label] = dst_labels
return rules
def parse_labels(label_str):
return [int(x.strip()) for x in label_str.split(",") if x.strip()]
def find_nii_gz_files(input_dir):
files = []
for root, _, filenames in os.walk(input_dir):
for fname in filenames:
if fname.endswith(".nii.gz"):
files.append(os.path.join(root, fname))
return sorted(files)
def get_structure(connectivity=26):
if connectivity == 6:
return ndimage.generate_binary_structure(3, 1)
elif connectivity == 26:
return ndimage.generate_binary_structure(3, 2)
else:
raise ValueError("connectivity must be 6 or 26")
def should_keep_component(dilated_component, seg, target_labels, keep_mode):
overlaps = []
for label in target_labels:
target_mask = (seg == label)
overlaps.append(np.any(dilated_component & target_mask))
if keep_mode == "any":
return any(overlaps)
elif keep_mode == "all":
return all(overlaps)
else:
raise ValueError("keep_mode must be 'any' or 'all'")
def remove_components_by_rules(
seg,
source_label,
target_labels,
dilation_iters,
structure,
keep_mode="any"
):
source_mask = (seg == source_label)
if not np.any(source_mask):
return seg
cc_map, num_cc = ndimage.label(source_mask, structure=structure)
for cc_id in range(1, num_cc + 1):
component = (cc_map == cc_id)
dilated_component = ndimage.binary_dilation(
component,
structure=structure,
iterations=dilation_iters
)
keep = should_keep_component(
dilated_component=dilated_component,
seg=seg,
target_labels=target_labels,
keep_mode=keep_mode
)
if not keep:
seg[component] = 0
return seg
def remove_small_components(seg, label_value, min_size, structure):
mask = (seg == label_value)
if not np.any(mask):
return seg
cc_map, num_cc = ndimage.label(mask, structure=structure)
for cc_id in range(1, num_cc + 1):
component = (cc_map == cc_id)
voxel_count = int(component.sum())
if voxel_count < min_size:
seg[component] = 0
return seg
def process_one_file(
file_path,
input_dir,
output_dir,
rules,
clean_labels,
min_size,
dilation_iters,
connectivity,
keep_mode
):
try:
img = nib.load(file_path)
seg = np.asanyarray(img.dataobj).astype(np.int32)
structure = get_structure(connectivity)
for source_label, dst_labels in rules.items():
seg = remove_components_by_rules(
seg=seg,
source_label=source_label,
target_labels=dst_labels,
dilation_iters=dilation_iters,
structure=structure,
keep_mode=keep_mode
)
for label in clean_labels:
seg = remove_small_components(
seg=seg,
label_value=label,
min_size=min_size,
structure=structure
)
rel_path = os.path.relpath(file_path, input_dir)
out_path = os.path.join(output_dir, rel_path)
os.makedirs(os.path.dirname(out_path), exist_ok=True)
out_img = nib.Nifti1Image(
seg.astype(img.get_data_dtype()),
affine=img.affine,
header=img.header
)
nib.save(out_img, out_path)
return file_path, "success", ""
except Exception as e:
return file_path, "failed", str(e)
def main():
parser = argparse.ArgumentParser(
description=(
"Remove connected components of source labels based on overlap with one or more "
"target labels after dilation, and then remove small connected components."
)
)
parser.add_argument(
"--input_dir",
type=str,
required=True,
help=""
)
parser.add_argument(
"--output_dir",
type=str,
required=True,
help=""
)
parser.add_argument(
"--rules",
type=str,
required=True,
help=""
)
parser.add_argument(
"--target_labels",
type=str,
default=None,
help=(
""
""
)
)
parser.add_argument(
"--keep_mode",
type=str,
default="any",
choices=["any", "all"],
help=(
""
""
)
)
parser.add_argument(
"--min_size",
type=int,
default=30,
help=""
)
parser.add_argument(
"--dilation_iters",
type=int,
default=1,
help=""
)
parser.add_argument(
"--connectivity",
type=int,
default=26,
choices=[6, 26],
help=""
)
parser.add_argument(
"--num_workers",
type=int,
default=4,
help=""
)
args = parser.parse_args()
rules = parse_rules(args.rules)
if args.target_labels is None:
clean_labels = list(rules.keys())
else:
clean_labels = parse_labels(args.target_labels)
os.makedirs(args.output_dir, exist_ok=True)
file_list = find_nii_gz_files(args.input_dir)
if len(file_list) == 0:
print(f"No .nii.gz files found in: {args.input_dir}")
return
#print(f"Found {len(file_list)} files.")
#print(f"Rules: {rules}")
#print(f"Labels for small-component removal: {clean_labels}")
#print(f"Keep mode: {args.keep_mode}")
#print(f"Min size: {args.min_size}")
#print(f"Dilation iterations: {args.dilation_iters}")
#print(f"Connectivity: {args.connectivity}")
#print(f"Num workers: {args.num_workers}")
results = []
with ProcessPoolExecutor(max_workers=args.num_workers) as executor:
futures = [
executor.submit(
process_one_file,
file_path,
args.input_dir,
args.output_dir,
rules,
clean_labels,
args.min_size,
args.dilation_iters,
args.connectivity,
args.keep_mode
)
for file_path in file_list
]
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing"):
results.append(future.result())
success_count = sum(1 for _, status, _ in results if status == "success")
failed_cases = [(fp, err) for fp, status, err in results if status == "failed"]
print(f"\nDone. Success: {success_count}, Failed: {len(failed_cases)}")
if failed_cases:
failed_txt = os.path.join(args.output_dir, "failed_cases.txt")
with open(failed_txt, "w", encoding="utf-8") as f:
for fp, err in failed_cases:
f.write(f"{fp}\t{err}\n")
print(f"Failed cases saved to: {failed_txt}")
if __name__ == "__main__":
main()