MRiabov commited on
Commit
d46d294
·
1 Parent(s): e160020

to mit_b2 as per paper

Browse files
SEGMENTATION_PLAN.md CHANGED
@@ -21,7 +21,7 @@ Focus: segmentation only (no dataset collection or inpainting).
21
 
22
  ## Project Structure
23
  - `configs/`
24
- - `default.yaml` (backbone=mit_b3, p=768, coarse_train=512, coarse_test=1024, alpha=0.01, minmax=true, kernel=6, maxpool_label=true, cond_variant=global+binary_mask)
25
  - `src/wireseghr/`
26
  - `model/`
27
  - `encoder.py` (SegFormer MiT-B3, N_in channels expansion)
@@ -106,7 +106,7 @@ Focus: segmentation only (no dataset collection or inpainting).
106
  - Ablations: MinMax on/off, MaxPool on/off, conditioning variant (Table `tables/logit.tex`).
107
 
108
  ## Configuration Surface (key)
109
- - Backbone/weights: `mit_b3` (pretrained ImageNet-1K).
110
  - Sizes: `p=768`, `coarse_train=512`, `coarse_test=1024`, `overlap=128`.
111
  - Conditioning: `use_binary_location=true`, `cond_from='coarse_logits_1x1'`, `cond_crop='patch'`.
112
  - MinMax: `enable=true`, `kernel=6`.
 
21
 
22
  ## Project Structure
23
  - `configs/`
24
+ - `default.yaml` (backbone=mit_b2, p=768, coarse_train=512, coarse_test=1024, alpha=0.01, minmax=true, kernel=6, maxpool_label=true, cond_variant=global+binary_mask)
25
  - `src/wireseghr/`
26
  - `model/`
27
  - `encoder.py` (SegFormer MiT-B3, N_in channels expansion)
 
106
  - Ablations: MinMax on/off, MaxPool on/off, conditioning variant (Table `tables/logit.tex`).
107
 
108
  ## Configuration Surface (key)
109
+ - Backbone/weights: `mit_b2` (pretrained ImageNet-1K).
110
  - Sizes: `p=768`, `coarse_train=512`, `coarse_test=1024`, `overlap=128`.
111
  - Conditioning: `use_binary_location=true`, `cond_from='coarse_logits_1x1'`, `cond_crop='patch'`.
112
  - MinMax: `enable=true`, `kernel=6`.
configs/default.yaml CHANGED
@@ -1,5 +1,5 @@
1
  # Default configuration for WireSegHR (segmentation-only)
2
- backbone: mit_b3
3
  pretrained: true # Uses HF SegFormer weights if available; else timm or tiny fallback
4
 
5
  coarse:
@@ -24,7 +24,7 @@ label:
24
 
25
  inference:
26
  alpha: 0.01
27
- prob_threshold: 0.5
28
  stitch: avg_logits
29
 
30
  optim:
 
1
  # Default configuration for WireSegHR (segmentation-only)
2
+ backbone: mit_b2
3
  pretrained: true # Uses HF SegFormer weights if available; else timm or tiny fallback
4
 
5
  coarse:
 
24
 
25
  inference:
26
  alpha: 0.01
27
+ prob_threshold: 0.3 # was 0.5, not actually mentioned in the paper.
28
  stitch: avg_logits
29
 
30
  optim:
src/wireseghr/model/encoder.py CHANGED
@@ -1,6 +1,6 @@
1
  """SegFormer MiT encoder wrapper with adjustable input channels.
2
 
3
- Uses timm to instantiate MiT (e.g., mit_b3) and returns a list of multi-scale
4
  features [C1, C2, C3, C4].
5
  """
6
 
@@ -14,7 +14,7 @@ import timm
14
  class SegFormerEncoder(nn.Module):
