AnikS22 commited on
Commit
1fc7794
·
verified ·
1 Parent(s): d1fb167

Sync all source code, docs, and configs

Browse files
scripts/verify_data.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data verification script: loads all images and annotations,
3
+ validates counts, and saves visual overlays.
4
+
5
+ Usage:
6
+ python scripts/verify_data.py --config config/config.yaml
7
+ """
8
+
9
+ import argparse
10
+ import sys
11
+ from pathlib import Path
12
+
13
+ import numpy as np
14
+ import yaml
15
+
16
+ sys.path.insert(0, str(Path(__file__).parent.parent))
17
+
18
+ from src.preprocessing import discover_synapse_data, load_synapse
19
+ from src.visualize import overlay_annotations
20
+
21
+
22
+ def main():
23
+ parser = argparse.ArgumentParser()
24
+ parser.add_argument("--config", default="config/config.yaml")
25
+ args = parser.parse_args()
26
+
27
+ with open(args.config) as f:
28
+ cfg = yaml.safe_load(f)
29
+
30
+ print("=" * 60)
31
+ print("Immunogold Data Verification")
32
+ print("=" * 60)
33
+
34
+ records = discover_synapse_data(cfg["data"]["root"], cfg["data"]["synapse_ids"])
35
+
36
+ total_6nm = 0
37
+ total_12nm = 0
38
+ output_dir = Path("results/verification")
39
+
40
+ for record in records:
41
+ print(f"\n--- {record.synapse_id} ---")
42
+ print(f" Image: {record.image_path.name}")
43
+ print(f" Mask: {record.mask_path.name if record.mask_path else 'NONE'}")
44
+ print(f" 6nm CSVs: {[p.name for p in record.csv_6nm_paths]}")
45
+ print(f" 12nm CSVs: {[p.name for p in record.csv_12nm_paths]}")
46
+
47
+ data = load_synapse(record)
48
+ img = data["image"]
49
+ annots = data["annotations"]
50
+
51
+ n6 = len(annots["6nm"])
52
+ n12 = len(annots["12nm"])
53
+ total_6nm += n6
54
+ total_12nm += n12
55
+
56
+ print(f" Image shape: {img.shape}")
57
+ print(f" 6nm particles: {n6}")
58
+ print(f" 12nm particles: {n12}")
59
+
60
+ if data["mask"] is not None:
61
+ # Check how many particles fall within mask
62
+ mask = data["mask"]
63
+ for cls, coords in annots.items():
64
+ if len(coords) == 0:
65
+ continue
66
+ inside = sum(
67
+ 1 for x, y in coords
68
+ if 0 <= int(y) < mask.shape[0] and
69
+ 0 <= int(x) < mask.shape[1] and
70
+ mask[int(y), int(x)]
71
+ )
72
+ print(f" {cls} in mask: {inside}/{len(coords)}")
73
+
74
+ # Save overlay
75
+ overlay_annotations(
76
+ img, annots,
77
+ title=f"{record.synapse_id}: {n6} 6nm, {n12} 12nm",
78
+ save_path=output_dir / f"{record.synapse_id}_annotations.png",
79
+ )
80
+
81
+ print(f"\n{'=' * 60}")
82
+ print(f"TOTAL: {total_6nm} 6nm + {total_12nm} 12nm = {total_6nm + total_12nm}")
83
+ print(f"Expected: 403 6nm + 50 12nm = 453")
84
+
85
+ if total_6nm + total_12nm >= 400:
86
+ print("PASS: Particle counts look reasonable")
87
+ else:
88
+ print("WARNING: Total count is lower than expected")
89
+
90
+ print(f"\nOverlays saved to: {output_dir}")
91
+
92
+
93
+ if __name__ == "__main__":
94
+ main()
slurm/02_train_single_fold.sh ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=immunogold_train
3
+ #SBATCH --partition=gpu # TODO: adjust to your GPU partition
4
+ #SBATCH --gres=gpu:1 # single A100 or V100
5
+ #SBATCH --nodes=1
6
+ #SBATCH --ntasks=1
7
+ #SBATCH --cpus-per-task=8
8
+ #SBATCH --mem=80G
9
+ #SBATCH --time=12:00:00
10
+ #SBATCH --array=0-49 # 10 folds x 5 seeds = 50 jobs
11
+ #SBATCH --output=logs/train_%A_%a.out
12
+ #SBATCH --error=logs/train_%A_%a.err
13
+
14
+ set -euo pipefail
15
+ mkdir -p logs
16
+
17
+ eval "$(conda shell.bash hook)"
18
+ conda activate immunogold
19
+
20
+ # Map array task ID to fold and seed
21
+ SYNAPSE_IDS=("S1" "S4" "S7" "S8" "S13" "S15" "S22" "S25" "S27" "S29")
22
+ FOLD_IDX=$((SLURM_ARRAY_TASK_ID / 5))
23
+ SEED_IDX=$((SLURM_ARRAY_TASK_ID % 5))
24
+ SEED=$((SEED_IDX + 42))
25
+ FOLD_NAME=${SYNAPSE_IDS[$FOLD_IDX]}
26
+
27
+ echo "=== Training immunogold CenterNet ==="
28
+ echo "Date: $(date)"
29
+ echo "Array task: ${SLURM_ARRAY_TASK_ID}"
30
+ echo "Fold: ${FOLD_NAME} (idx ${FOLD_IDX})"
31
+ echo "Seed: ${SEED} (idx ${SEED_IDX})"
32
+ echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader 2>/dev/null || echo 'N/A')"
33
+
34
+ python train.py \
35
+ --fold "${FOLD_NAME}" \
36
+ --seed "${SEED}" \
37
+ --config config/config.yaml \
38
+ --device cuda:0
39
+
40
+ echo "=== Training complete for fold=${FOLD_NAME} seed=${SEED} ==="
slurm/03_evaluate_ensemble.sh ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=immunogold_eval
3
+ #SBATCH --partition=gpu # TODO: adjust
4
+ #SBATCH --gres=gpu:1 # single GPU per fold (model loaded sequentially)
5
+ #SBATCH --nodes=1
6
+ #SBATCH --ntasks=1
7
+ #SBATCH --cpus-per-task=8
8
+ #SBATCH --mem=64G
9
+ #SBATCH --time=04:00:00
10
+ #SBATCH --array=0-9 # one job per test fold
11
+ #SBATCH --output=logs/eval_%A_%a.out
12
+ #SBATCH --error=logs/eval_%A_%a.err
13
+
14
+ set -euo pipefail
15
+ mkdir -p logs
16
+
17
+ eval "$(conda shell.bash hook)"
18
+ conda activate immunogold
19
+
20
+ SYNAPSE_IDS=("S1" "S4" "S7" "S8" "S13" "S15" "S22" "S25" "S27" "S29")
21
+ FOLD_NAME=${SYNAPSE_IDS[$SLURM_ARRAY_TASK_ID]}
22
+
23
+ echo "=== Ensemble evaluation ==="
24
+ echo "Date: $(date)"
25
+ echo "Test fold: ${FOLD_NAME}"
26
+ echo "Array task: ${SLURM_ARRAY_TASK_ID}"
27
+ echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader 2>/dev/null || echo 'N/A')"
28
+
29
+ # Evaluate this single fold using the same script as local evaluation
30
+ # evaluate_loocv.py handles loading all ensemble members per fold
31
+ python evaluate_loocv.py \
32
+ --config config/config.yaml \
33
+ --ensemble-dir checkpoints \
34
+ --device cuda:0 \
35
+ --use-tta \
36
+ --fold "${FOLD_NAME}" \
37
+ --output "results/per_fold_predictions/${FOLD_NAME}_metrics.csv"
38
+
39
+ echo "=== Evaluation complete for fold ${FOLD_NAME} ==="
slurm/04_full_pipeline.sh ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Master submission script for the full immunogold detection pipeline.
3
+ # Submits all stages with dependency chaining so they run sequentially.
4
+ #
5
+ # Usage:
6
+ # bash slurm/04_full_pipeline.sh
7
+
8
+ set -euo pipefail
9
+
10
+ echo "=== Immunogold Detection Pipeline ==="
11
+ echo "Date: $(date)"
12
+ echo "Submitting all pipeline stages..."
13
+
14
+ # Create log directory
15
+ mkdir -p logs
16
+
17
+ # Stage 0: Environment setup (single job)
18
+ JOB0=$(sbatch --parsable slurm/00_setup_env.sh)
19
+ echo "Stage 0 (setup): Job ${JOB0}"
20
+
21
+ # Stage 1: Training — 50 parallel jobs (depends on setup)
22
+ JOB1=$(sbatch --parsable --dependency=afterok:${JOB0} slurm/02_train_single_fold.sh)
23
+ echo "Stage 1 (training): Job ${JOB1} (array 0-49, depends on ${JOB0})"
24
+
25
+ # Stage 2: Ensemble evaluation — 10 parallel jobs (depends on ALL training tasks)
26
+ # afterok on an array job waits for ALL array tasks to complete successfully
27
+ JOB2=$(sbatch --parsable --dependency=afterok:${JOB1} slurm/03_evaluate_ensemble.sh)
28
+ echo "Stage 2 (evaluate): Job ${JOB2} (array 0-9, depends on ${JOB1})"
29
+
30
+ echo ""
31
+ echo "Pipeline submitted successfully!"
32
+ echo "Final results job: ${JOB2}"
33
+ echo ""
34
+ echo "Monitor with:"
35
+ echo " squeue -u \${USER}"
36
+ echo " sacct -j ${JOB0},${JOB1},${JOB2}"
37
+ echo ""
38
+ echo "Results will be in:"
39
+ echo " results/loocv_metrics.csv"
40
+ echo " results/per_fold_predictions/"
slurm/05_train_final.sh ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=immunogold_final
3
+ #SBATCH --partition=shortq7-gpu
4
+ #SBATCH --gres=gpu:1
5
+ #SBATCH --nodes=1
6
+ #SBATCH --ntasks=1
7
+ #SBATCH --cpus-per-task=8
8
+ #SBATCH --mem=32G
9
+ #SBATCH --time=06:00:00
10
+ #SBATCH --output=logs/train_final_%j.out
11
+ #SBATCH --error=logs/train_final_%j.err
12
+
13
+ set -euo pipefail
14
+ mkdir -p logs
15
+
16
+ module load miniconda3/24.3.0-gcc-13.2.0-rslr3to
17
+ module load cuda/12.4.0-gcc-13.2.0-bxjolrw
18
+ eval "$(conda shell.bash hook)"
19
+ conda activate immunogold
20
+
21
+ echo "=== Training final deployable model ==="
22
+ echo "Date: $(date)"
23
+ echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader 2>/dev/null)"
24
+
25
+ python train_final.py --config config/config.yaml --device cuda:0
26
+
27
+ echo "=== Final model training complete ==="
src/heatmap.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Ground truth heatmap generation and peak extraction for CenterNet.
3
+
4
+ Generates Gaussian-splat heatmaps at stride-2 resolution with
5
+ class-specific sigma values calibrated to bead size.
6
+ """
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from typing import Dict, List, Tuple, Optional
12
+
13
+ # Class index mapping
14
+ CLASS_IDX = {"6nm": 0, "12nm": 1}
15
+ CLASS_NAMES = ["6nm", "12nm"]
16
+ STRIDE = 2
17
+
18
+
19
+ def generate_heatmap_gt(
20
+ coords_6nm: np.ndarray,
21
+ coords_12nm: np.ndarray,
22
+ image_h: int,
23
+ image_w: int,
24
+ sigmas: Optional[Dict[str, float]] = None,
25
+ stride: int = STRIDE,
26
+ confidence_6nm: Optional[np.ndarray] = None,
27
+ confidence_12nm: Optional[np.ndarray] = None,
28
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
29
+ """
30
+ Generate CenterNet ground truth heatmaps and offset maps.
31
+
32
+ Args:
33
+ coords_6nm: (N, 2) array of (x, y) in ORIGINAL pixel space
34
+ coords_12nm: (M, 2) array of (x, y) in ORIGINAL pixel space
35
+ image_h: original image height
36
+ image_w: original image width
37
+ sigmas: per-class Gaussian sigma in feature space
38
+ stride: output stride (default 2)
39
+ confidence_6nm: optional per-particle confidence weights
40
+ confidence_12nm: optional per-particle confidence weights
41
+
42
+ Returns:
43
+ heatmap: (2, H//stride, W//stride) float32 in [0, 1]
44
+ offsets: (2, H//stride, W//stride) float32 sub-pixel offsets
45
+ offset_mask: (H//stride, W//stride) bool — True at particle centers
46
+ conf_map: (2, H//stride, W//stride) float32 confidence weights
47
+ """
48
+ if sigmas is None:
49
+ sigmas = {"6nm": 1.0, "12nm": 1.5}
50
+
51
+ h_feat = image_h // stride
52
+ w_feat = image_w // stride
53
+
54
+ heatmap = np.zeros((2, h_feat, w_feat), dtype=np.float32)
55
+ offsets = np.zeros((2, h_feat, w_feat), dtype=np.float32)
56
+ offset_mask = np.zeros((h_feat, w_feat), dtype=bool)
57
+ conf_map = np.ones((2, h_feat, w_feat), dtype=np.float32)
58
+
59
+ # Prepare coordinate lists with class labels and confidences
60
+ all_entries = []
61
+ if len(coords_6nm) > 0:
62
+ confs = confidence_6nm if confidence_6nm is not None else np.ones(len(coords_6nm))
63
+ for i, (x, y) in enumerate(coords_6nm):
64
+ all_entries.append((x, y, "6nm", confs[i]))
65
+ if len(coords_12nm) > 0:
66
+ confs = confidence_12nm if confidence_12nm is not None else np.ones(len(coords_12nm))
67
+ for i, (x, y) in enumerate(coords_12nm):
68
+ all_entries.append((x, y, "12nm", confs[i]))
69
+
70
+ for x, y, cls, conf in all_entries:
71
+ cidx = CLASS_IDX[cls]
72
+ sigma = sigmas[cls]
73
+
74
+ # Feature-space center (float)
75
+ cx_f = x / stride
76
+ cy_f = y / stride
77
+
78
+ # Integer grid center
79
+ cx_i = int(round(cx_f))
80
+ cy_i = int(round(cy_f))
81
+
82
+ # Sub-pixel offset
83
+ off_x = cx_f - cx_i
84
+ off_y = cy_f - cy_i
85
+
86
+ # Gaussian radius: truncate at 3 sigma
87
+ r = max(int(3 * sigma + 1), 2)
88
+
89
+ # Bounds-clipped grid
90
+ y0 = max(0, cy_i - r)
91
+ y1 = min(h_feat, cy_i + r + 1)
92
+ x0 = max(0, cx_i - r)
93
+ x1 = min(w_feat, cx_i + r + 1)
94
+
95
+ if y0 >= y1 or x0 >= x1:
96
+ continue
97
+
98
+ yy, xx = np.meshgrid(
99
+ np.arange(y0, y1),
100
+ np.arange(x0, x1),
101
+ indexing="ij",
102
+ )
103
+
104
+ # Gaussian centered at INTEGER center (not float)
105
+ # The integer center MUST be exactly 1.0 — the CornerNet focal loss
106
+ # uses pos_mask = (gt == 1.0) and treats everything else as negative.
107
+ # Centering the Gaussian at the float position produces peaks of 0.78-0.93
108
+ # which the loss sees as negatives → zero positive training signal.
109
+ gaussian = np.exp(
110
+ -((xx - cx_i) ** 2 + (yy - cy_i) ** 2) / (2 * sigma ** 2)
111
+ )
112
+
113
+ # Scale by confidence (for pseudo-label weighting)
114
+ gaussian = gaussian * conf
115
+
116
+ # Element-wise max (handles overlapping particles correctly)
117
+ heatmap[cidx, y0:y1, x0:x1] = np.maximum(
118
+ heatmap[cidx, y0:y1, x0:x1], gaussian
119
+ )
120
+
121
+ # Offset and confidence only at the integer center pixel
122
+ if 0 <= cy_i < h_feat and 0 <= cx_i < w_feat:
123
+ offsets[0, cy_i, cx_i] = off_x
124
+ offsets[1, cy_i, cx_i] = off_y
125
+ offset_mask[cy_i, cx_i] = True
126
+ conf_map[cidx, cy_i, cx_i] = conf
127
+
128
+ return heatmap, offsets, offset_mask, conf_map
129
+
130
+
131
+ def extract_peaks(
132
+ heatmap: torch.Tensor,
133
+ offset_map: torch.Tensor,
134
+ stride: int = STRIDE,
135
+ conf_threshold: float = 0.3,
136
+ nms_kernel_sizes: Optional[Dict[str, int]] = None,
137
+ ) -> List[dict]:
138
+ """
139
+ Extract detections from predicted heatmap via max-pool NMS.
140
+
141
+ Args:
142
+ heatmap: (2, H/stride, W/stride) sigmoid-activated
143
+ offset_map: (2, H/stride, W/stride) raw offset predictions
144
+ stride: output stride
145
+ conf_threshold: minimum confidence to keep
146
+ nms_kernel_sizes: per-class NMS kernel sizes
147
+
148
+ Returns:
149
+ List of {'x': float, 'y': float, 'class': str, 'conf': float}
150
+ """
151
+ if nms_kernel_sizes is None:
152
+ nms_kernel_sizes = {"6nm": 3, "12nm": 5}
153
+
154
+ detections = []
155
+
156
+ for cls_idx, cls_name in enumerate(CLASS_NAMES):
157
+ hm_cls = heatmap[cls_idx].unsqueeze(0).unsqueeze(0) # (1,1,H,W)
158
+ kernel = nms_kernel_sizes[cls_name]
159
+
160
+ # Max-pool NMS
161
+ hmax = F.max_pool2d(
162
+ hm_cls, kernel_size=kernel, stride=1, padding=kernel // 2
163
+ )
164
+ peaks = (hmax.squeeze() == heatmap[cls_idx]) & (
165
+ heatmap[cls_idx] > conf_threshold
166
+ )
167
+
168
+ ys, xs = torch.where(peaks)
169
+ for y_idx, x_idx in zip(ys, xs):
170
+ y_i = y_idx.item()
171
+ x_i = x_idx.item()
172
+ conf = heatmap[cls_idx, y_i, x_i].item()
173
+ dx = offset_map[0, y_i, x_i].item()
174
+ dy = offset_map[1, y_i, x_i].item()
175
+
176
+ # Back to input space with sub-pixel offset
177
+ det_x = (x_i + dx) * stride
178
+ det_y = (y_i + dy) * stride
179
+
180
+ detections.append({
181
+ "x": det_x,
182
+ "y": det_y,
183
+ "class": cls_name,
184
+ "conf": conf,
185
+ })
186
+
187
+ return detections
tests/__init__.py ADDED
File without changes
tests/test_evaluate.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unit tests for evaluation matching and metrics."""
2
+
3
+ import numpy as np
4
+ import pytest
5
+
6
+ from src.evaluate import match_detections_to_gt, compute_f1, compute_average_precision
7
+
8
+
9
+ class TestComputeF1:
10
+ def test_perfect_score(self):
11
+ f1, p, r = compute_f1(10, 0, 0)
12
+ assert f1 == pytest.approx(1.0, abs=0.001)
13
+ assert p == pytest.approx(1.0, abs=0.001)
14
+ assert r == pytest.approx(1.0, abs=0.001)
15
+
16
+ def test_zero_detections(self):
17
+ f1, p, r = compute_f1(0, 0, 10)
18
+ assert f1 == pytest.approx(0.0, abs=0.01)
19
+ assert r == pytest.approx(0.0, abs=0.01)
20
+
21
+ def test_all_false_positives(self):
22
+ f1, p, r = compute_f1(0, 10, 0)
23
+ assert p == pytest.approx(0.0, abs=0.01)
24
+
25
+
26
+ class TestMatchDetections:
27
+ def test_perfect_matching(self):
28
+ """Detections at exact GT locations should all match."""
29
+ gt_6nm = np.array([[100.0, 100.0], [200.0, 200.0]])
30
+ gt_12nm = np.array([[300.0, 300.0]])
31
+
32
+ dets = [
33
+ {"x": 100.0, "y": 100.0, "class": "6nm", "conf": 0.9},
34
+ {"x": 200.0, "y": 200.0, "class": "6nm", "conf": 0.8},
35
+ {"x": 300.0, "y": 300.0, "class": "12nm", "conf": 0.7},
36
+ ]
37
+
38
+ results = match_detections_to_gt(dets, gt_6nm, gt_12nm)
39
+ assert results["6nm"]["tp"] == 2
40
+ assert results["6nm"]["fp"] == 0
41
+ assert results["6nm"]["fn"] == 0
42
+ assert results["12nm"]["tp"] == 1
43
+
44
+ def test_wrong_class_no_match(self):
45
+ """Detection near GT but wrong class should not match."""
46
+ gt_6nm = np.array([[100.0, 100.0]])
47
+ gt_12nm = np.empty((0, 2))
48
+
49
+ dets = [
50
+ {"x": 100.0, "y": 100.0, "class": "12nm", "conf": 0.9},
51
+ ]
52
+
53
+ results = match_detections_to_gt(dets, gt_6nm, gt_12nm)
54
+ assert results["6nm"]["fn"] == 1 # missed
55
+ assert results["12nm"]["fp"] == 1 # false positive
56
+
57
+ def test_beyond_radius_no_match(self):
58
+ """Detection beyond match radius should not match."""
59
+ gt_6nm = np.array([[100.0, 100.0]])
60
+ gt_12nm = np.empty((0, 2))
61
+
62
+ dets = [
63
+ {"x": 120.0, "y": 100.0, "class": "6nm", "conf": 0.9}, # 20px away > 9px radius
64
+ ]
65
+
66
+ results = match_detections_to_gt(
67
+ dets, gt_6nm, gt_12nm, match_radii={"6nm": 9.0, "12nm": 15.0}
68
+ )
69
+ assert results["6nm"]["tp"] == 0
70
+ assert results["6nm"]["fp"] == 1
71
+ assert results["6nm"]["fn"] == 1
72
+
73
+ def test_within_radius_matches(self):
74
+ """Detection within match radius should match."""
75
+ gt_6nm = np.array([[100.0, 100.0]])
76
+ gt_12nm = np.empty((0, 2))
77
+
78
+ dets = [
79
+ {"x": 105.0, "y": 100.0, "class": "6nm", "conf": 0.9}, # 5px away < 9px
80
+ ]
81
+
82
+ results = match_detections_to_gt(
83
+ dets, gt_6nm, gt_12nm, match_radii={"6nm": 9.0, "12nm": 15.0}
84
+ )
85
+ assert results["6nm"]["tp"] == 1
86
+
87
+ def test_no_detections(self):
88
+ """No detections: all GT are false negatives."""
89
+ gt_6nm = np.array([[100.0, 100.0], [200.0, 200.0]])
90
+ results = match_detections_to_gt([], gt_6nm, np.empty((0, 2)))
91
+ assert results["6nm"]["fn"] == 2
92
+ assert results["6nm"]["f1"] == pytest.approx(0.0, abs=0.01)
93
+
94
+ def test_no_ground_truth(self):
95
+ """No GT: all detections are false positives."""
96
+ dets = [{"x": 100.0, "y": 100.0, "class": "6nm", "conf": 0.9}]
97
+ results = match_detections_to_gt(dets, np.empty((0, 2)), np.empty((0, 2)))
98
+ assert results["6nm"]["fp"] == 1
99
+
100
+
101
+ class TestAveragePrecision:
102
+ def test_perfect_ap(self):
103
+ """All detections match in rank order → AP = 1.0."""
104
+ gt = np.array([[100.0, 100.0], [200.0, 200.0]])
105
+ dets = [
106
+ {"x": 100.0, "y": 100.0, "class": "6nm", "conf": 0.9},
107
+ {"x": 200.0, "y": 200.0, "class": "6nm", "conf": 0.8},
108
+ ]
109
+ ap = compute_average_precision(dets, gt, match_radius=9.0)
110
+ assert ap == pytest.approx(1.0, abs=0.01)
tests/test_heatmap.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unit tests for heatmap GT generation and peak extraction."""
2
+
3
+ import numpy as np
4
+ import pytest
5
+ import torch
6
+
7
+ from src.heatmap import generate_heatmap_gt, extract_peaks
8
+
9
+
10
+ class TestHeatmapGeneration:
11
+ def test_single_particle_peak(self):
12
+ """A single particle should produce a Gaussian peak at the correct location."""
13
+ coords_6nm = np.array([[100.0, 200.0]])
14
+ coords_12nm = np.empty((0, 2))
15
+
16
+ hm, off, mask, conf = generate_heatmap_gt(
17
+ coords_6nm, coords_12nm, 512, 512, stride=2,
18
+ )
19
+
20
+ assert hm.shape == (2, 256, 256)
21
+ assert off.shape == (2, 256, 256)
22
+ assert mask.shape == (256, 256)
23
+
24
+ # Peak should be at (50, 100) in stride-2 space
25
+ peak_y, peak_x = np.unravel_index(hm[0].argmax(), hm[0].shape)
26
+ assert abs(peak_x - 50) <= 1
27
+ assert abs(peak_y - 100) <= 1
28
+
29
+ # Peak value should be 1.0 (confidence=1.0 default)
30
+ assert hm[0].max() == pytest.approx(1.0, abs=0.01)
31
+
32
+ # 12nm channel should be empty
33
+ assert hm[1].max() == 0.0
34
+
35
+ def test_two_classes(self):
36
+ """Both classes should produce peaks in their respective channels."""
37
+ coords_6nm = np.array([[100.0, 100.0]])
38
+ coords_12nm = np.array([[200.0, 200.0]])
39
+
40
+ hm, _, _, _ = generate_heatmap_gt(
41
+ coords_6nm, coords_12nm, 512, 512, stride=2,
42
+ )
43
+
44
+ assert hm[0].max() > 0.9 # 6nm channel has peak
45
+ assert hm[1].max() > 0.9 # 12nm channel has peak
46
+
47
+ def test_offset_values(self):
48
+ """Offsets should encode sub-pixel correction."""
49
+ # Place particle at (101.5, 200.5) → stride-2 center at (50.75, 100.25)
50
+ # Integer center: (51, 100) → offset: (-0.25, 0.25)
51
+ coords_6nm = np.array([[101.5, 200.5]])
52
+ coords_12nm = np.empty((0, 2))
53
+
54
+ _, off, mask, _ = generate_heatmap_gt(
55
+ coords_6nm, coords_12nm, 512, 512, stride=2,
56
+ )
57
+
58
+ # Mask should have exactly one True pixel
59
+ assert mask.sum() == 1
60
+
61
+ def test_empty_annotations(self):
62
+ """Empty annotations should produce zero heatmap."""
63
+ hm, off, mask, conf = generate_heatmap_gt(
64
+ np.empty((0, 2)), np.empty((0, 2)), 512, 512,
65
+ )
66
+ assert hm.max() == 0.0
67
+ assert mask.sum() == 0
68
+
69
+ def test_confidence_weighting(self):
70
+ """Confidence < 1 should scale peak value."""
71
+ coords = np.array([[100.0, 100.0]])
72
+ confidences = np.array([0.5])
73
+
74
+ hm, _, _, _ = generate_heatmap_gt(
75
+ coords, np.empty((0, 2)), 512, 512,
76
+ confidence_6nm=confidences,
77
+ )
78
+
79
+ assert hm[0].max() == pytest.approx(0.5, abs=0.05)
80
+
81
+ def test_overlapping_particles_use_max(self):
82
+ """Overlapping Gaussians should use element-wise max, not sum."""
83
+ coords = np.array([[100.0, 100.0], [104.0, 100.0]]) # close together
84
+ hm, _, _, _ = generate_heatmap_gt(
85
+ coords, np.empty((0, 2)), 512, 512, stride=2,
86
+ )
87
+ # Max should be 1.0, not >1.0
88
+ assert hm[0].max() <= 1.0
89
+
90
+
91
+ class TestPeakExtraction:
92
+ def test_single_peak(self):
93
+ """Extract a single peak from synthetic heatmap."""
94
+ heatmap = torch.zeros(2, 256, 256)
95
+ heatmap[0, 100, 50] = 0.9 # 6nm peak
96
+
97
+ offset_map = torch.zeros(2, 256, 256)
98
+ offset_map[0, 100, 50] = 0.3 # dx
99
+ offset_map[1, 100, 50] = 0.1 # dy
100
+
101
+ dets = extract_peaks(heatmap, offset_map, stride=2, conf_threshold=0.5)
102
+
103
+ assert len(dets) == 1
104
+ assert dets[0]["class"] == "6nm"
105
+ assert dets[0]["conf"] == pytest.approx(0.9, abs=0.01)
106
+ # x = (50 + 0.3) * 2 = 100.6
107
+ assert dets[0]["x"] == pytest.approx(100.6, abs=0.1)
108
+ # y = (100 + 0.1) * 2 = 200.2
109
+ assert dets[0]["y"] == pytest.approx(200.2, abs=0.1)
110
+
111
+ def test_nms_suppresses_neighbors(self):
112
+ """NMS should suppress weaker neighboring peaks."""
113
+ heatmap = torch.zeros(2, 256, 256)
114
+ heatmap[0, 100, 50] = 0.9 # strong
115
+ heatmap[0, 101, 50] = 0.7 # weaker neighbor (within NMS kernel)
116
+
117
+ dets = extract_peaks(
118
+ heatmap, torch.zeros(2, 256, 256),
119
+ stride=2, conf_threshold=0.5,
120
+ nms_kernel_sizes={"6nm": 5, "12nm": 5},
121
+ )
122
+
123
+ # Only the stronger peak should survive
124
+ assert len([d for d in dets if d["class"] == "6nm"]) == 1
125
+
126
+ def test_below_threshold_filtered(self):
127
+ """Peaks below threshold should not be extracted."""
128
+ heatmap = torch.zeros(2, 256, 256)
129
+ heatmap[0, 100, 50] = 0.2 # below 0.3 threshold
130
+
131
+ dets = extract_peaks(heatmap, torch.zeros(2, 256, 256), conf_threshold=0.3)
132
+ assert len(dets) == 0
tests/test_loss.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unit tests for loss functions."""
2
+
3
+ import pytest
4
+ import torch
5
+
6
+ from src.loss import cornernet_focal_loss, offset_loss, total_loss
7
+
8
+
9
+ class TestCornerNetFocalLoss:
10
+ def test_perfect_prediction_zero_loss(self):
11
+ """Perfect predictions should produce near-zero loss."""
12
+ gt = torch.zeros(1, 2, 64, 64)
13
+ gt[0, 0, 32, 32] = 1.0 # one particle
14
+
15
+ # Near-perfect prediction
16
+ pred = torch.zeros(1, 2, 64, 64) + 1e-6
17
+ pred[0, 0, 32, 32] = 1.0 - 1e-6
18
+
19
+ loss = cornernet_focal_loss(pred, gt)
20
+ assert loss.item() < 0.1
21
+
22
+ def test_all_zeros_prediction_nonzero_loss(self):
23
+ """Predicting all zeros when particles exist should give positive loss."""
24
+ gt = torch.zeros(1, 2, 64, 64)
25
+ gt[0, 0, 32, 32] = 1.0
26
+
27
+ pred = torch.zeros(1, 2, 64, 64) + 1e-6
28
+ loss = cornernet_focal_loss(pred, gt)
29
+ assert loss.item() > 0
30
+
31
+ def test_high_false_positive_penalized(self):
32
+ """Predicting high confidence where GT is zero should be penalized."""
33
+ gt = torch.zeros(1, 2, 64, 64)
34
+ pred_low_fp = torch.zeros(1, 2, 64, 64) + 0.01
35
+ pred_high_fp = torch.zeros(1, 2, 64, 64) + 0.9
36
+
37
+ loss_low = cornernet_focal_loss(pred_low_fp, gt)
38
+ loss_high = cornernet_focal_loss(pred_high_fp, gt)
39
+
40
+ assert loss_high.item() > loss_low.item()
41
+
42
+ def test_near_peak_reduced_penalty(self):
43
+ """Pixels near GT peaks should have reduced negative penalty via beta term."""
44
+ gt = torch.zeros(1, 2, 64, 64)
45
+ gt[0, 0, 32, 32] = 1.0
46
+ gt[0, 0, 31, 32] = 0.8 # nearby pixel with Gaussian falloff
47
+
48
+ # Moderate prediction near peak should have low loss
49
+ pred = torch.zeros(1, 2, 64, 64) + 0.01
50
+ pred[0, 0, 31, 32] = 0.5
51
+
52
+ loss = cornernet_focal_loss(pred, gt)
53
+ # Should be a reasonable value, not extremely high
54
+ assert loss.item() < 10
55
+
56
+ def test_confidence_weighting(self):
57
+ """Confidence weights should scale the loss."""
58
+ gt = torch.zeros(1, 2, 64, 64)
59
+ gt[0, 0, 32, 32] = 1.0
60
+ pred = torch.zeros(1, 2, 64, 64) + 0.5
61
+
62
+ weights_full = torch.ones(1, 2, 64, 64)
63
+ weights_half = torch.ones(1, 2, 64, 64) * 0.5
64
+
65
+ loss_full = cornernet_focal_loss(pred, gt, conf_weights=weights_full)
66
+ loss_half = cornernet_focal_loss(pred, gt, conf_weights=weights_half)
67
+
68
+ # Half weights should produce lower loss
69
+ assert loss_half.item() < loss_full.item()
70
+
71
+
72
+ class TestOffsetLoss:
73
+ def test_zero_when_no_particles(self):
74
+ """Offset loss should be zero when mask is empty."""
75
+ pred = torch.randn(1, 2, 64, 64)
76
+ gt = torch.zeros(1, 2, 64, 64)
77
+ mask = torch.zeros(1, 64, 64, dtype=torch.bool)
78
+
79
+ loss = offset_loss(pred, gt, mask)
80
+ assert loss.item() == 0.0
81
+
82
+ def test_nonzero_with_particles(self):
83
+ """Offset loss should be nonzero when predictions differ from GT."""
84
+ pred = torch.randn(1, 2, 64, 64)
85
+ gt = torch.zeros(1, 2, 64, 64)
86
+ mask = torch.zeros(1, 64, 64, dtype=torch.bool)
87
+ mask[0, 32, 32] = True
88
+
89
+ loss = offset_loss(pred, gt, mask)
90
+ assert loss.item() > 0
91
+
92
+
93
+ class TestTotalLoss:
94
+ def test_returns_three_values(self):
95
+ """total_loss should return (total, hm_loss, off_loss)."""
96
+ hm_pred = torch.sigmoid(torch.randn(1, 2, 64, 64))
97
+ hm_gt = torch.zeros(1, 2, 64, 64)
98
+ off_pred = torch.randn(1, 2, 64, 64)
99
+ off_gt = torch.zeros(1, 2, 64, 64)
100
+ mask = torch.zeros(1, 64, 64, dtype=torch.bool)
101
+
102
+ total, hm_val, off_val = total_loss(
103
+ hm_pred, hm_gt, off_pred, off_gt, mask,
104
+ )
105
+
106
+ assert isinstance(total, torch.Tensor)
107
+ assert isinstance(hm_val, float)
108
+ assert isinstance(off_val, float)
109
+ assert total.requires_grad
tests/test_model.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unit tests for model architecture."""
2
+
3
+ import pytest
4
+ import torch
5
+
6
+ from src.model import ImmunogoldCenterNet, BiFPN
7
+
8
+
9
+ class TestModelForwardPass:
10
+ def test_output_shapes(self):
11
+ """Verify output shapes match stride-2 specification."""
12
+ model = ImmunogoldCenterNet(pretrained_path=None)
13
+ x = torch.randn(1, 1, 512, 512)
14
+ hm, off = model(x)
15
+
16
+ assert hm.shape == (1, 2, 256, 256), f"Expected (1,2,256,256), got {hm.shape}"
17
+ assert off.shape == (1, 2, 256, 256), f"Expected (1,2,256,256), got {off.shape}"
18
+
19
+ def test_heatmap_sigmoid_range(self):
20
+ """Heatmap outputs should be in [0, 1] from sigmoid."""
21
+ model = ImmunogoldCenterNet(pretrained_path=None)
22
+ x = torch.randn(1, 1, 512, 512)
23
+ hm, _ = model(x)
24
+
25
+ assert hm.min() >= 0.0
26
+ assert hm.max() <= 1.0
27
+
28
+ def test_batch_dimension(self):
29
+ """Model should handle batch size > 1."""
30
+ model = ImmunogoldCenterNet(pretrained_path=None)
31
+ x = torch.randn(4, 1, 512, 512)
32
+ hm, off = model(x)
33
+
34
+ assert hm.shape[0] == 4
35
+ assert off.shape[0] == 4
36
+
37
+ def test_variable_input_size(self):
38
+ """Model should handle different input sizes (multiples of 32)."""
39
+ model = ImmunogoldCenterNet(pretrained_path=None)
40
+
41
+ for size in [256, 384, 512]:
42
+ x = torch.randn(1, 1, size, size)
43
+ hm, off = model(x)
44
+ assert hm.shape == (1, 2, size // 2, size // 2)
45
+
46
+ def test_parameter_count(self):
47
+ """Model should have approximately 25M parameters."""
48
+ model = ImmunogoldCenterNet(pretrained_path=None)
49
+ n_params = sum(p.numel() for p in model.parameters())
50
+ # ResNet-50 is ~25M, plus BiFPN and heads
51
+ assert 20_000_000 < n_params < 40_000_000
52
+
53
+
54
+ class TestFreezeUnfreeze:
55
+ def test_freeze_encoder(self):
56
+ """Frozen encoder should have no gradients."""
57
+ model = ImmunogoldCenterNet(pretrained_path=None)
58
+ model.freeze_encoder()
59
+
60
+ for name, param in model.named_parameters():
61
+ if any(x in name for x in ["stem", "layer1", "layer2", "layer3", "layer4"]):
62
+ assert not param.requires_grad, f"{name} should be frozen"
63
+
64
+ # BiFPN and heads should still be trainable
65
+ for name, param in model.bifpn.named_parameters():
66
+ assert param.requires_grad, f"bifpn.{name} should be trainable"
67
+
68
+ def test_unfreeze_deep(self):
69
+ """Unfreezing deep layers should enable gradients for layer3/4."""
70
+ model = ImmunogoldCenterNet(pretrained_path=None)
71
+ model.freeze_encoder()
72
+ model.unfreeze_deep_layers()
73
+
74
+ for param in model.layer3.parameters():
75
+ assert param.requires_grad
76
+ for param in model.layer4.parameters():
77
+ assert param.requires_grad
78
+ # Stem and layer1/2 still frozen
79
+ for param in model.stem.parameters():
80
+ assert not param.requires_grad
81
+
82
+ def test_unfreeze_all(self):
83
+ """Unfreeze all should enable all gradients."""
84
+ model = ImmunogoldCenterNet(pretrained_path=None)
85
+ model.freeze_encoder()
86
+ model.unfreeze_all()
87
+
88
+ for param in model.parameters():
89
+ assert param.requires_grad
90
+
91
+
92
+ class TestBiFPN:
93
+ def test_bifpn_output_shapes(self):
94
+ """BiFPN should output 4 feature maps at 128 channels."""
95
+ bifpn = BiFPN(
96
+ in_channels=[256, 512, 1024, 2048],
97
+ out_channels=128,
98
+ num_rounds=2,
99
+ )
100
+ features = [
101
+ torch.randn(1, 256, 128, 128), # P2: stride 4
102
+ torch.randn(1, 512, 64, 64), # P3: stride 8
103
+ torch.randn(1, 1024, 32, 32), # P4: stride 16
104
+ torch.randn(1, 2048, 16, 16), # P5: stride 32
105
+ ]
106
+
107
+ outputs = bifpn(features)
108
+ assert len(outputs) == 4
109
+ for i, out in enumerate(outputs):
110
+ assert out.shape[1] == 128, f"P{i+2} channels should be 128"
111
+ assert out.shape[2:] == features[i].shape[2:], \
112
+ f"P{i+2} spatial dims should match input"