File size: 5,379 Bytes
7f58aa4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""
Inference script: detect immunogold particles in new images.

Usage:
    python predict.py --image path/to/image.tif --checkpoint checkpoints/fold_S1_seed42/phase3_best.pth
    python predict.py --fold S1 --checkpoint checkpoints/fold_S1_seed42/phase3_best.pth --config config/config.yaml
"""

import argparse
from pathlib import Path

import numpy as np
import torch
import yaml

from src.heatmap import extract_peaks
from src.model import ImmunogoldCenterNet
from src.postprocess import apply_structural_mask_filter, cross_class_nms
from src.preprocessing import load_image, load_mask
from src.ensemble import sliding_window_inference, d4_tta_predict
from src.visualize import overlay_annotations


def parse_args():
    parser = argparse.ArgumentParser(description="Predict immunogold particles")
    parser.add_argument("--image", type=str, help="Path to single image")
    parser.add_argument("--mask", type=str, help="Path to mask (optional)")
    parser.add_argument("--fold", type=str, help="Fold synapse ID for evaluation")
    parser.add_argument("--checkpoint", type=str, required=True,
                        help="Path to model checkpoint")
    parser.add_argument("--config", type=str, default="config/config.yaml")
    parser.add_argument("--device", type=str, default="auto")
    parser.add_argument("--tta", action="store_true", help="Enable D4 TTA")
    parser.add_argument("--conf-threshold", type=float, default=0.3)
    parser.add_argument("--output-dir", type=str, default="results/predictions")
    return parser.parse_args()


def main():
    args = parse_args()
    with open(args.config) as f:
        cfg = yaml.safe_load(f)

    device = torch.device(
        "cuda" if args.device == "auto" and torch.cuda.is_available()
        else args.device if args.device != "auto" else "cpu"
    )

    # Load model
    model = ImmunogoldCenterNet(
        bifpn_channels=cfg["model"]["bifpn_channels"],
        bifpn_rounds=cfg["model"]["bifpn_rounds"],
        num_classes=cfg["model"]["num_classes"],
    )

    ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
    model.load_state_dict(ckpt["model_state_dict"])
    model.to(device)
    model.eval()
    print(f"Loaded checkpoint from epoch {ckpt.get('epoch', '?')}, "
          f"val_f1={ckpt.get('val_f1_mean', '?')}")

    # Load image
    if args.fold:
        from src.preprocessing import discover_synapse_data, load_synapse
        records = discover_synapse_data(cfg["data"]["root"], cfg["data"]["synapse_ids"])
        record = [r for r in records if r.synapse_id == args.fold][0]
        data = load_synapse(record)
        image = data["image"]
        preprocessed = data["image"]
        mask = data["mask"]
        annotations = data["annotations"]
        name = args.fold
    else:
        image = load_image(Path(args.image))
        preprocessed = image
        mask = load_mask(Path(args.mask)) if args.mask else None
        annotations = {"6nm": np.empty((0, 2)), "12nm": np.empty((0, 2))}
        name = Path(args.image).stem

    # Inference
    if args.tta:
        print("Running D4 TTA inference...")
        heatmap_np, offset_np = d4_tta_predict(model, preprocessed, device)
    else:
        print("Running sliding window inference...")
        heatmap_np, offset_np = sliding_window_inference(
            model, preprocessed,
            patch_size=cfg["data"]["patch_size"],
            device=device,
        )

    # Extract detections
    heatmap_t = torch.from_numpy(heatmap_np)
    offset_t = torch.from_numpy(offset_np)

    detections = extract_peaks(
        heatmap_t, offset_t,
        stride=cfg["data"]["stride"],
        conf_threshold=args.conf_threshold,
        nms_kernel_sizes=cfg["postprocessing"]["nms_kernel_size"],
    )

    # Post-processing
    if mask is not None:
        detections = apply_structural_mask_filter(
            detections, mask,
            margin_px=cfg["postprocessing"]["mask_filter_margin_px"],
        )
    detections = cross_class_nms(
        detections, cfg["postprocessing"]["cross_class_nms_distance_px"],
    )

    # Print results
    n_6nm = sum(1 for d in detections if d["class"] == "6nm")
    n_12nm = sum(1 for d in detections if d["class"] == "12nm")
    print(f"\nDetections: {n_6nm} 6nm, {n_12nm} 12nm ({len(detections)} total)")

    # Evaluate if GT available
    if annotations and (len(annotations["6nm"]) > 0 or len(annotations["12nm"]) > 0):
        from src.evaluate import match_detections_to_gt
        results = match_detections_to_gt(
            detections, annotations["6nm"], annotations["12nm"],
            {k: float(v) for k, v in cfg["evaluation"]["match_radii_px"].items()},
        )
        for cls in ["6nm", "12nm", "overall"]:
            r = results[cls]
            print(f"  {cls}: F1={r['f1']:.3f}, P={r['precision']:.3f}, "
                  f"R={r['recall']:.3f} (TP={r['tp']}, FP={r['fp']}, FN={r['fn']})")

    # Save visualization
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    overlay_annotations(
        image, annotations,
        title=f"{name}{n_6nm} 6nm, {n_12nm} 12nm detected",
        save_path=output_dir / f"{name}_predictions.png",
        predictions=detections,
    )
    print(f"Saved overlay to {output_dir / f'{name}_predictions.png'}")


if __name__ == "__main__":
    main()