Sync all source code, docs, and configs
Browse files- scripts/verify_data.py +94 -0
- slurm/02_train_single_fold.sh +40 -0
- slurm/03_evaluate_ensemble.sh +39 -0
- slurm/04_full_pipeline.sh +40 -0
- slurm/05_train_final.sh +27 -0
- src/heatmap.py +187 -0
- tests/__init__.py +0 -0
- tests/test_evaluate.py +110 -0
- tests/test_heatmap.py +132 -0
- tests/test_loss.py +109 -0
- tests/test_model.py +112 -0
scripts/verify_data.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data verification script: loads all images and annotations,
|
| 3 |
+
validates counts, and saves visual overlays.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
python scripts/verify_data.py --config config/config.yaml
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import argparse
|
| 10 |
+
import sys
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import yaml
|
| 15 |
+
|
| 16 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 17 |
+
|
| 18 |
+
from src.preprocessing import discover_synapse_data, load_synapse
|
| 19 |
+
from src.visualize import overlay_annotations
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def main():
|
| 23 |
+
parser = argparse.ArgumentParser()
|
| 24 |
+
parser.add_argument("--config", default="config/config.yaml")
|
| 25 |
+
args = parser.parse_args()
|
| 26 |
+
|
| 27 |
+
with open(args.config) as f:
|
| 28 |
+
cfg = yaml.safe_load(f)
|
| 29 |
+
|
| 30 |
+
print("=" * 60)
|
| 31 |
+
print("Immunogold Data Verification")
|
| 32 |
+
print("=" * 60)
|
| 33 |
+
|
| 34 |
+
records = discover_synapse_data(cfg["data"]["root"], cfg["data"]["synapse_ids"])
|
| 35 |
+
|
| 36 |
+
total_6nm = 0
|
| 37 |
+
total_12nm = 0
|
| 38 |
+
output_dir = Path("results/verification")
|
| 39 |
+
|
| 40 |
+
for record in records:
|
| 41 |
+
print(f"\n--- {record.synapse_id} ---")
|
| 42 |
+
print(f" Image: {record.image_path.name}")
|
| 43 |
+
print(f" Mask: {record.mask_path.name if record.mask_path else 'NONE'}")
|
| 44 |
+
print(f" 6nm CSVs: {[p.name for p in record.csv_6nm_paths]}")
|
| 45 |
+
print(f" 12nm CSVs: {[p.name for p in record.csv_12nm_paths]}")
|
| 46 |
+
|
| 47 |
+
data = load_synapse(record)
|
| 48 |
+
img = data["image"]
|
| 49 |
+
annots = data["annotations"]
|
| 50 |
+
|
| 51 |
+
n6 = len(annots["6nm"])
|
| 52 |
+
n12 = len(annots["12nm"])
|
| 53 |
+
total_6nm += n6
|
| 54 |
+
total_12nm += n12
|
| 55 |
+
|
| 56 |
+
print(f" Image shape: {img.shape}")
|
| 57 |
+
print(f" 6nm particles: {n6}")
|
| 58 |
+
print(f" 12nm particles: {n12}")
|
| 59 |
+
|
| 60 |
+
if data["mask"] is not None:
|
| 61 |
+
# Check how many particles fall within mask
|
| 62 |
+
mask = data["mask"]
|
| 63 |
+
for cls, coords in annots.items():
|
| 64 |
+
if len(coords) == 0:
|
| 65 |
+
continue
|
| 66 |
+
inside = sum(
|
| 67 |
+
1 for x, y in coords
|
| 68 |
+
if 0 <= int(y) < mask.shape[0] and
|
| 69 |
+
0 <= int(x) < mask.shape[1] and
|
| 70 |
+
mask[int(y), int(x)]
|
| 71 |
+
)
|
| 72 |
+
print(f" {cls} in mask: {inside}/{len(coords)}")
|
| 73 |
+
|
| 74 |
+
# Save overlay
|
| 75 |
+
overlay_annotations(
|
| 76 |
+
img, annots,
|
| 77 |
+
title=f"{record.synapse_id}: {n6} 6nm, {n12} 12nm",
|
| 78 |
+
save_path=output_dir / f"{record.synapse_id}_annotations.png",
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
print(f"\n{'=' * 60}")
|
| 82 |
+
print(f"TOTAL: {total_6nm} 6nm + {total_12nm} 12nm = {total_6nm + total_12nm}")
|
| 83 |
+
print(f"Expected: 403 6nm + 50 12nm = 453")
|
| 84 |
+
|
| 85 |
+
if total_6nm + total_12nm >= 400:
|
| 86 |
+
print("PASS: Particle counts look reasonable")
|
| 87 |
+
else:
|
| 88 |
+
print("WARNING: Total count is lower than expected")
|
| 89 |
+
|
| 90 |
+
print(f"\nOverlays saved to: {output_dir}")
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
if __name__ == "__main__":
|
| 94 |
+
main()
|
slurm/02_train_single_fold.sh
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --job-name=immunogold_train
|
| 3 |
+
#SBATCH --partition=gpu # TODO: adjust to your GPU partition
|
| 4 |
+
#SBATCH --gres=gpu:1 # single A100 or V100
|
| 5 |
+
#SBATCH --nodes=1
|
| 6 |
+
#SBATCH --ntasks=1
|
| 7 |
+
#SBATCH --cpus-per-task=8
|
| 8 |
+
#SBATCH --mem=80G
|
| 9 |
+
#SBATCH --time=12:00:00
|
| 10 |
+
#SBATCH --array=0-49 # 10 folds x 5 seeds = 50 jobs
|
| 11 |
+
#SBATCH --output=logs/train_%A_%a.out
|
| 12 |
+
#SBATCH --error=logs/train_%A_%a.err
|
| 13 |
+
|
| 14 |
+
set -euo pipefail
|
| 15 |
+
mkdir -p logs
|
| 16 |
+
|
| 17 |
+
eval "$(conda shell.bash hook)"
|
| 18 |
+
conda activate immunogold
|
| 19 |
+
|
| 20 |
+
# Map array task ID to fold and seed
|
| 21 |
+
SYNAPSE_IDS=("S1" "S4" "S7" "S8" "S13" "S15" "S22" "S25" "S27" "S29")
|
| 22 |
+
FOLD_IDX=$((SLURM_ARRAY_TASK_ID / 5))
|
| 23 |
+
SEED_IDX=$((SLURM_ARRAY_TASK_ID % 5))
|
| 24 |
+
SEED=$((SEED_IDX + 42))
|
| 25 |
+
FOLD_NAME=${SYNAPSE_IDS[$FOLD_IDX]}
|
| 26 |
+
|
| 27 |
+
echo "=== Training immunogold CenterNet ==="
|
| 28 |
+
echo "Date: $(date)"
|
| 29 |
+
echo "Array task: ${SLURM_ARRAY_TASK_ID}"
|
| 30 |
+
echo "Fold: ${FOLD_NAME} (idx ${FOLD_IDX})"
|
| 31 |
+
echo "Seed: ${SEED} (idx ${SEED_IDX})"
|
| 32 |
+
echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader 2>/dev/null || echo 'N/A')"
|
| 33 |
+
|
| 34 |
+
python train.py \
|
| 35 |
+
--fold "${FOLD_NAME}" \
|
| 36 |
+
--seed "${SEED}" \
|
| 37 |
+
--config config/config.yaml \
|
| 38 |
+
--device cuda:0
|
| 39 |
+
|
| 40 |
+
echo "=== Training complete for fold=${FOLD_NAME} seed=${SEED} ==="
|
slurm/03_evaluate_ensemble.sh
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --job-name=immunogold_eval
|
| 3 |
+
#SBATCH --partition=gpu # TODO: adjust
|
| 4 |
+
#SBATCH --gres=gpu:1 # single GPU per fold (model loaded sequentially)
|
| 5 |
+
#SBATCH --nodes=1
|
| 6 |
+
#SBATCH --ntasks=1
|
| 7 |
+
#SBATCH --cpus-per-task=8
|
| 8 |
+
#SBATCH --mem=64G
|
| 9 |
+
#SBATCH --time=04:00:00
|
| 10 |
+
#SBATCH --array=0-9 # one job per test fold
|
| 11 |
+
#SBATCH --output=logs/eval_%A_%a.out
|
| 12 |
+
#SBATCH --error=logs/eval_%A_%a.err
|
| 13 |
+
|
| 14 |
+
set -euo pipefail
|
| 15 |
+
mkdir -p logs
|
| 16 |
+
|
| 17 |
+
eval "$(conda shell.bash hook)"
|
| 18 |
+
conda activate immunogold
|
| 19 |
+
|
| 20 |
+
SYNAPSE_IDS=("S1" "S4" "S7" "S8" "S13" "S15" "S22" "S25" "S27" "S29")
|
| 21 |
+
FOLD_NAME=${SYNAPSE_IDS[$SLURM_ARRAY_TASK_ID]}
|
| 22 |
+
|
| 23 |
+
echo "=== Ensemble evaluation ==="
|
| 24 |
+
echo "Date: $(date)"
|
| 25 |
+
echo "Test fold: ${FOLD_NAME}"
|
| 26 |
+
echo "Array task: ${SLURM_ARRAY_TASK_ID}"
|
| 27 |
+
echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader 2>/dev/null || echo 'N/A')"
|
| 28 |
+
|
| 29 |
+
# Evaluate this single fold using the same script as local evaluation
|
| 30 |
+
# evaluate_loocv.py handles loading all ensemble members per fold
|
| 31 |
+
python evaluate_loocv.py \
|
| 32 |
+
--config config/config.yaml \
|
| 33 |
+
--ensemble-dir checkpoints \
|
| 34 |
+
--device cuda:0 \
|
| 35 |
+
--use-tta \
|
| 36 |
+
--fold "${FOLD_NAME}" \
|
| 37 |
+
--output "results/per_fold_predictions/${FOLD_NAME}_metrics.csv"
|
| 38 |
+
|
| 39 |
+
echo "=== Evaluation complete for fold ${FOLD_NAME} ==="
|
slurm/04_full_pipeline.sh
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Master submission script for the full immunogold detection pipeline.
|
| 3 |
+
# Submits all stages with dependency chaining so they run sequentially.
|
| 4 |
+
#
|
| 5 |
+
# Usage:
|
| 6 |
+
# bash slurm/04_full_pipeline.sh
|
| 7 |
+
|
| 8 |
+
set -euo pipefail
|
| 9 |
+
|
| 10 |
+
echo "=== Immunogold Detection Pipeline ==="
|
| 11 |
+
echo "Date: $(date)"
|
| 12 |
+
echo "Submitting all pipeline stages..."
|
| 13 |
+
|
| 14 |
+
# Create log directory
|
| 15 |
+
mkdir -p logs
|
| 16 |
+
|
| 17 |
+
# Stage 0: Environment setup (single job)
|
| 18 |
+
JOB0=$(sbatch --parsable slurm/00_setup_env.sh)
|
| 19 |
+
echo "Stage 0 (setup): Job ${JOB0}"
|
| 20 |
+
|
| 21 |
+
# Stage 1: Training — 50 parallel jobs (depends on setup)
|
| 22 |
+
JOB1=$(sbatch --parsable --dependency=afterok:${JOB0} slurm/02_train_single_fold.sh)
|
| 23 |
+
echo "Stage 1 (training): Job ${JOB1} (array 0-49, depends on ${JOB0})"
|
| 24 |
+
|
| 25 |
+
# Stage 2: Ensemble evaluation — 10 parallel jobs (depends on ALL training tasks)
|
| 26 |
+
# afterok on an array job waits for ALL array tasks to complete successfully
|
| 27 |
+
JOB2=$(sbatch --parsable --dependency=afterok:${JOB1} slurm/03_evaluate_ensemble.sh)
|
| 28 |
+
echo "Stage 2 (evaluate): Job ${JOB2} (array 0-9, depends on ${JOB1})"
|
| 29 |
+
|
| 30 |
+
echo ""
|
| 31 |
+
echo "Pipeline submitted successfully!"
|
| 32 |
+
echo "Final results job: ${JOB2}"
|
| 33 |
+
echo ""
|
| 34 |
+
echo "Monitor with:"
|
| 35 |
+
echo " squeue -u \${USER}"
|
| 36 |
+
echo " sacct -j ${JOB0},${JOB1},${JOB2}"
|
| 37 |
+
echo ""
|
| 38 |
+
echo "Results will be in:"
|
| 39 |
+
echo " results/loocv_metrics.csv"
|
| 40 |
+
echo " results/per_fold_predictions/"
|
slurm/05_train_final.sh
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --job-name=immunogold_final
|
| 3 |
+
#SBATCH --partition=shortq7-gpu
|
| 4 |
+
#SBATCH --gres=gpu:1
|
| 5 |
+
#SBATCH --nodes=1
|
| 6 |
+
#SBATCH --ntasks=1
|
| 7 |
+
#SBATCH --cpus-per-task=8
|
| 8 |
+
#SBATCH --mem=32G
|
| 9 |
+
#SBATCH --time=06:00:00
|
| 10 |
+
#SBATCH --output=logs/train_final_%j.out
|
| 11 |
+
#SBATCH --error=logs/train_final_%j.err
|
| 12 |
+
|
| 13 |
+
set -euo pipefail
|
| 14 |
+
mkdir -p logs
|
| 15 |
+
|
| 16 |
+
module load miniconda3/24.3.0-gcc-13.2.0-rslr3to
|
| 17 |
+
module load cuda/12.4.0-gcc-13.2.0-bxjolrw
|
| 18 |
+
eval "$(conda shell.bash hook)"
|
| 19 |
+
conda activate immunogold
|
| 20 |
+
|
| 21 |
+
echo "=== Training final deployable model ==="
|
| 22 |
+
echo "Date: $(date)"
|
| 23 |
+
echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader 2>/dev/null)"
|
| 24 |
+
|
| 25 |
+
python train_final.py --config config/config.yaml --device cuda:0
|
| 26 |
+
|
| 27 |
+
echo "=== Final model training complete ==="
|
src/heatmap.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Ground truth heatmap generation and peak extraction for CenterNet.
|
| 3 |
+
|
| 4 |
+
Generates Gaussian-splat heatmaps at stride-2 resolution with
|
| 5 |
+
class-specific sigma values calibrated to bead size.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from typing import Dict, List, Tuple, Optional
|
| 12 |
+
|
| 13 |
+
# Class index mapping
|
| 14 |
+
CLASS_IDX = {"6nm": 0, "12nm": 1}
|
| 15 |
+
CLASS_NAMES = ["6nm", "12nm"]
|
| 16 |
+
STRIDE = 2
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def generate_heatmap_gt(
|
| 20 |
+
coords_6nm: np.ndarray,
|
| 21 |
+
coords_12nm: np.ndarray,
|
| 22 |
+
image_h: int,
|
| 23 |
+
image_w: int,
|
| 24 |
+
sigmas: Optional[Dict[str, float]] = None,
|
| 25 |
+
stride: int = STRIDE,
|
| 26 |
+
confidence_6nm: Optional[np.ndarray] = None,
|
| 27 |
+
confidence_12nm: Optional[np.ndarray] = None,
|
| 28 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
| 29 |
+
"""
|
| 30 |
+
Generate CenterNet ground truth heatmaps and offset maps.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
coords_6nm: (N, 2) array of (x, y) in ORIGINAL pixel space
|
| 34 |
+
coords_12nm: (M, 2) array of (x, y) in ORIGINAL pixel space
|
| 35 |
+
image_h: original image height
|
| 36 |
+
image_w: original image width
|
| 37 |
+
sigmas: per-class Gaussian sigma in feature space
|
| 38 |
+
stride: output stride (default 2)
|
| 39 |
+
confidence_6nm: optional per-particle confidence weights
|
| 40 |
+
confidence_12nm: optional per-particle confidence weights
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
heatmap: (2, H//stride, W//stride) float32 in [0, 1]
|
| 44 |
+
offsets: (2, H//stride, W//stride) float32 sub-pixel offsets
|
| 45 |
+
offset_mask: (H//stride, W//stride) bool — True at particle centers
|
| 46 |
+
conf_map: (2, H//stride, W//stride) float32 confidence weights
|
| 47 |
+
"""
|
| 48 |
+
if sigmas is None:
|
| 49 |
+
sigmas = {"6nm": 1.0, "12nm": 1.5}
|
| 50 |
+
|
| 51 |
+
h_feat = image_h // stride
|
| 52 |
+
w_feat = image_w // stride
|
| 53 |
+
|
| 54 |
+
heatmap = np.zeros((2, h_feat, w_feat), dtype=np.float32)
|
| 55 |
+
offsets = np.zeros((2, h_feat, w_feat), dtype=np.float32)
|
| 56 |
+
offset_mask = np.zeros((h_feat, w_feat), dtype=bool)
|
| 57 |
+
conf_map = np.ones((2, h_feat, w_feat), dtype=np.float32)
|
| 58 |
+
|
| 59 |
+
# Prepare coordinate lists with class labels and confidences
|
| 60 |
+
all_entries = []
|
| 61 |
+
if len(coords_6nm) > 0:
|
| 62 |
+
confs = confidence_6nm if confidence_6nm is not None else np.ones(len(coords_6nm))
|
| 63 |
+
for i, (x, y) in enumerate(coords_6nm):
|
| 64 |
+
all_entries.append((x, y, "6nm", confs[i]))
|
| 65 |
+
if len(coords_12nm) > 0:
|
| 66 |
+
confs = confidence_12nm if confidence_12nm is not None else np.ones(len(coords_12nm))
|
| 67 |
+
for i, (x, y) in enumerate(coords_12nm):
|
| 68 |
+
all_entries.append((x, y, "12nm", confs[i]))
|
| 69 |
+
|
| 70 |
+
for x, y, cls, conf in all_entries:
|
| 71 |
+
cidx = CLASS_IDX[cls]
|
| 72 |
+
sigma = sigmas[cls]
|
| 73 |
+
|
| 74 |
+
# Feature-space center (float)
|
| 75 |
+
cx_f = x / stride
|
| 76 |
+
cy_f = y / stride
|
| 77 |
+
|
| 78 |
+
# Integer grid center
|
| 79 |
+
cx_i = int(round(cx_f))
|
| 80 |
+
cy_i = int(round(cy_f))
|
| 81 |
+
|
| 82 |
+
# Sub-pixel offset
|
| 83 |
+
off_x = cx_f - cx_i
|
| 84 |
+
off_y = cy_f - cy_i
|
| 85 |
+
|
| 86 |
+
# Gaussian radius: truncate at 3 sigma
|
| 87 |
+
r = max(int(3 * sigma + 1), 2)
|
| 88 |
+
|
| 89 |
+
# Bounds-clipped grid
|
| 90 |
+
y0 = max(0, cy_i - r)
|
| 91 |
+
y1 = min(h_feat, cy_i + r + 1)
|
| 92 |
+
x0 = max(0, cx_i - r)
|
| 93 |
+
x1 = min(w_feat, cx_i + r + 1)
|
| 94 |
+
|
| 95 |
+
if y0 >= y1 or x0 >= x1:
|
| 96 |
+
continue
|
| 97 |
+
|
| 98 |
+
yy, xx = np.meshgrid(
|
| 99 |
+
np.arange(y0, y1),
|
| 100 |
+
np.arange(x0, x1),
|
| 101 |
+
indexing="ij",
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# Gaussian centered at INTEGER center (not float)
|
| 105 |
+
# The integer center MUST be exactly 1.0 — the CornerNet focal loss
|
| 106 |
+
# uses pos_mask = (gt == 1.0) and treats everything else as negative.
|
| 107 |
+
# Centering the Gaussian at the float position produces peaks of 0.78-0.93
|
| 108 |
+
# which the loss sees as negatives → zero positive training signal.
|
| 109 |
+
gaussian = np.exp(
|
| 110 |
+
-((xx - cx_i) ** 2 + (yy - cy_i) ** 2) / (2 * sigma ** 2)
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
# Scale by confidence (for pseudo-label weighting)
|
| 114 |
+
gaussian = gaussian * conf
|
| 115 |
+
|
| 116 |
+
# Element-wise max (handles overlapping particles correctly)
|
| 117 |
+
heatmap[cidx, y0:y1, x0:x1] = np.maximum(
|
| 118 |
+
heatmap[cidx, y0:y1, x0:x1], gaussian
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# Offset and confidence only at the integer center pixel
|
| 122 |
+
if 0 <= cy_i < h_feat and 0 <= cx_i < w_feat:
|
| 123 |
+
offsets[0, cy_i, cx_i] = off_x
|
| 124 |
+
offsets[1, cy_i, cx_i] = off_y
|
| 125 |
+
offset_mask[cy_i, cx_i] = True
|
| 126 |
+
conf_map[cidx, cy_i, cx_i] = conf
|
| 127 |
+
|
| 128 |
+
return heatmap, offsets, offset_mask, conf_map
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def extract_peaks(
|
| 132 |
+
heatmap: torch.Tensor,
|
| 133 |
+
offset_map: torch.Tensor,
|
| 134 |
+
stride: int = STRIDE,
|
| 135 |
+
conf_threshold: float = 0.3,
|
| 136 |
+
nms_kernel_sizes: Optional[Dict[str, int]] = None,
|
| 137 |
+
) -> List[dict]:
|
| 138 |
+
"""
|
| 139 |
+
Extract detections from predicted heatmap via max-pool NMS.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
heatmap: (2, H/stride, W/stride) sigmoid-activated
|
| 143 |
+
offset_map: (2, H/stride, W/stride) raw offset predictions
|
| 144 |
+
stride: output stride
|
| 145 |
+
conf_threshold: minimum confidence to keep
|
| 146 |
+
nms_kernel_sizes: per-class NMS kernel sizes
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
List of {'x': float, 'y': float, 'class': str, 'conf': float}
|
| 150 |
+
"""
|
| 151 |
+
if nms_kernel_sizes is None:
|
| 152 |
+
nms_kernel_sizes = {"6nm": 3, "12nm": 5}
|
| 153 |
+
|
| 154 |
+
detections = []
|
| 155 |
+
|
| 156 |
+
for cls_idx, cls_name in enumerate(CLASS_NAMES):
|
| 157 |
+
hm_cls = heatmap[cls_idx].unsqueeze(0).unsqueeze(0) # (1,1,H,W)
|
| 158 |
+
kernel = nms_kernel_sizes[cls_name]
|
| 159 |
+
|
| 160 |
+
# Max-pool NMS
|
| 161 |
+
hmax = F.max_pool2d(
|
| 162 |
+
hm_cls, kernel_size=kernel, stride=1, padding=kernel // 2
|
| 163 |
+
)
|
| 164 |
+
peaks = (hmax.squeeze() == heatmap[cls_idx]) & (
|
| 165 |
+
heatmap[cls_idx] > conf_threshold
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
ys, xs = torch.where(peaks)
|
| 169 |
+
for y_idx, x_idx in zip(ys, xs):
|
| 170 |
+
y_i = y_idx.item()
|
| 171 |
+
x_i = x_idx.item()
|
| 172 |
+
conf = heatmap[cls_idx, y_i, x_i].item()
|
| 173 |
+
dx = offset_map[0, y_i, x_i].item()
|
| 174 |
+
dy = offset_map[1, y_i, x_i].item()
|
| 175 |
+
|
| 176 |
+
# Back to input space with sub-pixel offset
|
| 177 |
+
det_x = (x_i + dx) * stride
|
| 178 |
+
det_y = (y_i + dy) * stride
|
| 179 |
+
|
| 180 |
+
detections.append({
|
| 181 |
+
"x": det_x,
|
| 182 |
+
"y": det_y,
|
| 183 |
+
"class": cls_name,
|
| 184 |
+
"conf": conf,
|
| 185 |
+
})
|
| 186 |
+
|
| 187 |
+
return detections
|
tests/__init__.py
ADDED
|
File without changes
|
tests/test_evaluate.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Unit tests for evaluation matching and metrics."""
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pytest
|
| 5 |
+
|
| 6 |
+
from src.evaluate import match_detections_to_gt, compute_f1, compute_average_precision
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class TestComputeF1:
|
| 10 |
+
def test_perfect_score(self):
|
| 11 |
+
f1, p, r = compute_f1(10, 0, 0)
|
| 12 |
+
assert f1 == pytest.approx(1.0, abs=0.001)
|
| 13 |
+
assert p == pytest.approx(1.0, abs=0.001)
|
| 14 |
+
assert r == pytest.approx(1.0, abs=0.001)
|
| 15 |
+
|
| 16 |
+
def test_zero_detections(self):
|
| 17 |
+
f1, p, r = compute_f1(0, 0, 10)
|
| 18 |
+
assert f1 == pytest.approx(0.0, abs=0.01)
|
| 19 |
+
assert r == pytest.approx(0.0, abs=0.01)
|
| 20 |
+
|
| 21 |
+
def test_all_false_positives(self):
|
| 22 |
+
f1, p, r = compute_f1(0, 10, 0)
|
| 23 |
+
assert p == pytest.approx(0.0, abs=0.01)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class TestMatchDetections:
|
| 27 |
+
def test_perfect_matching(self):
|
| 28 |
+
"""Detections at exact GT locations should all match."""
|
| 29 |
+
gt_6nm = np.array([[100.0, 100.0], [200.0, 200.0]])
|
| 30 |
+
gt_12nm = np.array([[300.0, 300.0]])
|
| 31 |
+
|
| 32 |
+
dets = [
|
| 33 |
+
{"x": 100.0, "y": 100.0, "class": "6nm", "conf": 0.9},
|
| 34 |
+
{"x": 200.0, "y": 200.0, "class": "6nm", "conf": 0.8},
|
| 35 |
+
{"x": 300.0, "y": 300.0, "class": "12nm", "conf": 0.7},
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
results = match_detections_to_gt(dets, gt_6nm, gt_12nm)
|
| 39 |
+
assert results["6nm"]["tp"] == 2
|
| 40 |
+
assert results["6nm"]["fp"] == 0
|
| 41 |
+
assert results["6nm"]["fn"] == 0
|
| 42 |
+
assert results["12nm"]["tp"] == 1
|
| 43 |
+
|
| 44 |
+
def test_wrong_class_no_match(self):
|
| 45 |
+
"""Detection near GT but wrong class should not match."""
|
| 46 |
+
gt_6nm = np.array([[100.0, 100.0]])
|
| 47 |
+
gt_12nm = np.empty((0, 2))
|
| 48 |
+
|
| 49 |
+
dets = [
|
| 50 |
+
{"x": 100.0, "y": 100.0, "class": "12nm", "conf": 0.9},
|
| 51 |
+
]
|
| 52 |
+
|
| 53 |
+
results = match_detections_to_gt(dets, gt_6nm, gt_12nm)
|
| 54 |
+
assert results["6nm"]["fn"] == 1 # missed
|
| 55 |
+
assert results["12nm"]["fp"] == 1 # false positive
|
| 56 |
+
|
| 57 |
+
def test_beyond_radius_no_match(self):
|
| 58 |
+
"""Detection beyond match radius should not match."""
|
| 59 |
+
gt_6nm = np.array([[100.0, 100.0]])
|
| 60 |
+
gt_12nm = np.empty((0, 2))
|
| 61 |
+
|
| 62 |
+
dets = [
|
| 63 |
+
{"x": 120.0, "y": 100.0, "class": "6nm", "conf": 0.9}, # 20px away > 9px radius
|
| 64 |
+
]
|
| 65 |
+
|
| 66 |
+
results = match_detections_to_gt(
|
| 67 |
+
dets, gt_6nm, gt_12nm, match_radii={"6nm": 9.0, "12nm": 15.0}
|
| 68 |
+
)
|
| 69 |
+
assert results["6nm"]["tp"] == 0
|
| 70 |
+
assert results["6nm"]["fp"] == 1
|
| 71 |
+
assert results["6nm"]["fn"] == 1
|
| 72 |
+
|
| 73 |
+
def test_within_radius_matches(self):
|
| 74 |
+
"""Detection within match radius should match."""
|
| 75 |
+
gt_6nm = np.array([[100.0, 100.0]])
|
| 76 |
+
gt_12nm = np.empty((0, 2))
|
| 77 |
+
|
| 78 |
+
dets = [
|
| 79 |
+
{"x": 105.0, "y": 100.0, "class": "6nm", "conf": 0.9}, # 5px away < 9px
|
| 80 |
+
]
|
| 81 |
+
|
| 82 |
+
results = match_detections_to_gt(
|
| 83 |
+
dets, gt_6nm, gt_12nm, match_radii={"6nm": 9.0, "12nm": 15.0}
|
| 84 |
+
)
|
| 85 |
+
assert results["6nm"]["tp"] == 1
|
| 86 |
+
|
| 87 |
+
def test_no_detections(self):
|
| 88 |
+
"""No detections: all GT are false negatives."""
|
| 89 |
+
gt_6nm = np.array([[100.0, 100.0], [200.0, 200.0]])
|
| 90 |
+
results = match_detections_to_gt([], gt_6nm, np.empty((0, 2)))
|
| 91 |
+
assert results["6nm"]["fn"] == 2
|
| 92 |
+
assert results["6nm"]["f1"] == pytest.approx(0.0, abs=0.01)
|
| 93 |
+
|
| 94 |
+
def test_no_ground_truth(self):
|
| 95 |
+
"""No GT: all detections are false positives."""
|
| 96 |
+
dets = [{"x": 100.0, "y": 100.0, "class": "6nm", "conf": 0.9}]
|
| 97 |
+
results = match_detections_to_gt(dets, np.empty((0, 2)), np.empty((0, 2)))
|
| 98 |
+
assert results["6nm"]["fp"] == 1
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class TestAveragePrecision:
|
| 102 |
+
def test_perfect_ap(self):
|
| 103 |
+
"""All detections match in rank order → AP = 1.0."""
|
| 104 |
+
gt = np.array([[100.0, 100.0], [200.0, 200.0]])
|
| 105 |
+
dets = [
|
| 106 |
+
{"x": 100.0, "y": 100.0, "class": "6nm", "conf": 0.9},
|
| 107 |
+
{"x": 200.0, "y": 200.0, "class": "6nm", "conf": 0.8},
|
| 108 |
+
]
|
| 109 |
+
ap = compute_average_precision(dets, gt, match_radius=9.0)
|
| 110 |
+
assert ap == pytest.approx(1.0, abs=0.01)
|
tests/test_heatmap.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Unit tests for heatmap GT generation and peak extraction."""
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pytest
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from src.heatmap import generate_heatmap_gt, extract_peaks
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class TestHeatmapGeneration:
|
| 11 |
+
def test_single_particle_peak(self):
|
| 12 |
+
"""A single particle should produce a Gaussian peak at the correct location."""
|
| 13 |
+
coords_6nm = np.array([[100.0, 200.0]])
|
| 14 |
+
coords_12nm = np.empty((0, 2))
|
| 15 |
+
|
| 16 |
+
hm, off, mask, conf = generate_heatmap_gt(
|
| 17 |
+
coords_6nm, coords_12nm, 512, 512, stride=2,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
assert hm.shape == (2, 256, 256)
|
| 21 |
+
assert off.shape == (2, 256, 256)
|
| 22 |
+
assert mask.shape == (256, 256)
|
| 23 |
+
|
| 24 |
+
# Peak should be at (50, 100) in stride-2 space
|
| 25 |
+
peak_y, peak_x = np.unravel_index(hm[0].argmax(), hm[0].shape)
|
| 26 |
+
assert abs(peak_x - 50) <= 1
|
| 27 |
+
assert abs(peak_y - 100) <= 1
|
| 28 |
+
|
| 29 |
+
# Peak value should be 1.0 (confidence=1.0 default)
|
| 30 |
+
assert hm[0].max() == pytest.approx(1.0, abs=0.01)
|
| 31 |
+
|
| 32 |
+
# 12nm channel should be empty
|
| 33 |
+
assert hm[1].max() == 0.0
|
| 34 |
+
|
| 35 |
+
def test_two_classes(self):
|
| 36 |
+
"""Both classes should produce peaks in their respective channels."""
|
| 37 |
+
coords_6nm = np.array([[100.0, 100.0]])
|
| 38 |
+
coords_12nm = np.array([[200.0, 200.0]])
|
| 39 |
+
|
| 40 |
+
hm, _, _, _ = generate_heatmap_gt(
|
| 41 |
+
coords_6nm, coords_12nm, 512, 512, stride=2,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
assert hm[0].max() > 0.9 # 6nm channel has peak
|
| 45 |
+
assert hm[1].max() > 0.9 # 12nm channel has peak
|
| 46 |
+
|
| 47 |
+
def test_offset_values(self):
|
| 48 |
+
"""Offsets should encode sub-pixel correction."""
|
| 49 |
+
# Place particle at (101.5, 200.5) → stride-2 center at (50.75, 100.25)
|
| 50 |
+
# Integer center: (51, 100) → offset: (-0.25, 0.25)
|
| 51 |
+
coords_6nm = np.array([[101.5, 200.5]])
|
| 52 |
+
coords_12nm = np.empty((0, 2))
|
| 53 |
+
|
| 54 |
+
_, off, mask, _ = generate_heatmap_gt(
|
| 55 |
+
coords_6nm, coords_12nm, 512, 512, stride=2,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# Mask should have exactly one True pixel
|
| 59 |
+
assert mask.sum() == 1
|
| 60 |
+
|
| 61 |
+
def test_empty_annotations(self):
|
| 62 |
+
"""Empty annotations should produce zero heatmap."""
|
| 63 |
+
hm, off, mask, conf = generate_heatmap_gt(
|
| 64 |
+
np.empty((0, 2)), np.empty((0, 2)), 512, 512,
|
| 65 |
+
)
|
| 66 |
+
assert hm.max() == 0.0
|
| 67 |
+
assert mask.sum() == 0
|
| 68 |
+
|
| 69 |
+
def test_confidence_weighting(self):
|
| 70 |
+
"""Confidence < 1 should scale peak value."""
|
| 71 |
+
coords = np.array([[100.0, 100.0]])
|
| 72 |
+
confidences = np.array([0.5])
|
| 73 |
+
|
| 74 |
+
hm, _, _, _ = generate_heatmap_gt(
|
| 75 |
+
coords, np.empty((0, 2)), 512, 512,
|
| 76 |
+
confidence_6nm=confidences,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
assert hm[0].max() == pytest.approx(0.5, abs=0.05)
|
| 80 |
+
|
| 81 |
+
def test_overlapping_particles_use_max(self):
|
| 82 |
+
"""Overlapping Gaussians should use element-wise max, not sum."""
|
| 83 |
+
coords = np.array([[100.0, 100.0], [104.0, 100.0]]) # close together
|
| 84 |
+
hm, _, _, _ = generate_heatmap_gt(
|
| 85 |
+
coords, np.empty((0, 2)), 512, 512, stride=2,
|
| 86 |
+
)
|
| 87 |
+
# Max should be 1.0, not >1.0
|
| 88 |
+
assert hm[0].max() <= 1.0
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class TestPeakExtraction:
|
| 92 |
+
def test_single_peak(self):
|
| 93 |
+
"""Extract a single peak from synthetic heatmap."""
|
| 94 |
+
heatmap = torch.zeros(2, 256, 256)
|
| 95 |
+
heatmap[0, 100, 50] = 0.9 # 6nm peak
|
| 96 |
+
|
| 97 |
+
offset_map = torch.zeros(2, 256, 256)
|
| 98 |
+
offset_map[0, 100, 50] = 0.3 # dx
|
| 99 |
+
offset_map[1, 100, 50] = 0.1 # dy
|
| 100 |
+
|
| 101 |
+
dets = extract_peaks(heatmap, offset_map, stride=2, conf_threshold=0.5)
|
| 102 |
+
|
| 103 |
+
assert len(dets) == 1
|
| 104 |
+
assert dets[0]["class"] == "6nm"
|
| 105 |
+
assert dets[0]["conf"] == pytest.approx(0.9, abs=0.01)
|
| 106 |
+
# x = (50 + 0.3) * 2 = 100.6
|
| 107 |
+
assert dets[0]["x"] == pytest.approx(100.6, abs=0.1)
|
| 108 |
+
# y = (100 + 0.1) * 2 = 200.2
|
| 109 |
+
assert dets[0]["y"] == pytest.approx(200.2, abs=0.1)
|
| 110 |
+
|
| 111 |
+
def test_nms_suppresses_neighbors(self):
|
| 112 |
+
"""NMS should suppress weaker neighboring peaks."""
|
| 113 |
+
heatmap = torch.zeros(2, 256, 256)
|
| 114 |
+
heatmap[0, 100, 50] = 0.9 # strong
|
| 115 |
+
heatmap[0, 101, 50] = 0.7 # weaker neighbor (within NMS kernel)
|
| 116 |
+
|
| 117 |
+
dets = extract_peaks(
|
| 118 |
+
heatmap, torch.zeros(2, 256, 256),
|
| 119 |
+
stride=2, conf_threshold=0.5,
|
| 120 |
+
nms_kernel_sizes={"6nm": 5, "12nm": 5},
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
# Only the stronger peak should survive
|
| 124 |
+
assert len([d for d in dets if d["class"] == "6nm"]) == 1
|
| 125 |
+
|
| 126 |
+
def test_below_threshold_filtered(self):
|
| 127 |
+
"""Peaks below threshold should not be extracted."""
|
| 128 |
+
heatmap = torch.zeros(2, 256, 256)
|
| 129 |
+
heatmap[0, 100, 50] = 0.2 # below 0.3 threshold
|
| 130 |
+
|
| 131 |
+
dets = extract_peaks(heatmap, torch.zeros(2, 256, 256), conf_threshold=0.3)
|
| 132 |
+
assert len(dets) == 0
|
tests/test_loss.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Unit tests for loss functions."""
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from src.loss import cornernet_focal_loss, offset_loss, total_loss
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class TestCornerNetFocalLoss:
|
| 10 |
+
def test_perfect_prediction_zero_loss(self):
|
| 11 |
+
"""Perfect predictions should produce near-zero loss."""
|
| 12 |
+
gt = torch.zeros(1, 2, 64, 64)
|
| 13 |
+
gt[0, 0, 32, 32] = 1.0 # one particle
|
| 14 |
+
|
| 15 |
+
# Near-perfect prediction
|
| 16 |
+
pred = torch.zeros(1, 2, 64, 64) + 1e-6
|
| 17 |
+
pred[0, 0, 32, 32] = 1.0 - 1e-6
|
| 18 |
+
|
| 19 |
+
loss = cornernet_focal_loss(pred, gt)
|
| 20 |
+
assert loss.item() < 0.1
|
| 21 |
+
|
| 22 |
+
def test_all_zeros_prediction_nonzero_loss(self):
|
| 23 |
+
"""Predicting all zeros when particles exist should give positive loss."""
|
| 24 |
+
gt = torch.zeros(1, 2, 64, 64)
|
| 25 |
+
gt[0, 0, 32, 32] = 1.0
|
| 26 |
+
|
| 27 |
+
pred = torch.zeros(1, 2, 64, 64) + 1e-6
|
| 28 |
+
loss = cornernet_focal_loss(pred, gt)
|
| 29 |
+
assert loss.item() > 0
|
| 30 |
+
|
| 31 |
+
def test_high_false_positive_penalized(self):
|
| 32 |
+
"""Predicting high confidence where GT is zero should be penalized."""
|
| 33 |
+
gt = torch.zeros(1, 2, 64, 64)
|
| 34 |
+
pred_low_fp = torch.zeros(1, 2, 64, 64) + 0.01
|
| 35 |
+
pred_high_fp = torch.zeros(1, 2, 64, 64) + 0.9
|
| 36 |
+
|
| 37 |
+
loss_low = cornernet_focal_loss(pred_low_fp, gt)
|
| 38 |
+
loss_high = cornernet_focal_loss(pred_high_fp, gt)
|
| 39 |
+
|
| 40 |
+
assert loss_high.item() > loss_low.item()
|
| 41 |
+
|
| 42 |
+
def test_near_peak_reduced_penalty(self):
|
| 43 |
+
"""Pixels near GT peaks should have reduced negative penalty via beta term."""
|
| 44 |
+
gt = torch.zeros(1, 2, 64, 64)
|
| 45 |
+
gt[0, 0, 32, 32] = 1.0
|
| 46 |
+
gt[0, 0, 31, 32] = 0.8 # nearby pixel with Gaussian falloff
|
| 47 |
+
|
| 48 |
+
# Moderate prediction near peak should have low loss
|
| 49 |
+
pred = torch.zeros(1, 2, 64, 64) + 0.01
|
| 50 |
+
pred[0, 0, 31, 32] = 0.5
|
| 51 |
+
|
| 52 |
+
loss = cornernet_focal_loss(pred, gt)
|
| 53 |
+
# Should be a reasonable value, not extremely high
|
| 54 |
+
assert loss.item() < 10
|
| 55 |
+
|
| 56 |
+
def test_confidence_weighting(self):
|
| 57 |
+
"""Confidence weights should scale the loss."""
|
| 58 |
+
gt = torch.zeros(1, 2, 64, 64)
|
| 59 |
+
gt[0, 0, 32, 32] = 1.0
|
| 60 |
+
pred = torch.zeros(1, 2, 64, 64) + 0.5
|
| 61 |
+
|
| 62 |
+
weights_full = torch.ones(1, 2, 64, 64)
|
| 63 |
+
weights_half = torch.ones(1, 2, 64, 64) * 0.5
|
| 64 |
+
|
| 65 |
+
loss_full = cornernet_focal_loss(pred, gt, conf_weights=weights_full)
|
| 66 |
+
loss_half = cornernet_focal_loss(pred, gt, conf_weights=weights_half)
|
| 67 |
+
|
| 68 |
+
# Half weights should produce lower loss
|
| 69 |
+
assert loss_half.item() < loss_full.item()
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class TestOffsetLoss:
|
| 73 |
+
def test_zero_when_no_particles(self):
|
| 74 |
+
"""Offset loss should be zero when mask is empty."""
|
| 75 |
+
pred = torch.randn(1, 2, 64, 64)
|
| 76 |
+
gt = torch.zeros(1, 2, 64, 64)
|
| 77 |
+
mask = torch.zeros(1, 64, 64, dtype=torch.bool)
|
| 78 |
+
|
| 79 |
+
loss = offset_loss(pred, gt, mask)
|
| 80 |
+
assert loss.item() == 0.0
|
| 81 |
+
|
| 82 |
+
def test_nonzero_with_particles(self):
|
| 83 |
+
"""Offset loss should be nonzero when predictions differ from GT."""
|
| 84 |
+
pred = torch.randn(1, 2, 64, 64)
|
| 85 |
+
gt = torch.zeros(1, 2, 64, 64)
|
| 86 |
+
mask = torch.zeros(1, 64, 64, dtype=torch.bool)
|
| 87 |
+
mask[0, 32, 32] = True
|
| 88 |
+
|
| 89 |
+
loss = offset_loss(pred, gt, mask)
|
| 90 |
+
assert loss.item() > 0
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class TestTotalLoss:
|
| 94 |
+
def test_returns_three_values(self):
|
| 95 |
+
"""total_loss should return (total, hm_loss, off_loss)."""
|
| 96 |
+
hm_pred = torch.sigmoid(torch.randn(1, 2, 64, 64))
|
| 97 |
+
hm_gt = torch.zeros(1, 2, 64, 64)
|
| 98 |
+
off_pred = torch.randn(1, 2, 64, 64)
|
| 99 |
+
off_gt = torch.zeros(1, 2, 64, 64)
|
| 100 |
+
mask = torch.zeros(1, 64, 64, dtype=torch.bool)
|
| 101 |
+
|
| 102 |
+
total, hm_val, off_val = total_loss(
|
| 103 |
+
hm_pred, hm_gt, off_pred, off_gt, mask,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
assert isinstance(total, torch.Tensor)
|
| 107 |
+
assert isinstance(hm_val, float)
|
| 108 |
+
assert isinstance(off_val, float)
|
| 109 |
+
assert total.requires_grad
|
tests/test_model.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Unit tests for model architecture."""
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from src.model import ImmunogoldCenterNet, BiFPN
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class TestModelForwardPass:
|
| 10 |
+
def test_output_shapes(self):
|
| 11 |
+
"""Verify output shapes match stride-2 specification."""
|
| 12 |
+
model = ImmunogoldCenterNet(pretrained_path=None)
|
| 13 |
+
x = torch.randn(1, 1, 512, 512)
|
| 14 |
+
hm, off = model(x)
|
| 15 |
+
|
| 16 |
+
assert hm.shape == (1, 2, 256, 256), f"Expected (1,2,256,256), got {hm.shape}"
|
| 17 |
+
assert off.shape == (1, 2, 256, 256), f"Expected (1,2,256,256), got {off.shape}"
|
| 18 |
+
|
| 19 |
+
def test_heatmap_sigmoid_range(self):
|
| 20 |
+
"""Heatmap outputs should be in [0, 1] from sigmoid."""
|
| 21 |
+
model = ImmunogoldCenterNet(pretrained_path=None)
|
| 22 |
+
x = torch.randn(1, 1, 512, 512)
|
| 23 |
+
hm, _ = model(x)
|
| 24 |
+
|
| 25 |
+
assert hm.min() >= 0.0
|
| 26 |
+
assert hm.max() <= 1.0
|
| 27 |
+
|
| 28 |
+
def test_batch_dimension(self):
|
| 29 |
+
"""Model should handle batch size > 1."""
|
| 30 |
+
model = ImmunogoldCenterNet(pretrained_path=None)
|
| 31 |
+
x = torch.randn(4, 1, 512, 512)
|
| 32 |
+
hm, off = model(x)
|
| 33 |
+
|
| 34 |
+
assert hm.shape[0] == 4
|
| 35 |
+
assert off.shape[0] == 4
|
| 36 |
+
|
| 37 |
+
def test_variable_input_size(self):
|
| 38 |
+
"""Model should handle different input sizes (multiples of 32)."""
|
| 39 |
+
model = ImmunogoldCenterNet(pretrained_path=None)
|
| 40 |
+
|
| 41 |
+
for size in [256, 384, 512]:
|
| 42 |
+
x = torch.randn(1, 1, size, size)
|
| 43 |
+
hm, off = model(x)
|
| 44 |
+
assert hm.shape == (1, 2, size // 2, size // 2)
|
| 45 |
+
|
| 46 |
+
def test_parameter_count(self):
|
| 47 |
+
"""Model should have approximately 25M parameters."""
|
| 48 |
+
model = ImmunogoldCenterNet(pretrained_path=None)
|
| 49 |
+
n_params = sum(p.numel() for p in model.parameters())
|
| 50 |
+
# ResNet-50 is ~25M, plus BiFPN and heads
|
| 51 |
+
assert 20_000_000 < n_params < 40_000_000
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class TestFreezeUnfreeze:
|
| 55 |
+
def test_freeze_encoder(self):
|
| 56 |
+
"""Frozen encoder should have no gradients."""
|
| 57 |
+
model = ImmunogoldCenterNet(pretrained_path=None)
|
| 58 |
+
model.freeze_encoder()
|
| 59 |
+
|
| 60 |
+
for name, param in model.named_parameters():
|
| 61 |
+
if any(x in name for x in ["stem", "layer1", "layer2", "layer3", "layer4"]):
|
| 62 |
+
assert not param.requires_grad, f"{name} should be frozen"
|
| 63 |
+
|
| 64 |
+
# BiFPN and heads should still be trainable
|
| 65 |
+
for name, param in model.bifpn.named_parameters():
|
| 66 |
+
assert param.requires_grad, f"bifpn.{name} should be trainable"
|
| 67 |
+
|
| 68 |
+
def test_unfreeze_deep(self):
|
| 69 |
+
"""Unfreezing deep layers should enable gradients for layer3/4."""
|
| 70 |
+
model = ImmunogoldCenterNet(pretrained_path=None)
|
| 71 |
+
model.freeze_encoder()
|
| 72 |
+
model.unfreeze_deep_layers()
|
| 73 |
+
|
| 74 |
+
for param in model.layer3.parameters():
|
| 75 |
+
assert param.requires_grad
|
| 76 |
+
for param in model.layer4.parameters():
|
| 77 |
+
assert param.requires_grad
|
| 78 |
+
# Stem and layer1/2 still frozen
|
| 79 |
+
for param in model.stem.parameters():
|
| 80 |
+
assert not param.requires_grad
|
| 81 |
+
|
| 82 |
+
def test_unfreeze_all(self):
|
| 83 |
+
"""Unfreeze all should enable all gradients."""
|
| 84 |
+
model = ImmunogoldCenterNet(pretrained_path=None)
|
| 85 |
+
model.freeze_encoder()
|
| 86 |
+
model.unfreeze_all()
|
| 87 |
+
|
| 88 |
+
for param in model.parameters():
|
| 89 |
+
assert param.requires_grad
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class TestBiFPN:
|
| 93 |
+
def test_bifpn_output_shapes(self):
|
| 94 |
+
"""BiFPN should output 4 feature maps at 128 channels."""
|
| 95 |
+
bifpn = BiFPN(
|
| 96 |
+
in_channels=[256, 512, 1024, 2048],
|
| 97 |
+
out_channels=128,
|
| 98 |
+
num_rounds=2,
|
| 99 |
+
)
|
| 100 |
+
features = [
|
| 101 |
+
torch.randn(1, 256, 128, 128), # P2: stride 4
|
| 102 |
+
torch.randn(1, 512, 64, 64), # P3: stride 8
|
| 103 |
+
torch.randn(1, 1024, 32, 32), # P4: stride 16
|
| 104 |
+
torch.randn(1, 2048, 16, 16), # P5: stride 32
|
| 105 |
+
]
|
| 106 |
+
|
| 107 |
+
outputs = bifpn(features)
|
| 108 |
+
assert len(outputs) == 4
|
| 109 |
+
for i, out in enumerate(outputs):
|
| 110 |
+
assert out.shape[1] == 128, f"P{i+2} channels should be 128"
|
| 111 |
+
assert out.shape[2:] == features[i].shape[2:], \
|
| 112 |
+
f"P{i+2} spatial dims should match input"
|