MRiabov commited on
Commit
efdb6a6
·
1 Parent(s): a527623

(optim) prefetch and speed up eval.

Browse files
Files changed (3) hide show
  1. SEGMENTATION_PLAN.md +1 -1
  2. configs/default.yaml +3 -3
  3. train.py +111 -53
SEGMENTATION_PLAN.md CHANGED
@@ -108,7 +108,7 @@ Focus: segmentation only (no dataset collection or inpainting).
108
  ## Configuration Surface (key)
109
  - Backbone/weights: `mit_b2` (pretrained ImageNet-1K).
110
  - Sizes: `p=768`, `coarse_train=512`, `coarse_test=1024`, `overlap=128`.
111
- - Conditioning: `use_binary_location=true`, `cond_from='coarse_logits_1x1'`, `cond_crop='patch'`.
112
  - MinMax: `enable=true`, `kernel=6`.
113
  - Label: `coarse_label_downsample='maxpool'`.
114
  - Training: `iters=40000`, `batch=8`, `lr=6e-5`, `wd=0.01`, `schedule='poly'`, `power=1.0`.
 
108
  ## Configuration Surface (key)
109
  - Backbone/weights: `mit_b2` (pretrained ImageNet-1K).
110
  - Sizes: `p=768`, `coarse_train=512`, `coarse_test=1024`, `overlap=128`.
111
+ - Conditioning: `cond_from='coarse_logits_1x1'`, `cond_crop='patch'`.
112
  - MinMax: `enable=true`, `kernel=6`.
113
  - Label: `coarse_label_downsample='maxpool'`.
114
  - Training: `iters=40000`, `batch=8`, `lr=6e-5`, `wd=0.01`, `schedule='poly'`, `power=1.0`.
configs/default.yaml CHANGED
@@ -37,9 +37,9 @@ optim:
37
  # training housekeeping
38
  seed: 42
39
  out_dir: runs/wireseghr
40
- eval_interval: 30
41
- ckpt_interval: 100
42
- resume: runs/wireseghr/ckpt_1500.pt # optional
43
 
44
  # dataset paths (placeholders)
45
  data:
 
37
  # training housekeeping
38
  seed: 42
39
  out_dir: runs/wireseghr
40
+ eval_interval: 100
41
+ ckpt_interval: 300
42
+ # resume: runs/wireseghr/ckpt_1500.pt # optional
43
 
44
  # dataset paths (placeholders)
45
  data:
train.py CHANGED
@@ -14,6 +14,7 @@ from tqdm import tqdm
14
  import random
15
  import torch.backends.cudnn as cudnn
16
  import cv2
 
17
 
18
  from src.wireseghr.model import WireSegHR
19
  from src.wireseghr.model.minmax import MinMaxLuminance
@@ -23,6 +24,46 @@ from src.wireseghr.data.sampler import BalancedPatchSampler
23
  from src.wireseghr.metrics import compute_metrics
24
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def main():
27
  parser = argparse.ArgumentParser(description="WireSegHR training (skeleton)")
