AnikS22 commited on
Commit
60eb8cf
·
verified ·
1 Parent(s): 933831b

Upload src/preprocessing.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/preprocessing.py +284 -0
src/preprocessing.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data loading, annotation parsing, and preprocessing for immunogold TEM images.
3
+
4
+ The model receives raw images — the CEM500K backbone was pretrained on raw EM.
5
+ Top-hat preprocessing is only used by LodeStar (Stage 1).
6
+ """
7
+
8
+ from dataclasses import dataclass, field
9
+ from pathlib import Path
10
+ from typing import Dict, List, Optional, Tuple
11
+
12
+ import numpy as np
13
+ import pandas as pd
14
+ import tifffile
15
+
16
+
17
+ # ---------------------------------------------------------------------------
18
+ # Data registry: robust discovery of images, masks, and annotations
19
+ # ---------------------------------------------------------------------------
20
+
21
+ @dataclass
22
+ class SynapseRecord:
23
+ """Metadata for one synapse sample."""
24
+ synapse_id: str
25
+ image_path: Path
26
+ mask_path: Optional[Path]
27
+ csv_6nm_paths: List[Path] = field(default_factory=list)
28
+ csv_12nm_paths: List[Path] = field(default_factory=list)
29
+ has_6nm: bool = False
30
+ has_12nm: bool = False
31
+
32
+
33
+ def discover_synapse_data(root: str, synapse_ids: List[str]) -> List[SynapseRecord]:
34
+ """
35
+ Discover all TIF images, masks, and CSV annotations for each synapse.
36
+
37
+ Handles naming inconsistencies:
38
+ - S22: main image is S22_0003.tif, two Results folders
39
+ - S25: 12nm CSV has no space ("Results12nm")
40
+ - CSV patterns: "Results 6nm XY" vs "Results XY in microns 6nm"
41
+ """
42
+ root = Path(root)
43
+ analyzed = root / "analyzed synapses"
44
+ records = []
45
+
46
+ for sid in synapse_ids:
47
+ folder = analyzed / sid
48
+ if not folder.exists():
49
+ raise FileNotFoundError(f"Synapse folder not found: {folder}")
50
+
51
+ # --- Find main image (TIF without 'mask' or 'color' in name) ---
52
+ all_tifs = list(folder.glob("*.tif"))
53
+ main_tifs = [
54
+ t for t in all_tifs
55
+ if "mask" not in t.stem.lower() and "color" not in t.stem.lower()
56
+ ]
57
+ if not main_tifs:
58
+ raise FileNotFoundError(f"No main image found in {folder}")
59
+ # Prefer the largest file (main EM image) if multiple found
60
+ image_path = max(main_tifs, key=lambda t: t.stat().st_size)
61
+
62
+ # --- Find mask ---
63
+ mask_tifs = [t for t in all_tifs if "mask" in t.stem.lower()]
64
+ mask_path = None
65
+ if mask_tifs:
66
+ # Prefer plain "mask.tif" over "mask 1.tif" / "mask 2.tif"
67
+ plain = [t for t in mask_tifs if t.stem.lower().endswith("mask")]
68
+ mask_path = plain[0] if plain else mask_tifs[0]
69
+
70
+ # --- Find CSVs across all Results* subdirectories ---
71
+ results_dirs = sorted(folder.glob("Results*"))
72
+ # Also check direct subdirs like "Results 1", "Results 2"
73
+ csv_6nm_paths = []
74
+ csv_12nm_paths = []
75
+
76
+ for rdir in results_dirs:
77
+ if rdir.is_dir():
78
+ for csv_file in rdir.glob("*.csv"):
79
+ name_lower = csv_file.name.lower()
80
+ if "6nm" in name_lower:
81
+ csv_6nm_paths.append(csv_file)
82
+ elif "12nm" in name_lower:
83
+ csv_12nm_paths.append(csv_file)
84
+
85
+ record = SynapseRecord(
86
+ synapse_id=sid,
87
+ image_path=image_path,
88
+ mask_path=mask_path,
89
+ csv_6nm_paths=csv_6nm_paths,
90
+ csv_12nm_paths=csv_12nm_paths,
91
+ has_6nm=len(csv_6nm_paths) > 0,
92
+ has_12nm=len(csv_12nm_paths) > 0,
93
+ )
94
+ records.append(record)
95
+
96
+ return records
97
+
98
+
99
+ # ---------------------------------------------------------------------------
100
+ # Image I/O
101
+ # ---------------------------------------------------------------------------
102
+
103
+ def load_image(path: Path) -> np.ndarray:
104
+ """
105
+ Load a TIF image as grayscale uint8.
106
+
107
+ Handles:
108
+ - RGB images (take first channel)
109
+ - Palette-mode images
110
+ - Already-grayscale images
111
+ """
112
+ img = tifffile.imread(str(path))
113
+ if img.ndim == 3:
114
+ # RGB or multi-channel — take first channel (all channels identical in these images)
115
+ img = img[:, :, 0] if img.shape[2] <= 4 else img[0]
116
+ return img.astype(np.uint8)
117
+
118
+
119
+ def load_mask(path: Path) -> np.ndarray:
120
+ """
121
+ Load mask TIF as binary array.
122
+
123
+ Mask is RGB where tissue regions have values < 250 in at least one channel.
124
+ Returns boolean array: True = tissue/structural region.
125
+ """
126
+ mask_rgb = tifffile.imread(str(path))
127
+ if mask_rgb.ndim == 2:
128
+ return mask_rgb < 250
129
+ # RGB mask: tissue where any channel is not white
130
+ return np.any(mask_rgb < 250, axis=-1)
131
+
132
+
133
+ # ---------------------------------------------------------------------------
134
+ # Annotation loading and coordinate conversion
135
+ # ---------------------------------------------------------------------------
136
+
137
+ def load_annotations_csv(csv_path: Path) -> pd.DataFrame:
138
+ """
139
+ Load annotation CSV with columns [index, X, Y].
140
+
141
+ CSV headers have leading space: " ,X,Y".
142
+ Coordinates are normalized [0, 1] despite 'microns' in filename.
143
+ """
144
+ df = pd.read_csv(csv_path)
145
+ # Normalize column names (strip whitespace)
146
+ df.columns = [c.strip() for c in df.columns]
147
+ # Rename unnamed index column
148
+ if "" in df.columns:
149
+ df = df.rename(columns={"": "idx"})
150
+ return df[["X", "Y"]]
151
+
152
+
153
+ # Micron-to-pixel scale factor: consistent across all synapses (verified
154
+ # against researcher's color overlay TIFs). The CSV columns labeled "XY in
155
+ # microns" really ARE microns — multiply by this constant to get pixels.
156
+ MICRONS_TO_PIXELS = 1790.0
157
+
158
+
159
+ def load_all_annotations(
160
+ record: SynapseRecord, image_shape: Tuple[int, int]
161
+ ) -> Dict[str, np.ndarray]:
162
+ """
163
+ Load and convert annotations for one synapse to pixel coordinates.
164
+
165
+ CSV coordinates are in microns (despite filename suggesting normalization).
166
+ Multiply by MICRONS_TO_PIXELS (1790 px/micron) to convert.
167
+
168
+ Args:
169
+ record: SynapseRecord with CSV paths.
170
+ image_shape: (height, width) of the corresponding image.
171
+
172
+ Returns:
173
+ Dictionary with keys '6nm' and '12nm', each containing
174
+ an Nx2 array of (x, y) pixel coordinates.
175
+ """
176
+ h, w = image_shape[:2]
177
+ result = {"6nm": np.empty((0, 2), dtype=np.float64),
178
+ "12nm": np.empty((0, 2), dtype=np.float64)}
179
+
180
+ for cls, paths in [("6nm", record.csv_6nm_paths),
181
+ ("12nm", record.csv_12nm_paths)]:
182
+ all_coords = []
183
+ for csv_path in paths:
184
+ df = load_annotations_csv(csv_path)
185
+ # Convert microns to pixels
186
+ px_x = df["X"].values * MICRONS_TO_PIXELS
187
+ px_y = df["Y"].values * MICRONS_TO_PIXELS
188
+ # Validate: coords must fall within image bounds
189
+ assert px_x.max() < w + 10, \
190
+ f"X coords out of bounds ({px_x.max():.0f} > {w}) in {csv_path}"
191
+ assert px_y.max() < h + 10, \
192
+ f"Y coords out of bounds ({px_y.max():.0f} > {h}) in {csv_path}"
193
+ all_coords.append(np.stack([px_x, px_y], axis=1))
194
+
195
+ if all_coords:
196
+ coords = np.concatenate(all_coords, axis=0)
197
+ # Deduplicate (for S22 merged results): remove within 3px
198
+ if len(coords) > 1:
199
+ coords = _deduplicate_coords(coords, min_dist=3.0)
200
+ result[cls] = coords
201
+
202
+ return result
203
+
204
+
205
+ def _deduplicate_coords(
206
+ coords: np.ndarray, min_dist: float = 3.0
207
+ ) -> np.ndarray:
208
+ """Remove duplicate coordinates within min_dist pixels."""
209
+ from scipy.spatial.distance import cdist
210
+
211
+ if len(coords) <= 1:
212
+ return coords
213
+ dists = cdist(coords, coords)
214
+ np.fill_diagonal(dists, np.inf)
215
+ keep = np.ones(len(coords), dtype=bool)
216
+ for i in range(len(coords)):
217
+ if not keep[i]:
218
+ continue
219
+ # Mark later duplicates
220
+ for j in range(i + 1, len(coords)):
221
+ if keep[j] and dists[i, j] < min_dist:
222
+ keep[j] = False
223
+ return coords[keep]
224
+
225
+
226
+ # ---------------------------------------------------------------------------
227
+ # Preprocessing transforms
228
+ # ---------------------------------------------------------------------------
229
+
230
+ def preprocess_image(img: np.ndarray, bead_class: str,
231
+ tophat_radii: Optional[Dict[str, int]] = None,
232
+ clahe_clip_limit: float = 0.03,
233
+ clahe_kernel_size: int = 64) -> np.ndarray:
234
+ """
235
+ Top-hat + CLAHE preprocessing. Used ONLY by LodeStar (Stage 1).
236
+
237
+ Not used for model training — the CEM500K backbone expects raw EM images.
238
+ """
239
+ from skimage import exposure
240
+ from skimage.morphology import disk, white_tophat
241
+
242
+ if tophat_radii is None:
243
+ tophat_radii = {"6nm": 8, "12nm": 12}
244
+
245
+ img_inv = (255 - img).astype(np.float32)
246
+ radius = tophat_radii[bead_class]
247
+ tophat = white_tophat(img_inv, disk(radius))
248
+
249
+ tophat_max = tophat.max()
250
+ if tophat_max > 0:
251
+ tophat_norm = tophat / tophat_max
252
+ else:
253
+ tophat_norm = tophat
254
+
255
+ enhanced = exposure.equalize_adapthist(
256
+ tophat_norm,
257
+ clip_limit=clahe_clip_limit,
258
+ kernel_size=clahe_kernel_size,
259
+ )
260
+ return (enhanced * 255).astype(np.uint8)
261
+
262
+
263
+ # ---------------------------------------------------------------------------
264
+ # Convenience: load everything for one synapse
265
+ # ---------------------------------------------------------------------------
266
+
267
+ def load_synapse(record: SynapseRecord) -> dict:
268
+ """
269
+ Load image, mask, and annotations for one synapse.
270
+
271
+ Returns dict with keys: 'image', 'mask', 'annotations',
272
+ 'synapse_id', 'image_shape'
273
+ """
274
+ img = load_image(record.image_path)
275
+ mask = load_mask(record.mask_path) if record.mask_path else None
276
+ annotations = load_all_annotations(record, img.shape)
277
+
278
+ return {
279
+ "synapse_id": record.synapse_id,
280
+ "image": img,
281
+ "mask": mask,
282
+ "annotations": annotations,
283
+ "image_shape": img.shape,
284
+ }