MidasMap / scripts /verify_data.py
AnikS22's picture
Sync all source code, docs, and configs
1fc7794 verified
"""
Data verification script: loads all images and annotations,
validates counts, and saves visual overlays.
Usage:
python scripts/verify_data.py --config config/config.yaml
"""
import argparse
import sys
from pathlib import Path
import numpy as np
import yaml
sys.path.insert(0, str(Path(__file__).parent.parent))
from src.preprocessing import discover_synapse_data, load_synapse
from src.visualize import overlay_annotations
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", default="config/config.yaml")
args = parser.parse_args()
with open(args.config) as f:
cfg = yaml.safe_load(f)
print("=" * 60)
print("Immunogold Data Verification")
print("=" * 60)
records = discover_synapse_data(cfg["data"]["root"], cfg["data"]["synapse_ids"])
total_6nm = 0
total_12nm = 0
output_dir = Path("results/verification")
for record in records:
print(f"\n--- {record.synapse_id} ---")
print(f" Image: {record.image_path.name}")
print(f" Mask: {record.mask_path.name if record.mask_path else 'NONE'}")
print(f" 6nm CSVs: {[p.name for p in record.csv_6nm_paths]}")
print(f" 12nm CSVs: {[p.name for p in record.csv_12nm_paths]}")
data = load_synapse(record)
img = data["image"]
annots = data["annotations"]
n6 = len(annots["6nm"])
n12 = len(annots["12nm"])
total_6nm += n6
total_12nm += n12
print(f" Image shape: {img.shape}")
print(f" 6nm particles: {n6}")
print(f" 12nm particles: {n12}")
if data["mask"] is not None:
# Check how many particles fall within mask
mask = data["mask"]
for cls, coords in annots.items():
if len(coords) == 0:
continue
inside = sum(
1 for x, y in coords
if 0 <= int(y) < mask.shape[0] and
0 <= int(x) < mask.shape[1] and
mask[int(y), int(x)]
)
print(f" {cls} in mask: {inside}/{len(coords)}")
# Save overlay
overlay_annotations(
img, annots,
title=f"{record.synapse_id}: {n6} 6nm, {n12} 12nm",
save_path=output_dir / f"{record.synapse_id}_annotations.png",
)
print(f"\n{'=' * 60}")
print(f"TOTAL: {total_6nm} 6nm + {total_12nm} 12nm = {total_6nm + total_12nm}")
print(f"Expected: 403 6nm + 50 12nm = 453")
if total_6nm + total_12nm >= 400:
print("PASS: Particle counts look reasonable")
else:
print("WARNING: Total count is lower than expected")
print(f"\nOverlays saved to: {output_dir}")
if __name__ == "__main__":
main()