28
  parser.add_argument(
@@ -66,6 +107,23 @@ def main():
66
  train_images = cfg["data"]["train_images"]
67
  train_masks = cfg["data"]["train_masks"]
68
  dset = WireSegDataset(train_images, train_masks, split="train")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  # Validation and test
70
  val_images = cfg["data"].get("val_images", None)
71
  val_masks = cfg["data"].get("val_masks", None)
@@ -120,9 +178,14 @@ def main():
120
  model.train()
121
  step = start_step
122
  pbar = tqdm(total=iters - step, initial=0, desc="Train", ncols=100)
 
123
  while step < iters:
124
  optim.zero_grad(set_to_none=True)
125
- imgs, masks = _sample_batch_same_size(dset, batch_size)
 
 
 
 
126
  batch = _prepare_batch(
127
  imgs, masks, coarse_train, patch_size, sampler, minmax, device
128
  )
@@ -169,6 +232,9 @@ def main():
169
 
170
  # Eval & Checkpoint
171
  if (step % eval_interval == 0) and (dset_val is not None):
 
 
 
172
  model.eval()
173
  val_stats = validate(
174
  model,
@@ -181,6 +247,7 @@ def main():
181
  mm_kernel,
182
  patch_size,
183
  overlap,
 
184
  )
185
  print(
186
  f"[Val @ {step}][Fine] IoU={val_stats['iou']:.4f} F1={val_stats['f1']:.4f} P={val_stats['precision']:.4f} R={val_stats['recall']:.4f}"
@@ -323,8 +390,8 @@ def _prepare_batch(
323
  )[0]
324
  zeros_coarse = torch.zeros(1, coarse_train, coarse_train, device=device)
325
  c_t = torch.cat(
326
- [rgb_coarse_t, y_min_c_t, y_max_c_t, zeros_coarse, zeros_coarse], dim=0
327
- ) # 7xHc x Wc
328
  xs_coarse.append(c_t)
329
 
330
  # Sample fine patch (CPU mask), then slice GPU min/max and transfer only patches
@@ -347,8 +414,6 @@ def _prepare_batch(
347
  )
348
  patches_min.append(ymin_patch)
349
  patches_max.append(ymax_patch)
350
- # Binary location mask (ones inside the patch)
351
- loc_masks.append(np.ones((patch_size, patch_size), dtype=np.float32))
352
  yx_list.append((y0, x0))
353
 
354
  x_coarse = torch.stack(xs_coarse, dim=0) # already on device
@@ -362,7 +427,6 @@ def _prepare_batch(
362
  "mask_patches": patches_mask,
363
  "ymin_patches": patches_min,
364
  "ymax_patches": patches_max,
365
- "loc_patches": loc_masks,
366
  "patch_yx": yx_list,
367
  "mask_full": masks,
368
  }
@@ -371,9 +435,9 @@ def _prepare_batch(
371
  def _build_fine_inputs(
372
  batch, cond_map: torch.Tensor, device: torch.device
373
  ) -> torch.Tensor:
374
- # Build fine input tensor Bx7xP x P; crop cond from low-res map, upsample to P
375
  B = cond_map.shape[0]
376
- P = batch["loc_patches"][0].shape[0]
377
  full_h, full_w = batch["full_h"], batch["full_w"]
378
  hc4, wc4 = cond_map.shape[2], cond_map.shape[3]
379
 
@@ -382,7 +446,6 @@ def _build_fine_inputs(
382
  rgb = batch["rgb_patches"][i]
383
  ymin = batch["ymin_patches"][i]
384
  ymax = batch["ymax_patches"][i]
385
- loc = batch["loc_patches"][i]
386
  y0, x0 = batch["patch_yx"][i]
387
 
388
  # Map full-res patch box to low-res cond grid, crop and upsample to P
@@ -402,8 +465,7 @@ def _build_fine_inputs(
402
  ) # 3xPxP
403
  ymin_t = torch.from_numpy(ymin)[None, ...].to(device).float() # 1xPxP
404
  ymax_t = torch.from_numpy(ymax)[None, ...].to(device).float() # 1xPxP
405
- loc_t = torch.from_numpy(loc)[None, ...].to(device).float() # 1xPxP
406
- x = torch.cat([rgb_t, ymin_t, ymax_t, cond_patch, loc_t], dim=0) # 7xPxP
407
  xs.append(x)
408
  x_fine = torch.stack(xs, dim=0)
409
  return x_fine
@@ -492,6 +554,7 @@ def validate(
492
  minmax_kernel: int,
493
  fine_patch_size: int,
494
  fine_overlap: int,
 
495
  ) -> Dict[str, float]:
496
  # Coarse-only validation: resize image to coarse_size, predict coarse logits, upsample to full and compute metrics
497
  model = model.to(device)
@@ -558,16 +621,16 @@ def validate(
558
  for k in coarse_sum:
559
  coarse_sum[k] += m_c[k]
560
 
561
- # Fine-stage tiled inference and stitching
562
  P = fine_patch_size
563
  stride = P - fine_overlap
564
  assert stride > 0
565
  assert H >= P and W >= P
566
- prob_sum = np.zeros((H, W), dtype=np.float32)
567
- weight = np.zeros((H, W), dtype=np.float32)
 
568
 
569
  # Prepare min/max on full-res (already computed above as y_min_full/y_max_full)
570
- # y_min_full, y_max_full exist in this scope from above branch
571
  hc4, wc4 = cond_map.shape[2], cond_map.shape[3]
572
 
573
  ys = list(range(0, max(H - P, 0) + 1, stride))
@@ -577,15 +640,16 @@ def validate(
577
  if xs[-1] != (W - P):
578
  xs.append(W - P)
579
 
 
580
  for y0 in ys:
581
  for x0 in xs:
582
- y1, x1 = y0 + P, x0 + P
583
- # Build fine input 1x7xP x P
584
- patch_rgb = img[y0:y1, x0:x1, :]
585
- ymin_patch = y_min_full[0, 0, y0:y1, x0:x1].detach().cpu().numpy()
586
- ymax_patch = y_max_full[0, 0, y0:y1, x0:x1].detach().cpu().numpy()
587
- loc = np.ones((P, P), dtype=np.float32)
588
 
 
 
 
 
 
589
  # Cond crop mapping (same as training _build_fine_inputs)
590
  y0c = (y0 * hc4) // H
591
  y1c = ((y1 * hc4) + H - 1) // H
@@ -596,37 +660,31 @@ def validate(
596
  cond_sub, size=(P, P), mode="bilinear", align_corners=False
597
  ).squeeze(1) # 1xPxP
598
 
599
- rgb_t = (
600
- torch.from_numpy(np.transpose(patch_rgb, (2, 0, 1)))
601
- .to(device)
602
- .float()
603
- ) # 3xPxP
604
- ymin_t = torch.from_numpy(ymin_patch)[None, ...].to(device).float()
605
- ymax_t = torch.from_numpy(ymax_patch)[None, ...].to(device).float()
606
- loc_t = torch.from_numpy(loc)[None, ...].to(device).float()
607
- x_f = torch.cat(
608
- [rgb_t, ymin_t, ymax_t, cond_patch, loc_t], dim=0
609
- ).unsqueeze(0)
610
-
611
- with autocast(
612
- device_type=device.type,
613
- enabled=(device.type == "cuda" and amp_flag),
614
- ):
615
- logits_f = model.forward_fine(x_f)
616
- prob_f = torch.softmax(logits_f, dim=1)[:, 1:2]
617
- prob_f_up = (
618
- F.interpolate(
619
- prob_f, size=(P, P), mode="bilinear", align_corners=False
620
- )[0, 0]
621
- .detach()
622
- .cpu()
623
- .numpy()
624
- )
625
-
626
- prob_sum[y0:y1, x0:x1] += prob_f_up
627
- weight[y0:y1, x0:x1] += 1.0
628
-
629
- prob_full = prob_sum / weight
630
  pred_fine = (prob_full > prob_thresh).astype(np.uint8)
631
  m_f = compute_metrics(pred_fine, mask)
632
  for k in metrics_sum:
@@ -699,7 +757,7 @@ def save_test_visuals(
699
  align_corners=False,
700
  )[0]
701
  zeros_c = torch.zeros(1, coarse_size, coarse_size, device=device)
702
- x_t = torch.cat([rgb_c, y_min_c, y_max_c, zeros_c, zeros_c], dim=0).unsqueeze(0)
703
  with autocast(
704
  device_type=device.type, enabled=(device.type == "cuda" and amp_flag)
705
  ):
 
14
  import random
15
  import torch.backends.cudnn as cudnn
16
  import cv2
17
+ from torch.utils.data import DataLoader
18
 
19
  from src.wireseghr.model import WireSegHR
20
  from src.wireseghr.model.minmax import MinMaxLuminance
 
24
  from src.wireseghr.metrics import compute_metrics
25
 
26
 
27
+ class SizeBatchSampler:
28
+ """Batch sampler that groups indices by exact (H, W) so all samples in a batch share size.
29
+
30
+ This enables DataLoader prefetching while preserving the existing assumption
31
+ in `_prepare_batch()` that all items in a batch have the same full resolution.
32
+ """
33
+
34
+ def __init__(self, dset: WireSegDataset, batch_size: int):
35
+ self.dset = dset
36
+ self.batch_size = batch_size
37
+ # Precompute epoch length as the total number of full batches across bins
38
+ bins = self.dset.size_bins
39
+ self._len = 0
40
+ for hw, idxs in bins.items():
41
+ _ = hw # unused, clarity
42
+ self._len += (len(idxs) // self.batch_size)
43
+
44
+ def __len__(self) -> int:
45
+ return self._len
46
+
47
+ def __iter__(self):
48
+ # Create randomized batches per epoch across size bins
49
+ bins = self.dset.size_bins
50
+ keys = list(bins.keys())
51
+ random.shuffle(keys)
52
+ for hw in keys:
53
+ pool = list(bins[hw])
54
+ random.shuffle(pool)
55
+ # Yield only full batches to keep fixed batch size and same-size assumption
56
+ for i in range(0, len(pool) - (len(pool) % self.batch_size), self.batch_size):
57
+ yield pool[i : i + self.batch_size]
58
+
59
+
60
+ def collate_train(batch: List[Dict]):
61
+ """Collate function that returns lists of numpy arrays to match existing pipeline."""
62
+ imgs = [b["image"] for b in batch]
63
+ masks = [b["mask"] for b in batch]
64
+ return imgs, masks
65
+
66
+
67
  def main():
68
  parser = argparse.ArgumentParser(description="WireSegHR training (skeleton)")
69
  parser.add_argument(
 
107
  train_images = cfg["data"]["train_images"]
108
  train_masks = cfg["data"]["train_masks"]
109
  dset = WireSegDataset(train_images, train_masks, split="train")
110
+ # DataLoader with prefetching and size-aware batching
111
+ loader_cfg = cfg.get("loader", {})
112
+ num_workers = int(loader_cfg.get("num_workers", 4))
113
+ prefetch_factor = int(loader_cfg.get("prefetch_factor", 2))
114
+ pin_memory = bool(loader_cfg.get("pin_memory", True))
115
+ persistent_workers = bool(loader_cfg.get("persistent_workers", True)) if num_workers > 0 else False
116
+ batch_sampler = SizeBatchSampler(dset, batch_size)
117
+ loader_kwargs = dict(
118
+ batch_sampler=batch_sampler,
119
+ num_workers=num_workers,
120
+ pin_memory=pin_memory,
121
+ persistent_workers=persistent_workers,
122
+ collate_fn=collate_train,
123
+ )
124
+ if num_workers > 0:
125
+ loader_kwargs["prefetch_factor"] = prefetch_factor
126
+ train_loader = DataLoader(dset, **loader_kwargs)
127
  # Validation and test
128
  val_images = cfg["data"].get("val_images", None)
129
  val_masks = cfg["data"].get("val_masks", None)
 
178
  model.train()
179
  step = start_step
180
  pbar = tqdm(total=iters - step, initial=0, desc="Train", ncols=100)
181
+ data_iter = iter(train_loader)
182
  while step < iters:
183
  optim.zero_grad(set_to_none=True)
184
+ try:
185
+ imgs, masks = next(data_iter)
186
+ except StopIteration:
187
+ data_iter = iter(train_loader)
188
+ imgs, masks = next(data_iter)
189
  batch = _prepare_batch(
190
  imgs, masks, coarse_train, patch_size, sampler, minmax, device
191
  )
 
232
 
233
  # Eval & Checkpoint
234
  if (step % eval_interval == 0) and (dset_val is not None):
235
+ # Free training-step tensors before eval to lower peak memory
236
+ del x_fine, logits_coarse, cond_map, logits_fine, y_coarse, y_fine, loss_coarse, loss_fine, loss
237
+ torch.cuda.empty_cache()
238
  model.eval()
239
  val_stats = validate(
240
  model,
 
247
  mm_kernel,
248
  patch_size,
249
  overlap,
250
+ batch_size,
251
  )
252
  print(
253
  f"[Val @ {step}][Fine] IoU={val_stats['iou']:.4f} F1={val_stats['f1']:.4f} P={val_stats['precision']:.4f} R={val_stats['recall']:.4f}"
 
390
  )[0]
391
  zeros_coarse = torch.zeros(1, coarse_train, coarse_train, device=device)
392
  c_t = torch.cat(
393
+ [rgb_coarse_t, y_min_c_t, y_max_c_t, zeros_coarse], dim=0
394
+ ) # 6xHc x Wc
395
  xs_coarse.append(c_t)
396
 
397
  # Sample fine patch (CPU mask), then slice GPU min/max and transfer only patches
 
414
  )
415
  patches_min.append(ymin_patch)
416
  patches_max.append(ymax_patch)
 
 
417
  yx_list.append((y0, x0))
418
 
419
  x_coarse = torch.stack(xs_coarse, dim=0) # already on device
 
427
  "mask_patches": patches_mask,
428
  "ymin_patches": patches_min,
429
  "ymax_patches": patches_max,
 
430
  "patch_yx": yx_list,
431
  "mask_full": masks,
432
  }
 
435
  def _build_fine_inputs(
436
  batch, cond_map: torch.Tensor, device: torch.device
437
  ) -> torch.Tensor:
438
+ # Build fine input tensor Bx6xP x P; crop cond from low-res map, upsample to P
439
  B = cond_map.shape[0]
440
+ P = batch["rgb_patches"][0].shape[0]
441
  full_h, full_w = batch["full_h"], batch["full_w"]
442
  hc4, wc4 = cond_map.shape[2], cond_map.shape[3]
443
 
 
446
  rgb = batch["rgb_patches"][i]
447
  ymin = batch["ymin_patches"][i]
448
  ymax = batch["ymax_patches"][i]
 
449
  y0, x0 = batch["patch_yx"][i]
450
 
451
  # Map full-res patch box to low-res cond grid, crop and upsample to P
 
465
  ) # 3xPxP
466
  ymin_t = torch.from_numpy(ymin)[None, ...].to(device).float() # 1xPxP
467
  ymax_t = torch.from_numpy(ymax)[None, ...].to(device).float() # 1xPxP
468
+ x = torch.cat([rgb_t, ymin_t, ymax_t, cond_patch], dim=0) # 6xPxP
 
469
  xs.append(x)
470
  x_fine = torch.stack(xs, dim=0)
471
  return x_fine
 
554
  minmax_kernel: int,
555
  fine_patch_size: int,
556
  fine_overlap: int,
557
+ fine_batch: int,
558
  ) -> Dict[str, float]:
559
  # Coarse-only validation: resize image to coarse_size, predict coarse logits, upsample to full and compute metrics
560
  model = model.to(device)
 
621
  for k in coarse_sum:
622
  coarse_sum[k] += m_c[k]
623
 
624
+ # Fine-stage tiled inference and stitching (BATCHED)
625
  P = fine_patch_size
626
  stride = P - fine_overlap
627
  assert stride > 0
628
  assert H >= P and W >= P
629
+ # Accumulate on device to avoid CPU<->GPU thrash
630
+ prob_sum_t = torch.zeros((H, W), device=device, dtype=torch.float32)
631
+ weight_t = torch.zeros((H, W), device=device, dtype=torch.float32)
632
 
633
  # Prepare min/max on full-res (already computed above as y_min_full/y_max_full)
 
634
  hc4, wc4 = cond_map.shape[2], cond_map.shape[3]
635
 
636
  ys = list(range(0, max(H - P, 0) + 1, stride))
 
640
  if xs[-1] != (W - P):
641
  xs.append(W - P)
642
 
643
+ coords: List[Tuple[int, int]] = []
644
  for y0 in ys:
645
  for x0 in xs:
646
+ coords.append((y0, x0))
 
 
 
 
 
647
 
648
+ for i0 in range(0, len(coords), fine_batch):
649
+ batch_coords = coords[i0 : i0 + fine_batch]
650
+ xs_list: List[torch.Tensor] = []
651
+ for (y0, x0) in batch_coords:
652
+ y1, x1 = y0 + P, x0 + P
653
  # Cond crop mapping (same as training _build_fine_inputs)
654
  y0c = (y0 * hc4) // H
655
  y1c = ((y1 * hc4) + H - 1) // H
 
660
  cond_sub, size=(P, P), mode="bilinear", align_corners=False
661
  ).squeeze(1) # 1xPxP
662
 
663
+ # Build fine input channels directly from on-device tensors
664
+ rgb_t = t_img[0, :, y0:y1, x0:x1] # 3xPxP
665
+ ymin_t = y_min_full[0, 0, y0:y1, x0:x1].float().unsqueeze(0) # 1xPxP
666
+ ymax_t = y_max_full[0, 0, y0:y1, x0:x1].float().unsqueeze(0) # 1xPxP
667
+ x_f = torch.cat([rgb_t, ymin_t, ymax_t, cond_patch], dim=0).unsqueeze(0)
668
+ xs_list.append(x_f)
669
+
670
+ x_f_batch = torch.cat(xs_list, dim=0) # Bx6xPxP
671
+
672
+ with autocast(
673
+ device_type=device.type,
674
+ enabled=(device.type == "cuda" and amp_flag),
675
+ ):
676
+ logits_f = model.forward_fine(x_f_batch)
677
+ prob_f = torch.softmax(logits_f, dim=1)[:, 1:2]
678
+ prob_f_up = F.interpolate(
679
+ prob_f, size=(P, P), mode="bilinear", align_corners=False
680
+ )[:, 0, :, :] # BxPxP
681
+
682
+ for bi, (y0, x0) in enumerate(batch_coords):
683
+ y1, x1 = y0 + P, x0 + P
684
+ prob_sum_t[y0:y1, x0:x1] += prob_f_up[bi]
685
+ weight_t[y0:y1, x0:x1] += 1.0
686
+
687
+ prob_full = (prob_sum_t / weight_t).detach().cpu().numpy()
 
 
 
 
 
 
688
  pred_fine = (prob_full > prob_thresh).astype(np.uint8)
689
  m_f = compute_metrics(pred_fine, mask)
690
  for k in metrics_sum:
 
757
  align_corners=False,
758
  )[0]
759
  zeros_c = torch.zeros(1, coarse_size, coarse_size, device=device)
760
+ x_t = torch.cat([rgb_c, y_min_c, y_max_c, zeros_c], dim=0).unsqueeze(0)
761
  with autocast(
762
  device_type=device.type, enabled=(device.type == "cuda" and amp_flag)
763
  ):