File size: 9,814 Bytes
86c24cb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 | """
Data loading, annotation parsing, and preprocessing for immunogold TEM images.
The model receives raw images — the CEM500K backbone was pretrained on raw EM.
Top-hat preprocessing is only used by LodeStar (Stage 1).
"""
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
import pandas as pd
import tifffile
# ---------------------------------------------------------------------------
# Data registry: robust discovery of images, masks, and annotations
# ---------------------------------------------------------------------------
@dataclass
class SynapseRecord:
"""Metadata for one synapse sample."""
synapse_id: str
image_path: Path
mask_path: Optional[Path]
csv_6nm_paths: List[Path] = field(default_factory=list)
csv_12nm_paths: List[Path] = field(default_factory=list)
has_6nm: bool = False
has_12nm: bool = False
def discover_synapse_data(root: str, synapse_ids: List[str]) -> List[SynapseRecord]:
"""
Discover all TIF images, masks, and CSV annotations for each synapse.
Handles naming inconsistencies:
- S22: main image is S22_0003.tif, two Results folders
- S25: 12nm CSV has no space ("Results12nm")
- CSV patterns: "Results 6nm XY" vs "Results XY in microns 6nm"
"""
root = Path(root)
analyzed = root / "analyzed synapses"
records = []
for sid in synapse_ids:
folder = analyzed / sid
if not folder.exists():
raise FileNotFoundError(f"Synapse folder not found: {folder}")
# --- Find main image (TIF without 'mask' or 'color' in name) ---
all_tifs = list(folder.glob("*.tif"))
main_tifs = [
t for t in all_tifs
if "mask" not in t.stem.lower() and "color" not in t.stem.lower()
]
if not main_tifs:
raise FileNotFoundError(f"No main image found in {folder}")
# Prefer the largest file (main EM image) if multiple found
image_path = max(main_tifs, key=lambda t: t.stat().st_size)
# --- Find mask ---
mask_tifs = [t for t in all_tifs if "mask" in t.stem.lower()]
mask_path = None
if mask_tifs:
# Prefer plain "mask.tif" over "mask 1.tif" / "mask 2.tif"
plain = [t for t in mask_tifs if t.stem.lower().endswith("mask")]
mask_path = plain[0] if plain else mask_tifs[0]
# --- Find CSVs across all Results* subdirectories ---
results_dirs = sorted(folder.glob("Results*"))
# Also check direct subdirs like "Results 1", "Results 2"
csv_6nm_paths = []
csv_12nm_paths = []
for rdir in results_dirs:
if rdir.is_dir():
for csv_file in rdir.glob("*.csv"):
name_lower = csv_file.name.lower()
if "6nm" in name_lower:
csv_6nm_paths.append(csv_file)
elif "12nm" in name_lower:
csv_12nm_paths.append(csv_file)
record = SynapseRecord(
synapse_id=sid,
image_path=image_path,
mask_path=mask_path,
csv_6nm_paths=csv_6nm_paths,
csv_12nm_paths=csv_12nm_paths,
has_6nm=len(csv_6nm_paths) > 0,
has_12nm=len(csv_12nm_paths) > 0,
)
records.append(record)
return records
# ---------------------------------------------------------------------------
# Image I/O
# ---------------------------------------------------------------------------
def load_image(path: Path) -> np.ndarray:
"""
Load a TIF image as grayscale uint8.
Handles:
- RGB images (take first channel)
- Palette-mode images
- Already-grayscale images
"""
img = tifffile.imread(str(path))
if img.ndim == 3:
# RGB or multi-channel — take first channel (all channels identical in these images)
img = img[:, :, 0] if img.shape[2] <= 4 else img[0]
return img.astype(np.uint8)
def load_mask(path: Path) -> np.ndarray:
"""
Load mask TIF as binary array.
Mask is RGB where tissue regions have values < 250 in at least one channel.
Returns boolean array: True = tissue/structural region.
"""
mask_rgb = tifffile.imread(str(path))
if mask_rgb.ndim == 2:
return mask_rgb < 250
# RGB mask: tissue where any channel is not white
return np.any(mask_rgb < 250, axis=-1)
# ---------------------------------------------------------------------------
# Annotation loading and coordinate conversion
# ---------------------------------------------------------------------------
def load_annotations_csv(csv_path: Path) -> pd.DataFrame:
"""
Load annotation CSV with columns [index, X, Y].
CSV headers have leading space: " ,X,Y".
Coordinates are normalized [0, 1] despite 'microns' in filename.
"""
df = pd.read_csv(csv_path)
# Normalize column names (strip whitespace)
df.columns = [c.strip() for c in df.columns]
# Rename unnamed index column
if "" in df.columns:
df = df.rename(columns={"": "idx"})
return df[["X", "Y"]]
# Micron-to-pixel scale factor: consistent across all synapses (verified
# against researcher's color overlay TIFs). The CSV columns labeled "XY in
# microns" really ARE microns — multiply by this constant to get pixels.
MICRONS_TO_PIXELS = 1790.0
def load_all_annotations(
record: SynapseRecord, image_shape: Tuple[int, int]
) -> Dict[str, np.ndarray]:
"""
Load and convert annotations for one synapse to pixel coordinates.
CSV coordinates are in microns (despite filename suggesting normalization).
Multiply by MICRONS_TO_PIXELS (1790 px/micron) to convert.
Args:
record: SynapseRecord with CSV paths.
image_shape: (height, width) of the corresponding image.
Returns:
Dictionary with keys '6nm' and '12nm', each containing
an Nx2 array of (x, y) pixel coordinates.
"""
h, w = image_shape[:2]
result = {"6nm": np.empty((0, 2), dtype=np.float64),
"12nm": np.empty((0, 2), dtype=np.float64)}
for cls, paths in [("6nm", record.csv_6nm_paths),
("12nm", record.csv_12nm_paths)]:
all_coords = []
for csv_path in paths:
df = load_annotations_csv(csv_path)
# Convert microns to pixels
px_x = df["X"].values * MICRONS_TO_PIXELS
px_y = df["Y"].values * MICRONS_TO_PIXELS
# Validate: coords must fall within image bounds
assert px_x.max() < w + 10, \
f"X coords out of bounds ({px_x.max():.0f} > {w}) in {csv_path}"
assert px_y.max() < h + 10, \
f"Y coords out of bounds ({px_y.max():.0f} > {h}) in {csv_path}"
all_coords.append(np.stack([px_x, px_y], axis=1))
if all_coords:
coords = np.concatenate(all_coords, axis=0)
# Deduplicate (for S22 merged results): remove within 3px
if len(coords) > 1:
coords = _deduplicate_coords(coords, min_dist=3.0)
result[cls] = coords
return result
def _deduplicate_coords(
coords: np.ndarray, min_dist: float = 3.0
) -> np.ndarray:
"""Remove duplicate coordinates within min_dist pixels."""
from scipy.spatial.distance import cdist
if len(coords) <= 1:
return coords
dists = cdist(coords, coords)
np.fill_diagonal(dists, np.inf)
keep = np.ones(len(coords), dtype=bool)
for i in range(len(coords)):
if not keep[i]:
continue
# Mark later duplicates
for j in range(i + 1, len(coords)):
if keep[j] and dists[i, j] < min_dist:
keep[j] = False
return coords[keep]
# ---------------------------------------------------------------------------
# Preprocessing transforms
# ---------------------------------------------------------------------------
def preprocess_image(img: np.ndarray, bead_class: str,
tophat_radii: Optional[Dict[str, int]] = None,
clahe_clip_limit: float = 0.03,
clahe_kernel_size: int = 64) -> np.ndarray:
"""
Top-hat + CLAHE preprocessing. Used ONLY by LodeStar (Stage 1).
Not used for model training — the CEM500K backbone expects raw EM images.
"""
from skimage import exposure
from skimage.morphology import disk, white_tophat
if tophat_radii is None:
tophat_radii = {"6nm": 8, "12nm": 12}
img_inv = (255 - img).astype(np.float32)
radius = tophat_radii[bead_class]
tophat = white_tophat(img_inv, disk(radius))
tophat_max = tophat.max()
if tophat_max > 0:
tophat_norm = tophat / tophat_max
else:
tophat_norm = tophat
enhanced = exposure.equalize_adapthist(
tophat_norm,
clip_limit=clahe_clip_limit,
kernel_size=clahe_kernel_size,
)
return (enhanced * 255).astype(np.uint8)
# ---------------------------------------------------------------------------
# Convenience: load everything for one synapse
# ---------------------------------------------------------------------------
def load_synapse(record: SynapseRecord) -> dict:
"""
Load image, mask, and annotations for one synapse.
Returns dict with keys: 'image', 'mask', 'annotations',
'synapse_id', 'image_shape'
"""
img = load_image(record.image_path)
mask = load_mask(record.mask_path) if record.mask_path else None
annotations = load_all_annotations(record, img.shape)
return {
"synapse_id": record.synapse_id,
"image": img,
"mask": mask,
"annotations": annotations,
"image_shape": img.shape,
}
|