File size: 3,221 Bytes
3f8bf9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""Command-line interface for stroke-deepisles-demo."""

from __future__ import annotations

import argparse
import sys
from pathlib import Path

from stroke_deepisles_demo.data import list_case_ids
from stroke_deepisles_demo.pipeline import run_pipeline_on_case


def main(argv: list[str] | None = None) -> int:
    """Main CLI entry point."""
    parser = argparse.ArgumentParser(
        prog="stroke-demo",
        description="Run DeepISLES stroke segmentation on HF datasets",
    )
    subparsers = parser.add_subparsers(dest="command", required=True)

    # List command
    list_parser = subparsers.add_parser("list", help="List available cases")
    list_parser.add_argument("--dataset", default=None, help="HF dataset ID (not used yet)")

    # Run command
    run_parser = subparsers.add_parser("run", help="Run segmentation")
    run_parser.add_argument("--case", type=str, help="Case ID (e.g., sub-stroke0001)")
    run_parser.add_argument("--index", type=int, help="Case index (alternative to --case)")
    run_parser.add_argument("--output", type=Path, default=None, help="Output directory")
    run_parser.add_argument(
        "--no-fast", action="store_false", dest="fast", help="Disable fast mode (SEALS-only)"
    )
    run_parser.set_defaults(fast=True)

    run_parser.add_argument("--no-gpu", action="store_true", help="Disable GPU")

    args = parser.parse_args(argv)

    if args.command == "list":
        return cmd_list(args)
    elif args.command == "run":
        return cmd_run(args)

    return 0


def cmd_list(args: argparse.Namespace) -> int:  # noqa: ARG001
    """Handle 'list' command."""
    try:
        case_ids = list_case_ids()
        print(f"Found {len(case_ids)} cases:")
        for i, cid in enumerate(case_ids):
            print(f"[{i}] {cid}")
        return 0
    except Exception as e:
        print(f"Error listing cases: {e}", file=sys.stderr)
        return 1


def cmd_run(args: argparse.Namespace) -> int:
    """Handle 'run' command."""
    if args.case is None and args.index is None:
        print("Error: Must specify --case or --index", file=sys.stderr)
        return 1

    case_id: str | int = args.case if args.case else args.index

    try:
        print(f"Running pipeline on case: {case_id} (fast={args.fast}, gpu={not args.no_gpu})")
        result = run_pipeline_on_case(
            case_id=case_id,
            output_dir=args.output,
            fast=args.fast,
            gpu=not args.no_gpu,
            compute_dice=True,
            cleanup_staging=True,  # Clean up by default for CLI runs
        )

        print("\nPipeline Completed Successfully!")
        print(f"Case ID: {result.case_id}")
        print(f"Prediction: {result.prediction_mask}")
        if result.ground_truth:
            print(f"Ground Truth: {result.ground_truth}")
            if result.dice_score is not None:
                print(f"Dice Score: {result.dice_score:.4f}")
        else:
            print("No Ground Truth available.")

        print(f"Elapsed: {result.elapsed_seconds:.1f}s")
        return 0

    except Exception as e:
        print(f"Pipeline failed: {e}", file=sys.stderr)
        return 1


if __name__ == "__main__":
    sys.exit(main())