15
  def __init__(
16
  self,
17
- backbone: str = "mit_b3",
18
  in_channels: int = 7,
19
  pretrained: bool = True,
20
  out_indices: Tuple[int, int, int, int] = (0, 1, 2, 3),
@@ -125,7 +125,7 @@ class _HFEncoderWrapper(nn.Module):
125
  "mit_b0": "nvidia/mit-b0",
126
  "mit_b1": "nvidia/mit-b1",
127
  "mit_b2": "nvidia/mit-b2",
128
- "mit_b3": "nvidia/mit-b3",
129
  "mit_b4": "nvidia/mit-b4",
130
  "mit_b5": "nvidia/mit-b5",
131
  }
 
1
  """SegFormer MiT encoder wrapper with adjustable input channels.
2
 
3
+ Uses timm to instantiate MiT (e.g., mit_b2) and returns a list of multi-scale
4
  features [C1, C2, C3, C4].
5
  """
6
 
 
14
  class SegFormerEncoder(nn.Module):
15
  def __init__(
16
  self,
17
+ backbone: str = "mit_b2",
18
  in_channels: int = 7,
19
  pretrained: bool = True,
20
  out_indices: Tuple[int, int, int, int] = (0, 1, 2, 3),
 
125
  "mit_b0": "nvidia/mit-b0",
126
  "mit_b1": "nvidia/mit-b1",
127
  "mit_b2": "nvidia/mit-b2",
128
+ "mit_b2": "nvidia/mit-b3",
129
  "mit_b4": "nvidia/mit-b4",
130
  "mit_b5": "nvidia/mit-b5",
131
  }
src/wireseghr/model/model.py CHANGED
@@ -20,7 +20,7 @@ class WireSegHR(nn.Module):
20
  """
21
 
22
  def __init__(
23
- self, backbone: str = "mit_b3", in_channels: int = 7, pretrained: bool = True
24
  ):
25
  super().__init__()
26
  self.encoder = SegFormerEncoder(
 
20
  """
21
 
22
  def __init__(
23
+ self, backbone: str = "mit_b2", in_channels: int = 7, pretrained: bool = True
24
  ):
25
  super().__init__()
26
  self.encoder = SegFormerEncoder(
tests/test_model_forward.py CHANGED
@@ -5,7 +5,7 @@ from wireseghr.model import WireSegHR
5
 
6
  def test_wireseghr_forward_shapes():
7
  # Use small input to keep test light and avoid downloading weights
8
- model = WireSegHR(backbone="mit_b3", in_channels=3, pretrained=False)
9
 
10
  x = torch.randn(1, 3, 64, 64)
11
  logits_coarse, cond = model.forward_coarse(x)
 
5
 
6
  def test_wireseghr_forward_shapes():
7
  # Use small input to keep test light and avoid downloading weights
8
+ model = WireSegHR(backbone="mit_b2", in_channels=3, pretrained=False)
9
 
10
  x = torch.randn(1, 3, 64, 64)
11
  logits_coarse, cond = model.forward_coarse(x)
train.py CHANGED
@@ -8,7 +8,7 @@ import numpy as np
8
  import torch
9
  import torch.nn as nn
10
  import torch.nn.functional as F
11
- from torch.cuda.amp import autocast
12
  from torch.amp import GradScaler
13
  from tqdm import tqdm
14
  import random
@@ -46,6 +46,7 @@ def main():
46
  # Config
47
  coarse_train = int(cfg["coarse"]["train_size"]) # 512
48
  patch_size = int(cfg["fine"]["patch_size"]) # 768
 
49
  iters = int(cfg["optim"]["iters"]) # 40000
50
  batch_size = int(cfg["optim"]["batch_size"]) # 8
51
  base_lr = float(cfg["optim"]["lr"]) # 6e-5
@@ -126,7 +127,9 @@ def main():
126
  imgs, masks, coarse_train, patch_size, sampler, minmax, device
127
  )
128
 
129
- with autocast(enabled=(device.type == "cuda" and amp_flag)):
 
 
130
  logits_coarse, cond_map = model.forward_coarse(
131
  batch["x_coarse"]
132
  ) # (B,2,Hc/4,Wc/4) and (B,1,Hc/4,Wc/4)
@@ -134,7 +137,9 @@ def main():
134
  # Build fine inputs: crop cond from low-res map to patch, concat with patch RGB+MinMax and loc mask
135
  B, _, hc4, wc4 = cond_map.shape
136
  x_fine = _build_fine_inputs(batch, cond_map, device)
137
- with autocast(enabled=(device.type == "cuda" and amp_flag)):
 
 
138
  logits_fine = model.forward_fine(x_fine)
139
 
140
  # Targets
@@ -174,9 +179,14 @@ def main():
174
  prob_thresh,
175
  mm_enable,
176
  mm_kernel,
 
 
177
  )
178
  print(
179
- f"[Val @ {step}] IoU={val_stats['iou']:.4f} F1={val_stats['f1']:.4f} P={val_stats['precision']:.4f} R={val_stats['recall']:.4f}"
 
 
 
180
  )
181
  # Save best
182
  if val_stats["f1"] > best_f1:
@@ -481,10 +491,13 @@ def validate(
481
  prob_thresh: float,
482
  minmax_enable: bool,
483
  minmax_kernel: int,
 
 
484
  ) -> Dict[str, float]:
485
  # Coarse-only validation: resize image to coarse_size, predict coarse logits, upsample to full and compute metrics
486
  model = model.to(device)
487
  metrics_sum = {"iou": 0.0, "f1": 0.0, "precision": 0.0, "recall": 0.0}
 
488
  n = 0
489
  for i in range(len(dset_val)):
490
  item = dset_val[i]
@@ -529,8 +542,10 @@ def validate(
529
  )[0]
530
  zeros_c = torch.zeros(1, coarse_size, coarse_size, device=device)
531
  x_t = torch.cat([rgb_c, y_min_c, y_max_c, zeros_c, zeros_c], dim=0).unsqueeze(0)
532
- with autocast(enabled=(device.type == "cuda" and amp_flag)):
533
- logits_c, _ = model.forward_coarse(x_t)
 
 
534
  prob = torch.softmax(logits_c, dim=1)[:, 1:2]
535
  prob_up = (
536
  F.interpolate(prob, size=(H, W), mode="bilinear", align_corners=False)[0, 0]
@@ -538,14 +553,98 @@ def validate(
538
  .cpu()
539
  .numpy()
540
  )
541
- pred = (prob_up > prob_thresh).astype(np.uint8)
542
- m = compute_metrics(pred, mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
543
  for k in metrics_sum:
544
- metrics_sum[k] += m[k]
545
  n += 1
546
  if n == 0:
547
  return {k: 0.0 for k in metrics_sum}
548
- return {k: v / float(n) for k, v in metrics_sum.items()}
 
 
 
 
 
 
 
 
 
549
 
550
 
551
  @torch.no_grad()
@@ -602,7 +701,9 @@ def save_test_visuals(
602
  )[0]
603
  zeros_c = torch.zeros(1, coarse_size, coarse_size, device=device)
604
  x_t = torch.cat([rgb_c, y_min_c, y_max_c, zeros_c, zeros_c], dim=0).unsqueeze(0)
605
- with autocast(enabled=(device.type == "cuda" and amp_flag)):
 
 
606
  logits_c, _ = model.forward_coarse(x_t)
607
  prob = torch.softmax(logits_c, dim=1)[:, 1:2]
608
  prob_up = (
 
8
  import torch
9
  import torch.nn as nn
10
  import torch.nn.functional as F
11
+ from torch.amp import autocast
12
  from torch.amp import GradScaler
13
  from tqdm import tqdm
14
  import random
 
46
  # Config
47
  coarse_train = int(cfg["coarse"]["train_size"]) # 512
48
  patch_size = int(cfg["fine"]["patch_size"]) # 768
49
+ overlap = int(cfg["fine"]["overlap"]) # e.g., 128
50
  iters = int(cfg["optim"]["iters"]) # 40000
51
  batch_size = int(cfg["optim"]["batch_size"]) # 8
52
  base_lr = float(cfg["optim"]["lr"]) # 6e-5
 
127
  imgs, masks, coarse_train, patch_size, sampler, minmax, device
128
  )
129
 
130
+ with autocast(
131
+ device_type=device.type, enabled=(device.type == "cuda" and amp_flag)
132
+ ):
133
  logits_coarse, cond_map = model.forward_coarse(
134
  batch["x_coarse"]
135
  ) # (B,2,Hc/4,Wc/4) and (B,1,Hc/4,Wc/4)
 
137
  # Build fine inputs: crop cond from low-res map to patch, concat with patch RGB+MinMax and loc mask
138
  B, _, hc4, wc4 = cond_map.shape
139
  x_fine = _build_fine_inputs(batch, cond_map, device)
140
+ with autocast(
141
+ device_type=device.type, enabled=(device.type == "cuda" and amp_flag)
142
+ ):
143
  logits_fine = model.forward_fine(x_fine)
144
 
145
  # Targets
 
179
  prob_thresh,
180
  mm_enable,
181
  mm_kernel,
182
+ patch_size,
183
+ overlap,
184
  )
185
  print(
186
+ f"[Val @ {step}][Fine] IoU={val_stats['iou']:.4f} F1={val_stats['f1']:.4f} P={val_stats['precision']:.4f} R={val_stats['recall']:.4f}"
187
+ )
188
+ print(
189
+ f"[Val @ {step}][Coarse] IoU={val_stats['iou_coarse']:.4f} F1={val_stats['f1_coarse']:.4f} P={val_stats['precision_coarse']:.4f} R={val_stats['recall_coarse']:.4f}"
190
  )
191
  # Save best
192
  if val_stats["f1"] > best_f1:
 
491
  prob_thresh: float,
492
  minmax_enable: bool,
493
  minmax_kernel: int,
494
+ fine_patch_size: int,
495
+ fine_overlap: int,
496
  ) -> Dict[str, float]:
