|
|
"""Command-line interface for stroke-deepisles-demo.""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import argparse |
|
|
import sys |
|
|
from pathlib import Path |
|
|
|
|
|
from stroke_deepisles_demo.data import list_case_ids |
|
|
from stroke_deepisles_demo.pipeline import run_pipeline_on_case |
|
|
|
|
|
|
|
|
def main(argv: list[str] | None = None) -> int: |
|
|
"""Main CLI entry point.""" |
|
|
parser = argparse.ArgumentParser( |
|
|
prog="stroke-demo", |
|
|
description="Run DeepISLES stroke segmentation on HF datasets", |
|
|
) |
|
|
subparsers = parser.add_subparsers(dest="command", required=True) |
|
|
|
|
|
|
|
|
list_parser = subparsers.add_parser("list", help="List available cases") |
|
|
list_parser.add_argument("--dataset", default=None, help="HF dataset ID (not used yet)") |
|
|
|
|
|
|
|
|
run_parser = subparsers.add_parser("run", help="Run segmentation") |
|
|
run_parser.add_argument("--case", type=str, help="Case ID (e.g., sub-stroke0001)") |
|
|
run_parser.add_argument("--index", type=int, help="Case index (alternative to --case)") |
|
|
run_parser.add_argument("--output", type=Path, default=None, help="Output directory") |
|
|
run_parser.add_argument( |
|
|
"--no-fast", action="store_false", dest="fast", help="Disable fast mode (SEALS-only)" |
|
|
) |
|
|
run_parser.set_defaults(fast=True) |
|
|
|
|
|
run_parser.add_argument("--no-gpu", action="store_true", help="Disable GPU") |
|
|
|
|
|
args = parser.parse_args(argv) |
|
|
|
|
|
if args.command == "list": |
|
|
return cmd_list(args) |
|
|
elif args.command == "run": |
|
|
return cmd_run(args) |
|
|
|
|
|
return 0 |
|
|
|
|
|
|
|
|
def cmd_list(args: argparse.Namespace) -> int: |
|
|
"""Handle 'list' command.""" |
|
|
try: |
|
|
case_ids = list_case_ids() |
|
|
print(f"Found {len(case_ids)} cases:") |
|
|
for i, cid in enumerate(case_ids): |
|
|
print(f"[{i}] {cid}") |
|
|
return 0 |
|
|
except Exception as e: |
|
|
print(f"Error listing cases: {e}", file=sys.stderr) |
|
|
return 1 |
|
|
|
|
|
|
|
|
def cmd_run(args: argparse.Namespace) -> int: |
|
|
"""Handle 'run' command.""" |
|
|
if args.case is None and args.index is None: |
|
|
print("Error: Must specify --case or --index", file=sys.stderr) |
|
|
return 1 |
|
|
|
|
|
case_id: str | int = args.case if args.case else args.index |
|
|
|
|
|
try: |
|
|
print(f"Running pipeline on case: {case_id} (fast={args.fast}, gpu={not args.no_gpu})") |
|
|
result = run_pipeline_on_case( |
|
|
case_id=case_id, |
|
|
output_dir=args.output, |
|
|
fast=args.fast, |
|
|
gpu=not args.no_gpu, |
|
|
compute_dice=True, |
|
|
cleanup_staging=True, |
|
|
) |
|
|
|
|
|
print("\nPipeline Completed Successfully!") |
|
|
print(f"Case ID: {result.case_id}") |
|
|
print(f"Prediction: {result.prediction_mask}") |
|
|
if result.ground_truth: |
|
|
print(f"Ground Truth: {result.ground_truth}") |
|
|
if result.dice_score is not None: |
|
|
print(f"Dice Score: {result.dice_score:.4f}") |
|
|
else: |
|
|
print("No Ground Truth available.") |
|
|
|
|
|
print(f"Elapsed: {result.elapsed_seconds:.1f}s") |
|
|
return 0 |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Pipeline failed: {e}", file=sys.stderr) |
|
|
return 1 |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
sys.exit(main()) |
|
|
|