| import os |
| import argparse |
| from concurrent.futures import ProcessPoolExecutor, as_completed |
|
|
| import nibabel as nib |
| import numpy as np |
| from scipy import ndimage |
| from tqdm import tqdm |
|
|
|
|
| def parse_rules(rule_str): |
| rules = {} |
| items = [x.strip() for x in rule_str.split(",") if x.strip()] |
|
|
| for item in items: |
| if ":" not in item: |
| raise ValueError( |
| f"Invalid rule format: {item}. Expected format like '1:2|3|4,5:6'" |
| ) |
|
|
| src, dst = item.split(":") |
| src_label = int(src.strip()) |
| dst_labels = [int(x.strip()) for x in dst.split("|") if x.strip()] |
|
|
| if len(dst_labels) == 0: |
| raise ValueError(f"No target labels found in rule: {item}") |
|
|
| rules[src_label] = dst_labels |
|
|
| return rules |
|
|
|
|
| def parse_labels(label_str): |
|
|
| return [int(x.strip()) for x in label_str.split(",") if x.strip()] |
|
|
|
|
| def find_nii_gz_files(input_dir): |
|
|
| files = [] |
| for root, _, filenames in os.walk(input_dir): |
| for fname in filenames: |
| if fname.endswith(".nii.gz"): |
| files.append(os.path.join(root, fname)) |
| return sorted(files) |
|
|
|
|
| def get_structure(connectivity=26): |
|
|
| if connectivity == 6: |
| return ndimage.generate_binary_structure(3, 1) |
| elif connectivity == 26: |
| return ndimage.generate_binary_structure(3, 2) |
| else: |
| raise ValueError("connectivity must be 6 or 26") |
|
|
|
|
| def should_keep_component(dilated_component, seg, target_labels, keep_mode): |
|
|
| overlaps = [] |
| for label in target_labels: |
| target_mask = (seg == label) |
| overlaps.append(np.any(dilated_component & target_mask)) |
|
|
| if keep_mode == "any": |
| return any(overlaps) |
| elif keep_mode == "all": |
| return all(overlaps) |
| else: |
| raise ValueError("keep_mode must be 'any' or 'all'") |
|
|
|
|
| def remove_components_by_rules( |
| seg, |
| source_label, |
| target_labels, |
| dilation_iters, |
| structure, |
| keep_mode="any" |
| ): |
|
|
| source_mask = (seg == source_label) |
|
|
| if not np.any(source_mask): |
| return seg |
|
|
| cc_map, num_cc = ndimage.label(source_mask, structure=structure) |
|
|
| for cc_id in range(1, num_cc + 1): |
| component = (cc_map == cc_id) |
|
|
| dilated_component = ndimage.binary_dilation( |
| component, |
| structure=structure, |
| iterations=dilation_iters |
| ) |
|
|
| keep = should_keep_component( |
| dilated_component=dilated_component, |
| seg=seg, |
| target_labels=target_labels, |
| keep_mode=keep_mode |
| ) |
|
|
| if not keep: |
| seg[component] = 0 |
|
|
| return seg |
|
|
|
|
| def remove_small_components(seg, label_value, min_size, structure): |
|
|
| mask = (seg == label_value) |
|
|
| if not np.any(mask): |
| return seg |
|
|
| cc_map, num_cc = ndimage.label(mask, structure=structure) |
|
|
| for cc_id in range(1, num_cc + 1): |
| component = (cc_map == cc_id) |
| voxel_count = int(component.sum()) |
|
|
| if voxel_count < min_size: |
| seg[component] = 0 |
|
|
| return seg |
|
|
|
|
| def process_one_file( |
| file_path, |
| input_dir, |
| output_dir, |
| rules, |
| clean_labels, |
| min_size, |
| dilation_iters, |
| connectivity, |
| keep_mode |
| ): |
|
|
| try: |
| img = nib.load(file_path) |
| seg = np.asanyarray(img.dataobj).astype(np.int32) |
|
|
| structure = get_structure(connectivity) |
|
|
|
|
| for source_label, dst_labels in rules.items(): |
| seg = remove_components_by_rules( |
| seg=seg, |
| source_label=source_label, |
| target_labels=dst_labels, |
| dilation_iters=dilation_iters, |
| structure=structure, |
| keep_mode=keep_mode |
| ) |
|
|
|
|
| for label in clean_labels: |
| seg = remove_small_components( |
| seg=seg, |
| label_value=label, |
| min_size=min_size, |
| structure=structure |
| ) |
|
|
|
|
| rel_path = os.path.relpath(file_path, input_dir) |
| out_path = os.path.join(output_dir, rel_path) |
| os.makedirs(os.path.dirname(out_path), exist_ok=True) |
|
|
| out_img = nib.Nifti1Image( |
| seg.astype(img.get_data_dtype()), |
| affine=img.affine, |
| header=img.header |
| ) |
| nib.save(out_img, out_path) |
|
|
| return file_path, "success", "" |
|
|
| except Exception as e: |
| return file_path, "failed", str(e) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description=( |
| "Remove connected components of source labels based on overlap with one or more " |
| "target labels after dilation, and then remove small connected components." |
| ) |
| ) |
|
|
| parser.add_argument( |
| "--input_dir", |
| type=str, |
| required=True, |
| help="" |
| ) |
| parser.add_argument( |
| "--output_dir", |
| type=str, |
| required=True, |
| help="" |
| ) |
| parser.add_argument( |
| "--rules", |
| type=str, |
| required=True, |
| help="" |
| ) |
| parser.add_argument( |
| "--target_labels", |
| type=str, |
| default=None, |
| help=( |
| "" |
| "" |
| ) |
| ) |
| parser.add_argument( |
| "--keep_mode", |
| type=str, |
| default="any", |
| choices=["any", "all"], |
| help=( |
| "" |
| "" |
| ) |
| ) |
| parser.add_argument( |
| "--min_size", |
| type=int, |
| default=30, |
| help="" |
| ) |
| parser.add_argument( |
| "--dilation_iters", |
| type=int, |
| default=1, |
| help="" |
| ) |
| parser.add_argument( |
| "--connectivity", |
| type=int, |
| default=26, |
| choices=[6, 26], |
| help="" |
| ) |
| parser.add_argument( |
| "--num_workers", |
| type=int, |
| default=4, |
| help="" |
| ) |
|
|
| args = parser.parse_args() |
|
|
| rules = parse_rules(args.rules) |
|
|
| if args.target_labels is None: |
| clean_labels = list(rules.keys()) |
| else: |
| clean_labels = parse_labels(args.target_labels) |
|
|
| os.makedirs(args.output_dir, exist_ok=True) |
|
|
| file_list = find_nii_gz_files(args.input_dir) |
| if len(file_list) == 0: |
| print(f"No .nii.gz files found in: {args.input_dir}") |
| return |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| results = [] |
| with ProcessPoolExecutor(max_workers=args.num_workers) as executor: |
| futures = [ |
| executor.submit( |
| process_one_file, |
| file_path, |
| args.input_dir, |
| args.output_dir, |
| rules, |
| clean_labels, |
| args.min_size, |
| args.dilation_iters, |
| args.connectivity, |
| args.keep_mode |
| ) |
| for file_path in file_list |
| ] |
|
|
| for future in tqdm(as_completed(futures), total=len(futures), desc="Processing"): |
| results.append(future.result()) |
|
|
| success_count = sum(1 for _, status, _ in results if status == "success") |
| failed_cases = [(fp, err) for fp, status, err in results if status == "failed"] |
|
|
| print(f"\nDone. Success: {success_count}, Failed: {len(failed_cases)}") |
|
|
| if failed_cases: |
| failed_txt = os.path.join(args.output_dir, "failed_cases.txt") |
| with open(failed_txt, "w", encoding="utf-8") as f: |
| for fp, err in failed_cases: |
| f.write(f"{fp}\t{err}\n") |
| print(f"Failed cases saved to: {failed_txt}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |