import os import argparse from concurrent.futures import ProcessPoolExecutor, as_completed import nibabel as nib import numpy as np from scipy import ndimage from tqdm import tqdm def parse_rules(rule_str): rules = {} items = [x.strip() for x in rule_str.split(",") if x.strip()] for item in items: if ":" not in item: raise ValueError( f"Invalid rule format: {item}. Expected format like '1:2|3|4,5:6'" ) src, dst = item.split(":") src_label = int(src.strip()) dst_labels = [int(x.strip()) for x in dst.split("|") if x.strip()] if len(dst_labels) == 0: raise ValueError(f"No target labels found in rule: {item}") rules[src_label] = dst_labels return rules def parse_labels(label_str): return [int(x.strip()) for x in label_str.split(",") if x.strip()] def find_nii_gz_files(input_dir): files = [] for root, _, filenames in os.walk(input_dir): for fname in filenames: if fname.endswith(".nii.gz"): files.append(os.path.join(root, fname)) return sorted(files) def get_structure(connectivity=26): if connectivity == 6: return ndimage.generate_binary_structure(3, 1) elif connectivity == 26: return ndimage.generate_binary_structure(3, 2) else: raise ValueError("connectivity must be 6 or 26") def should_keep_component(dilated_component, seg, target_labels, keep_mode): overlaps = [] for label in target_labels: target_mask = (seg == label) overlaps.append(np.any(dilated_component & target_mask)) if keep_mode == "any": return any(overlaps) elif keep_mode == "all": return all(overlaps) else: raise ValueError("keep_mode must be 'any' or 'all'") def remove_components_by_rules( seg, source_label, target_labels, dilation_iters, structure, keep_mode="any" ): source_mask = (seg == source_label) if not np.any(source_mask): return seg cc_map, num_cc = ndimage.label(source_mask, structure=structure) for cc_id in range(1, num_cc + 1): component = (cc_map == cc_id) dilated_component = ndimage.binary_dilation( component, structure=structure, iterations=dilation_iters ) keep = should_keep_component( dilated_component=dilated_component, seg=seg, target_labels=target_labels, keep_mode=keep_mode ) if not keep: seg[component] = 0 return seg def remove_small_components(seg, label_value, min_size, structure): mask = (seg == label_value) if not np.any(mask): return seg cc_map, num_cc = ndimage.label(mask, structure=structure) for cc_id in range(1, num_cc + 1): component = (cc_map == cc_id) voxel_count = int(component.sum()) if voxel_count < min_size: seg[component] = 0 return seg def process_one_file( file_path, input_dir, output_dir, rules, clean_labels, min_size, dilation_iters, connectivity, keep_mode ): try: img = nib.load(file_path) seg = np.asanyarray(img.dataobj).astype(np.int32) structure = get_structure(connectivity) for source_label, dst_labels in rules.items(): seg = remove_components_by_rules( seg=seg, source_label=source_label, target_labels=dst_labels, dilation_iters=dilation_iters, structure=structure, keep_mode=keep_mode ) for label in clean_labels: seg = remove_small_components( seg=seg, label_value=label, min_size=min_size, structure=structure ) rel_path = os.path.relpath(file_path, input_dir) out_path = os.path.join(output_dir, rel_path) os.makedirs(os.path.dirname(out_path), exist_ok=True) out_img = nib.Nifti1Image( seg.astype(img.get_data_dtype()), affine=img.affine, header=img.header ) nib.save(out_img, out_path) return file_path, "success", "" except Exception as e: return file_path, "failed", str(e) def main(): parser = argparse.ArgumentParser( description=( "Remove connected components of source labels based on overlap with one or more " "target labels after dilation, and then remove small connected components." ) ) parser.add_argument( "--input_dir", type=str, required=True, help="" ) parser.add_argument( "--output_dir", type=str, required=True, help="" ) parser.add_argument( "--rules", type=str, required=True, help="" ) parser.add_argument( "--target_labels", type=str, default=None, help=( "" "" ) ) parser.add_argument( "--keep_mode", type=str, default="any", choices=["any", "all"], help=( "" "" ) ) parser.add_argument( "--min_size", type=int, default=30, help="" ) parser.add_argument( "--dilation_iters", type=int, default=1, help="" ) parser.add_argument( "--connectivity", type=int, default=26, choices=[6, 26], help="" ) parser.add_argument( "--num_workers", type=int, default=4, help="" ) args = parser.parse_args() rules = parse_rules(args.rules) if args.target_labels is None: clean_labels = list(rules.keys()) else: clean_labels = parse_labels(args.target_labels) os.makedirs(args.output_dir, exist_ok=True) file_list = find_nii_gz_files(args.input_dir) if len(file_list) == 0: print(f"No .nii.gz files found in: {args.input_dir}") return #print(f"Found {len(file_list)} files.") #print(f"Rules: {rules}") #print(f"Labels for small-component removal: {clean_labels}") #print(f"Keep mode: {args.keep_mode}") #print(f"Min size: {args.min_size}") #print(f"Dilation iterations: {args.dilation_iters}") #print(f"Connectivity: {args.connectivity}") #print(f"Num workers: {args.num_workers}") results = [] with ProcessPoolExecutor(max_workers=args.num_workers) as executor: futures = [ executor.submit( process_one_file, file_path, args.input_dir, args.output_dir, rules, clean_labels, args.min_size, args.dilation_iters, args.connectivity, args.keep_mode ) for file_path in file_list ] for future in tqdm(as_completed(futures), total=len(futures), desc="Processing"): results.append(future.result()) success_count = sum(1 for _, status, _ in results if status == "success") failed_cases = [(fp, err) for fp, status, err in results if status == "failed"] print(f"\nDone. Success: {success_count}, Failed: {len(failed_cases)}") if failed_cases: failed_txt = os.path.join(args.output_dir, "failed_cases.txt") with open(failed_txt, "w", encoding="utf-8") as f: for fp, err in failed_cases: f.write(f"{fp}\t{err}\n") print(f"Failed cases saved to: {failed_txt}") if __name__ == "__main__": main()