497
  # Coarse-only validation: resize image to coarse_size, predict coarse logits, upsample to full and compute metrics
498
  model = model.to(device)
499
  metrics_sum = {"iou": 0.0, "f1": 0.0, "precision": 0.0, "recall": 0.0}
500
+ coarse_sum = {"iou": 0.0, "f1": 0.0, "precision": 0.0, "recall": 0.0}
501
  n = 0
502
  for i in range(len(dset_val)):
503
  item = dset_val[i]
 
542
  )[0]
543
  zeros_c = torch.zeros(1, coarse_size, coarse_size, device=device)
544
  x_t = torch.cat([rgb_c, y_min_c, y_max_c, zeros_c, zeros_c], dim=0).unsqueeze(0)
545
+ with autocast(
546
+ device_type=device.type, enabled=(device.type == "cuda" and amp_flag)
547
+ ):
548
+ logits_c, cond_map = model.forward_coarse(x_t)
549
  prob = torch.softmax(logits_c, dim=1)[:, 1:2]
550
  prob_up = (
551
  F.interpolate(prob, size=(H, W), mode="bilinear", align_corners=False)[0, 0]
 
553
  .cpu()
554
  .numpy()
555
  )
556
+ # Coarse metrics
557
+ pred_coarse = (prob_up > prob_thresh).astype(np.uint8)
558
+ m_c = compute_metrics(pred_coarse, mask)
559
+ for k in coarse_sum:
560
+ coarse_sum[k] += m_c[k]
561
+
562
+ # Fine-stage tiled inference and stitching
563
+ P = fine_patch_size
564
+ stride = P - fine_overlap
565
+ assert stride > 0
566
+ assert H >= P and W >= P
567
+ prob_sum = np.zeros((H, W), dtype=np.float32)
568
+ weight = np.zeros((H, W), dtype=np.float32)
569
+
570
+ # Prepare min/max on full-res (already computed above as y_min_full/y_max_full)
571
+ # y_min_full, y_max_full exist in this scope from above branch
572
+ hc4, wc4 = cond_map.shape[2], cond_map.shape[3]
573
+
574
+ ys = list(range(0, max(H - P, 0) + 1, stride))
575
+ if ys[-1] != (H - P):
576
+ ys.append(H - P)
577
+ xs = list(range(0, max(W - P, 0) + 1, stride))
578
+ if xs[-1] != (W - P):
579
+ xs.append(W - P)
580
+
581
+ for y0 in ys:
582
+ for x0 in xs:
583
+ y1, x1 = y0 + P, x0 + P
584
+ # Build fine input 1x7xP x P
585
+ patch_rgb = img[y0:y1, x0:x1, :]
586
+ ymin_patch = y_min_full[0, 0, y0:y1, x0:x1].detach().cpu().numpy()
587
+ ymax_patch = y_max_full[0, 0, y0:y1, x0:x1].detach().cpu().numpy()
588
+ loc = np.ones((P, P), dtype=np.float32)
589
+
590
+ # Cond crop mapping (same as training _build_fine_inputs)
591
+ y0c = (y0 * hc4) // H
592
+ y1c = ((y1 * hc4) + H - 1) // H
593
+ x0c = (x0 * wc4) // W
594
+ x1c = ((x1 * wc4) + W - 1) // W
595
+ cond_sub = cond_map[:, :, y0c:y1c, x0c:x1c].float()
596
+ cond_patch = F.interpolate(
597
+ cond_sub, size=(P, P), mode="bilinear", align_corners=False
598
+ ).squeeze(1) # 1xPxP
599
+
600
+ rgb_t = (
601
+ torch.from_numpy(np.transpose(patch_rgb, (2, 0, 1)))
602
+ .to(device)
603
+ .float()
604
+ ) # 3xPxP
605
+ ymin_t = torch.from_numpy(ymin_patch)[None, ...].to(device).float()
606
+ ymax_t = torch.from_numpy(ymax_patch)[None, ...].to(device).float()
607
+ loc_t = torch.from_numpy(loc)[None, ...].to(device).float()
608
+ x_f = torch.cat(
609
+ [rgb_t, ymin_t, ymax_t, cond_patch, loc_t], dim=0
610
+ ).unsqueeze(0)
611
+
612
+ with autocast(
613
+ device_type=device.type,
614
+ enabled=(device.type == "cuda" and amp_flag),
615
+ ):
616
+ logits_f = model.forward_fine(x_f)
617
+ prob_f = torch.softmax(logits_f, dim=1)[:, 1:2]
618
+ prob_f_up = (
619
+ F.interpolate(
620
+ prob_f, size=(P, P), mode="bilinear", align_corners=False
621
+ )[0, 0]
622
+ .detach()
623
+ .cpu()
624
+ .numpy()
625
+ )
626
+
627
+ prob_sum[y0:y1, x0:x1] += prob_f_up
628
+ weight[y0:y1, x0:x1] += 1.0
629
+
630
+ prob_full = prob_sum / weight
631
+ pred_fine = (prob_full > prob_thresh).astype(np.uint8)
632
+ m_f = compute_metrics(pred_fine, mask)
633
  for k in metrics_sum:
634
+ metrics_sum[k] += m_f[k]
635
  n += 1
636
  if n == 0:
637
  return {k: 0.0 for k in metrics_sum}
638
+ out = {k: v / float(n) for k, v in metrics_sum.items()}
639
+ out.update(
640
+ {
641
+ "iou_coarse": coarse_sum["iou"] / float(n),
642
+ "f1_coarse": coarse_sum["f1"] / float(n),
643
+ "precision_coarse": coarse_sum["precision"] / float(n),
644
+ "recall_coarse": coarse_sum["recall"] / float(n),
645
+ }
646
+ )
647
+ return out
648
 
649
 
650
  @torch.no_grad()
 
701
  )[0]
702
  zeros_c = torch.zeros(1, coarse_size, coarse_size, device=device)
703
  x_t = torch.cat([rgb_c, y_min_c, y_max_c, zeros_c, zeros_c], dim=0).unsqueeze(0)
704
+ with autocast(
705
+ device_type=device.type, enabled=(device.type == "cuda" and amp_flag)
706
+ ):
707
  logits_c, _ = model.forward_coarse(x_t)
708
  prob = torch.softmax(logits_c, dim=1)[:, 1:2]
709
  prob_up = (