AnikS22 commited on
Commit
933831b
·
verified ·
1 Parent(s): d1fe61c

Upload src/dataset.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/dataset.py +438 -0
src/dataset.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PyTorch Dataset for immunogold particle detection.
3
+
4
+ Implements patch-based training with:
5
+ - 70% hard mining (patches centered near particles)
6
+ - 30% random patches (background recognition)
7
+ - Copy-paste augmentation with Gaussian-blended bead bank
8
+ - Albumentations pipeline with keypoint co-transforms
9
+ """
10
+
11
+ import random
12
+ from pathlib import Path
13
+ from typing import Dict, List, Optional, Tuple
14
+
15
+ import albumentations as A
16
+ import cv2
17
+ import numpy as np
18
+ import torch
19
+ from torch.utils.data import Dataset
20
+
21
+ from src.heatmap import generate_heatmap_gt
22
+ from src.preprocessing import (
23
+ SynapseRecord,
24
+ load_all_annotations,
25
+ load_image,
26
+ load_mask,
27
+ )
28
+
29
+
30
+ # ---------------------------------------------------------------------------
31
+ # Augmentation pipeline
32
+ # ---------------------------------------------------------------------------
33
+
34
+ def get_train_augmentation() -> A.Compose:
35
+ """
36
+ Training augmentation pipeline.
37
+
38
+ Conservative intensity limits: contrast delta is only 11-39 units on uint8.
39
+ DO NOT use Cutout/Mixup/JPEG artifacts — they destroy or mimic particles.
40
+ """
41
+ return A.Compose(
42
+ [
43
+ # Geometric (co-transform keypoints)
44
+ A.RandomRotate90(p=1.0), # EM is rotation invariant
45
+ A.HorizontalFlip(p=0.5),
46
+ A.VerticalFlip(p=0.5),
47
+ # Only ±10° to avoid interpolation artifacts that destroy contrast
48
+ A.Rotate(
49
+ limit=10,
50
+ border_mode=cv2.BORDER_REFLECT_101,
51
+ p=0.5,
52
+ ),
53
+ # Mild elastic deformation (simulates section flatness variation)
54
+ A.ElasticTransform(alpha=30, sigma=5, p=0.3),
55
+ # Intensity (image only)
56
+ A.RandomBrightnessContrast(
57
+ brightness_limit=0.08, # NOT default 0.2
58
+ contrast_limit=0.08,
59
+ p=0.7,
60
+ ),
61
+ # EM shot noise simulation
62
+ A.GaussNoise(p=0.5),
63
+ # Mild blur — simulate slight defocus
64
+ A.GaussianBlur(blur_limit=(3, 3), p=0.2),
65
+ ],
66
+ keypoint_params=A.KeypointParams(
67
+ format="xy",
68
+ remove_invisible=True,
69
+ label_fields=["class_labels"],
70
+ ),
71
+ )
72
+
73
+
74
+ def get_val_augmentation() -> A.Compose:
75
+ """No augmentation for validation — identity transform."""
76
+ return A.Compose(
77
+ [],
78
+ keypoint_params=A.KeypointParams(
79
+ format="xy",
80
+ remove_invisible=True,
81
+ label_fields=["class_labels"],
82
+ ),
83
+ )
84
+
85
+
86
+ # ---------------------------------------------------------------------------
87
+ # Bead bank for copy-paste augmentation
88
+ # ---------------------------------------------------------------------------
89
+
90
+ class BeadBank:
91
+ """
92
+ Pre-extracted particle crops for copy-paste augmentation.
93
+
94
+ Stores small patches centered on annotated particles from training
95
+ images. During training, random beads are pasted onto patches to
96
+ increase particle density and address class imbalance.
97
+ """
98
+
99
+ def __init__(self):
100
+ self.crops: Dict[str, List[Tuple[np.ndarray, int]]] = {
101
+ "6nm": [],
102
+ "12nm": [],
103
+ }
104
+ self.crop_sizes = {"6nm": 32, "12nm": 48}
105
+
106
+ def extract_from_image(
107
+ self,
108
+ image: np.ndarray,
109
+ annotations: Dict[str, np.ndarray],
110
+ ):
111
+ """Extract bead crops from a training image."""
112
+ h, w = image.shape[:2]
113
+
114
+ for cls, coords in annotations.items():
115
+ crop_size = self.crop_sizes[cls]
116
+ half = crop_size // 2
117
+
118
+ for x, y in coords:
119
+ xi, yi = int(round(x)), int(round(y))
120
+ # Skip if too close to edge
121
+ if yi - half < 0 or yi + half > h or xi - half < 0 or xi + half > w:
122
+ continue
123
+
124
+ crop = image[yi - half : yi + half, xi - half : xi + half].copy()
125
+ if crop.shape == (crop_size, crop_size):
126
+ self.crops[cls].append((crop, half))
127
+
128
+ def paste_beads(
129
+ self,
130
+ image: np.ndarray,
131
+ coords_6nm: List[Tuple[float, float]],
132
+ coords_12nm: List[Tuple[float, float]],
133
+ class_labels: List[str],
134
+ mask: Optional[np.ndarray] = None,
135
+ n_paste_per_class: int = 5,
136
+ rng: Optional[np.random.Generator] = None,
137
+ ) -> Tuple[np.ndarray, List[Tuple[float, float]], List[Tuple[float, float]], List[str]]:
138
+ """
139
+ Paste random beads onto image with Gaussian alpha blending.
140
+
141
+ Returns augmented image and updated coordinate lists.
142
+ """
143
+ if rng is None:
144
+ rng = np.random.default_rng()
145
+
146
+ image = image.copy()
147
+ h, w = image.shape[:2]
148
+ new_coords_6nm = list(coords_6nm)
149
+ new_coords_12nm = list(coords_12nm)
150
+ new_labels = list(class_labels)
151
+
152
+ for cls in ["6nm", "12nm"]:
153
+ if not self.crops[cls]:
154
+ continue
155
+
156
+ crop_size = self.crop_sizes[cls]
157
+ half = crop_size // 2
158
+ n_paste = min(n_paste_per_class, len(self.crops[cls]))
159
+
160
+ for _ in range(n_paste):
161
+ # Random paste location (within image bounds)
162
+ px = rng.integers(half + 5, w - half - 5)
163
+ py = rng.integers(half + 5, h - half - 5)
164
+
165
+ # Skip if outside tissue mask
166
+ if mask is not None:
167
+ if py >= mask.shape[0] or px >= mask.shape[1] or not mask[py, px]:
168
+ continue
169
+
170
+ # Check minimum distance from existing particles (avoid overlap)
171
+ too_close = False
172
+ all_existing = new_coords_6nm + new_coords_12nm
173
+ for ex, ey in all_existing:
174
+ if (ex - px) ** 2 + (ey - py) ** 2 < (half * 1.5) ** 2:
175
+ too_close = True
176
+ break
177
+ if too_close:
178
+ continue
179
+
180
+ # Select random crop
181
+ crop, _ = self.crops[cls][rng.integers(len(self.crops[cls]))]
182
+
183
+ # Gaussian alpha mask for soft blending
184
+ yy, xx = np.mgrid[:crop_size, :crop_size]
185
+ center = crop_size / 2
186
+ sigma = half * 0.7
187
+ alpha = np.exp(-((xx - center) ** 2 + (yy - center) ** 2) / (2 * sigma ** 2))
188
+
189
+ # Blend
190
+ region = image[py - half : py + half, px - half : px + half]
191
+ if region.shape != crop.shape:
192
+ continue
193
+ blended = (alpha * crop + (1 - alpha) * region).astype(np.uint8)
194
+ image[py - half : py + half, px - half : px + half] = blended
195
+
196
+ # Add to annotations
197
+ if cls == "6nm":
198
+ new_coords_6nm.append((float(px), float(py)))
199
+ else:
200
+ new_coords_12nm.append((float(px), float(py)))
201
+ new_labels.append(cls)
202
+
203
+ return image, new_coords_6nm, new_coords_12nm, new_labels
204
+
205
+
206
+ # ---------------------------------------------------------------------------
207
+ # Dataset
208
+ # ---------------------------------------------------------------------------
209
+
210
+ class ImmunogoldDataset(Dataset):
211
+ """
212
+ Patch-based dataset for immunogold particle detection.
213
+
214
+ Sampling strategy:
215
+ - 70% of patches centered within 100px of a known particle (hard mining)
216
+ - 30% of patches at random locations (background recognition)
217
+
218
+ This ensures the model sees particles in nearly every batch despite
219
+ particles occupying <0.1% of image area.
220
+ """
221
+
222
+ def __init__(
223
+ self,
224
+ records: List[SynapseRecord],
225
+ fold_id: str,
226
+ mode: str = "train",
227
+ patch_size: int = 512,
228
+ stride: int = 2,
229
+ hard_mining_fraction: float = 0.7,
230
+ copy_paste_per_class: int = 5,
231
+ sigmas: Optional[Dict[str, float]] = None,
232
+ samples_per_epoch: int = 200,
233
+ seed: int = 42,
234
+ ):
235
+ """
236
+ Args:
237
+ records: all SynapseRecord entries
238
+ fold_id: synapse_id to hold out (test set)
239
+ mode: 'train' or 'val'
240
+ patch_size: training patch size
241
+ stride: model output stride
242
+ hard_mining_fraction: fraction of patches near particles
243
+ copy_paste_per_class: beads to paste per class
244
+ sigmas: heatmap Gaussian sigmas per class
245
+ samples_per_epoch: virtual epoch size
246
+ seed: random seed
247
+ """
248
+ super().__init__()
249
+ self.patch_size = patch_size
250
+ self.stride = stride
251
+ self.hard_mining_fraction = hard_mining_fraction
252
+ self.copy_paste_per_class = copy_paste_per_class if mode == "train" else 0
253
+ self.sigmas = sigmas or {"6nm": 1.0, "12nm": 1.5}
254
+ self.samples_per_epoch = samples_per_epoch
255
+ self.mode = mode
256
+ self._base_seed = seed
257
+ self.rng = np.random.default_rng(seed)
258
+
259
+ # Split records
260
+ if mode == "train":
261
+ self.records = [r for r in records if r.synapse_id != fold_id]
262
+ elif mode == "val":
263
+ self.records = [r for r in records if r.synapse_id == fold_id]
264
+ else:
265
+ self.records = records
266
+
267
+ # Pre-load all images and annotations into memory (~4MB each × 10 = 40MB)
268
+ self.images = {}
269
+ self.masks = {}
270
+ self.annotations = {}
271
+
272
+ for record in self.records:
273
+ sid = record.synapse_id
274
+ self.images[sid] = load_image(record.image_path)
275
+ if record.mask_path:
276
+ self.masks[sid] = load_mask(record.mask_path)
277
+ self.annotations[sid] = load_all_annotations(record, self.images[sid].shape)
278
+
279
+ # Build particle index for hard mining
280
+ self._build_particle_index()
281
+
282
+ # Build bead bank for copy-paste
283
+ self.bead_bank = BeadBank()
284
+ if mode == "train":
285
+ for sid in self.images:
286
+ self.bead_bank.extract_from_image(
287
+ self.images[sid], self.annotations[sid]
288
+ )
289
+
290
+ # Augmentation
291
+ if mode == "train":
292
+ self.transform = get_train_augmentation()
293
+ else:
294
+ self.transform = get_val_augmentation()
295
+
296
+ def _build_particle_index(self):
297
+ """Build flat index of all particles for hard mining."""
298
+ self.particle_list = [] # (synapse_id, x, y, class)
299
+ for sid, annots in self.annotations.items():
300
+ for cls in ["6nm", "12nm"]:
301
+ for x, y in annots[cls]:
302
+ self.particle_list.append((sid, x, y, cls))
303
+
304
+ @staticmethod
305
+ def worker_init_fn(worker_id: int):
306
+ """Re-seed RNG per DataLoader worker to avoid identical sequences."""
307
+ import torch
308
+ seed = torch.initial_seed() % (2**32) + worker_id
309
+ np.random.seed(seed)
310
+
311
+ def __len__(self) -> int:
312
+ return self.samples_per_epoch
313
+
314
+ def __getitem__(self, idx: int) -> dict:
315
+ # Reseed RNG using idx so each call produces a unique patch.
316
+ # Without this, the same 200 patches repeat every epoch → instant overfitting.
317
+ self.rng = np.random.default_rng(self._base_seed + idx + int(torch.initial_seed() % 100000))
318
+ """
319
+ Sample a patch with ground truth heatmap.
320
+
321
+ Returns dict with:
322
+ 'image': (1, patch_size, patch_size) float32 tensor
323
+ 'heatmap': (2, patch_size//stride, patch_size//stride) float32
324
+ 'offsets': (2, patch_size//stride, patch_size//stride) float32
325
+ 'offset_mask': (patch_size//stride, patch_size//stride) bool
326
+ 'conf_map': (2, patch_size//stride, patch_size//stride) float32
327
+ """
328
+ # Decide: hard or random patch
329
+ do_hard = (self.rng.random() < self.hard_mining_fraction
330
+ and len(self.particle_list) > 0
331
+ and self.mode == "train")
332
+
333
+ if do_hard:
334
+ # Pick random particle, center patch on it with jitter
335
+ pidx = self.rng.integers(len(self.particle_list))
336
+ sid, px, py, _ = self.particle_list[pidx]
337
+ # Jitter center up to 128px
338
+ jitter = 128
339
+ cx = int(px + self.rng.integers(-jitter, jitter + 1))
340
+ cy = int(py + self.rng.integers(-jitter, jitter + 1))
341
+ else:
342
+ # Random image and location
343
+ sid = list(self.images.keys())[
344
+ self.rng.integers(len(self.images))
345
+ ]
346
+ h, w = self.images[sid].shape[:2]
347
+ cx = self.rng.integers(self.patch_size // 2, w - self.patch_size // 2)
348
+ cy = self.rng.integers(self.patch_size // 2, h - self.patch_size // 2)
349
+
350
+ # Extract patch
351
+ image = self.images[sid]
352
+ h, w = image.shape[:2]
353
+ half = self.patch_size // 2
354
+
355
+ # Clamp to image bounds
356
+ cx = max(half, min(w - half, cx))
357
+ cy = max(half, min(h - half, cy))
358
+
359
+ x0, x1 = cx - half, cx + half
360
+ y0, y1 = cy - half, cy + half
361
+
362
+ patch = image[y0:y1, x0:x1].copy()
363
+
364
+ # Pad if needed (edge cases)
365
+ if patch.shape[0] != self.patch_size or patch.shape[1] != self.patch_size:
366
+ padded = np.zeros((self.patch_size, self.patch_size), dtype=np.uint8)
367
+ ph, pw = patch.shape[:2]
368
+ padded[:ph, :pw] = patch
369
+ patch = padded
370
+
371
+ # Get annotations within this patch (convert to patch-local coordinates)
372
+ keypoints = []
373
+ class_labels = []
374
+ for cls in ["6nm", "12nm"]:
375
+ for ax, ay in self.annotations[sid][cls]:
376
+ # Convert to patch-local coords
377
+ lx = ax - x0
378
+ ly = ay - y0
379
+ if 0 <= lx < self.patch_size and 0 <= ly < self.patch_size:
380
+ keypoints.append((lx, ly))
381
+ class_labels.append(cls)
382
+
383
+ # Copy-paste augmentation (before geometric transforms)
384
+ if self.copy_paste_per_class > 0 and self.mode == "train":
385
+ local_6nm = [(x, y) for (x, y), c in zip(keypoints, class_labels) if c == "6nm"]
386
+ local_12nm = [(x, y) for (x, y), c in zip(keypoints, class_labels) if c == "12nm"]
387
+ mask_patch = None
388
+ if sid in self.masks:
389
+ mask_patch = self.masks[sid][y0:y1, x0:x1]
390
+
391
+ patch, local_6nm, local_12nm, class_labels = self.bead_bank.paste_beads(
392
+ patch, local_6nm, local_12nm, class_labels,
393
+ mask=mask_patch,
394
+ n_paste_per_class=self.copy_paste_per_class,
395
+ rng=self.rng,
396
+ )
397
+ # Rebuild keypoints from updated coords
398
+ keypoints = [(x, y) for x, y in local_6nm] + [(x, y) for x, y in local_12nm]
399
+ class_labels = ["6nm"] * len(local_6nm) + ["12nm"] * len(local_12nm)
400
+
401
+ # Apply augmentation (co-transforms keypoints)
402
+ transformed = self.transform(
403
+ image=patch,
404
+ keypoints=keypoints,
405
+ class_labels=class_labels,
406
+ )
407
+ patch_aug = transformed["image"]
408
+ kp_aug = transformed["keypoints"]
409
+ cl_aug = transformed["class_labels"]
410
+
411
+ # Separate keypoints by class
412
+ coords_6nm = np.array(
413
+ [(x, y) for (x, y), c in zip(kp_aug, cl_aug) if c == "6nm"],
414
+ dtype=np.float64,
415
+ ).reshape(-1, 2)
416
+ coords_12nm = np.array(
417
+ [(x, y) for (x, y), c in zip(kp_aug, cl_aug) if c == "12nm"],
418
+ dtype=np.float64,
419
+ ).reshape(-1, 2)
420
+
421
+ # Generate heatmap GT from TRANSFORMED coordinates (never warp heatmap)
422
+ heatmap, offsets, offset_mask, conf_map = generate_heatmap_gt(
423
+ coords_6nm, coords_12nm,
424
+ self.patch_size, self.patch_size,
425
+ sigmas=self.sigmas,
426
+ stride=self.stride,
427
+ )
428
+
429
+ # Convert to tensors
430
+ patch_tensor = torch.from_numpy(patch_aug).float().unsqueeze(0) / 255.0
431
+
432
+ return {
433
+ "image": patch_tensor,
434
+ "heatmap": torch.from_numpy(heatmap),
435
+ "offsets": torch.from_numpy(offsets),
436
+ "offset_mask": torch.from_numpy(offset_mask),
437
+ "conf_map": torch.from_numpy(conf_map),
438
+ }