MRiabov commited on
Commit
6928e82
·
1 Parent(s): 8e73ec9

(format) format

Browse files
src/wireseghr/data/dataset.py CHANGED
@@ -33,7 +33,12 @@ class WireSegDataset:
33
  mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
34
  assert mask is not None, f"Failed to read mask: {mask_path}"
35
  mask_bin = (mask > 0).astype(np.uint8)
36
- return {"image": img, "mask": mask_bin, "image_path": str(img_path), "mask_path": str(mask_path)}
 
 
 
 
 
37
 
38
  def _index_pairs(self) -> List[tuple[Path, Path]]:
39
  # Convention: numeric filenames; images are .jpg/.jpeg; masks (gts) are .png
@@ -56,5 +61,7 @@ class WireSegDataset:
56
  mp = self.masks_dir / f"{i}.png"
57
  assert mp.exists(), f"Missing mask for {i}: {mp}"
58
  pairs.append((ip, mp))
59
- assert len(pairs) > 0, f"No numeric pairs found in {self.images_dir} and {self.masks_dir}"
 
 
60
  return pairs
 
33
  mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
34
  assert mask is not None, f"Failed to read mask: {mask_path}"
35
  mask_bin = (mask > 0).astype(np.uint8)
36
+ return {
37
+ "image": img,
38
+ "mask": mask_bin,
39
+ "image_path": str(img_path),
40
+ "mask_path": str(mask_path),
41
+ }
42
 
43
  def _index_pairs(self) -> List[tuple[Path, Path]]:
44
  # Convention: numeric filenames; images are .jpg/.jpeg; masks (gts) are .png
 
61
  mp = self.masks_dir / f"{i}.png"
62
  assert mp.exists(), f"Missing mask for {i}: {mp}"
63
  pairs.append((ip, mp))
64
+ assert len(pairs) > 0, (
65
+ f"No numeric pairs found in {self.images_dir} and {self.masks_dir}"
66
+ )
67
  return pairs
src/wireseghr/data/transforms.py CHANGED
@@ -1,6 +1,7 @@
1
  # Training-time transforms: scaling, rotation, flip, photometric distortion
2
  # TODO: Implement deterministic transform composition for reproducibility
3
 
 
4
  class TrainTransforms:
5
  def __init__(self):
6
  pass
 
1
  # Training-time transforms: scaling, rotation, flip, photometric distortion
2
  # TODO: Implement deterministic transform composition for reproducibility
3
 
4
+
5
  class TrainTransforms:
6
  def __init__(self):
7
  pass
src/wireseghr/infer.py CHANGED
@@ -6,7 +6,9 @@ import yaml
6
 
7
  def main():
8
  parser = argparse.ArgumentParser(description="WireSegHR inference (skeleton)")
9
- parser.add_argument("--config", type=str, default="configs/default.yaml", help="Path to YAML config")
 
 
10
  parser.add_argument("--image", type=str, required=False, help="Path to input image")
11
  args = parser.parse_args()
12
 
@@ -20,7 +22,9 @@ def main():
20
  print("[WireSegHR][infer] Loaded config from:", cfg_path)
21
  pprint.pprint(cfg)
22
  print("[WireSegHR][infer] Image:", args.image)
23
- print("[WireSegHR][infer] Skeleton OK. Implement inference per SEGMENTATION_PLAN.md.")
 
 
24
 
25
 
26
  if __name__ == "__main__":
 
6
 
7
  def main():
8
  parser = argparse.ArgumentParser(description="WireSegHR inference (skeleton)")
9
+ parser.add_argument(
10
+ "--config", type=str, default="configs/default.yaml", help="Path to YAML config"
11
+ )
12
  parser.add_argument("--image", type=str, required=False, help="Path to input image")
13
  args = parser.parse_args()
14
 
 
22
  print("[WireSegHR][infer] Loaded config from:", cfg_path)
23
  pprint.pprint(cfg)
24
  print("[WireSegHR][infer] Image:", args.image)
25
+ print(
26
+ "[WireSegHR][infer] Skeleton OK. Implement inference per SEGMENTATION_PLAN.md."
27
+ )
28
 
29
 
30
  if __name__ == "__main__":
src/wireseghr/metrics.py CHANGED
@@ -29,4 +29,9 @@ def compute_metrics(pred_mask: np.ndarray, gt_mask: np.ndarray) -> Dict[str, flo
29
  denom_f1 = precision + recall
30
  f1 = (2 * precision * recall / denom_f1) if denom_f1 > 0 else 0.0
31
 
32
- return {"iou": float(iou), "f1": float(f1), "precision": float(precision), "recall": float(recall)}
 
 
 
 
 
 
29
  denom_f1 = precision + recall
30
  f1 = (2 * precision * recall / denom_f1) if denom_f1 > 0 else 0.0
31
 
32
+ return {
33
+ "iou": float(iou),
34
+ "f1": float(f1),
35
+ "precision": float(precision),
36
+ "recall": float(recall),
37
+ }
src/wireseghr/model/decoder.py CHANGED
@@ -14,7 +14,9 @@ import torch.nn.functional as F
14
  class _ConvBNReLU(nn.Module):
15
  def __init__(self, in_ch: int, out_ch: int, k: int, s: int = 1, p: int = 0):
16
  super().__init__()
17
- self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=k, stride=s, padding=p, bias=False)
 
 
18
  self.bn = nn.BatchNorm2d(out_ch)
19
  self.relu = nn.ReLU(inplace=True)
20
 
@@ -29,7 +31,9 @@ class _SegFormerHead(nn.Module):
29
  def __init__(self, in_chs: List[int], embed_dim: int = 128, num_classes: int = 2):
30
  super().__init__()
31
  assert len(in_chs) == 4
32
- self.proj = nn.ModuleList([nn.Conv2d(c, embed_dim, kernel_size=1) for c in in_chs])
 
 
33
  self.fuse = _ConvBNReLU(embed_dim * 4, embed_dim, k=3, p=1)
34
  self.cls = nn.Conv2d(embed_dim, num_classes, kernel_size=1)
35
 
@@ -49,10 +53,20 @@ class _SegFormerHead(nn.Module):
49
 
50
 
51
  class CoarseDecoder(_SegFormerHead):
52
- def __init__(self, in_chs: List[int] = (64, 128, 320, 512), embed_dim: int = 128, num_classes: int = 2):
 
 
 
 
 
53
  super().__init__(list(in_chs), embed_dim, num_classes)
54
 
55
 
56
  class FineDecoder(_SegFormerHead):
57
- def __init__(self, in_chs: List[int] = (64, 128, 320, 512), embed_dim: int = 128, num_classes: int = 2):
 
 
 
 
 
58
  super().__init__(list(in_chs), embed_dim, num_classes)
 
14
  class _ConvBNReLU(nn.Module):
15
  def __init__(self, in_ch: int, out_ch: int, k: int, s: int = 1, p: int = 0):
16
  super().__init__()
17
+ self.conv = nn.Conv2d(
18
+ in_ch, out_ch, kernel_size=k, stride=s, padding=p, bias=False
19
+ )
20
  self.bn = nn.BatchNorm2d(out_ch)
21
  self.relu = nn.ReLU(inplace=True)
22
 
 
31
  def __init__(self, in_chs: List[int], embed_dim: int = 128, num_classes: int = 2):
32
  super().__init__()
33
  assert len(in_chs) == 4
34
+ self.proj = nn.ModuleList(
35
+ [nn.Conv2d(c, embed_dim, kernel_size=1) for c in in_chs]
36
+ )
37
  self.fuse = _ConvBNReLU(embed_dim * 4, embed_dim, k=3, p=1)
38
  self.cls = nn.Conv2d(embed_dim, num_classes, kernel_size=1)
39
 
 
53
 
54
 
55
  class CoarseDecoder(_SegFormerHead):
56
+ def __init__(
57
+ self,
58
+ in_chs: List[int] = (64, 128, 320, 512),
59
+ embed_dim: int = 128,
60
+ num_classes: int = 2,
61
+ ):
62
  super().__init__(list(in_chs), embed_dim, num_classes)
63
 
64
 
65
  class FineDecoder(_SegFormerHead):
66
+ def __init__(
67
+ self,
68
+ in_chs: List[int] = (64, 128, 320, 512),
69
+ embed_dim: int = 128,
70
+ num_classes: int = 2,
71
+ ):
72
  super().__init__(list(in_chs), embed_dim, num_classes)
src/wireseghr/model/encoder.py CHANGED
@@ -72,7 +72,9 @@ class SegFormerEncoder(nn.Module):
72
  def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
73
  if self.encoder is not None:
74
  feats = self.encoder(x)
75
- assert isinstance(feats, (list, tuple)) and len(feats) == len(self.out_indices)
 
 
76
  return list(feats)
77
  elif self.hf is not None:
78
  return self.hf(x)
@@ -106,7 +108,7 @@ class _TinyEncoder(nn.Module):
106
  )
107
 
108
  def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
109
- c0 = self.stem(x) # 1/4
110
  c1 = self.stage1(c0) # 1/8
111
  c2 = self.stage2(c1) # 1/16
112
  c3 = self.stage3(c2) # 1/32
@@ -144,7 +146,9 @@ class _HFEncoderWrapper(nn.Module):
144
  self.feature_dims = list(self.model.config.hidden_sizes)
145
 
146
  def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
147
- outputs = self.model(pixel_values=x, output_hidden_states=True, return_dict=True)
 
 
148
  feats = list(outputs.hidden_states)
149
  assert len(feats) == 4
150
  return feats
 
72
  def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
73
  if self.encoder is not None:
74
  feats = self.encoder(x)
75
+ assert isinstance(feats, (list, tuple)) and len(feats) == len(
76
+ self.out_indices
77
+ )
78
  return list(feats)
79
  elif self.hf is not None:
80
  return self.hf(x)
 
108
  )
109
 
110
  def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
111
+ c0 = self.stem(x) # 1/4
112
  c1 = self.stage1(c0) # 1/8
113
  c2 = self.stage2(c1) # 1/16
114
  c3 = self.stage3(c2) # 1/32
 
146
  self.feature_dims = list(self.model.config.hidden_sizes)
147
 
148
  def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
149
+ outputs = self.model(
150
+ pixel_values=x, output_hidden_states=True, return_dict=True
151
+ )
152
  feats = list(outputs.hidden_states)
153
  assert len(feats) == 4
154
  return feats
src/wireseghr/model/label_downsample.py CHANGED
@@ -20,6 +20,7 @@ def downsample_label_maxpool(mask: np.ndarray, out_h: int, out_w: int) -> np.nda
20
  assert mask.ndim == 2
21
  # Convert to float32 so area resize yields fractional averages > 0 if any positive present
22
  import cv2
 
23
  m = mask.astype(np.float32)
24
  r = cv2.resize(m, (out_w, out_h), interpolation=cv2.INTER_AREA)
25
  out = (r > 0.0).astype(np.uint8)
 
20
  assert mask.ndim == 2
21
  # Convert to float32 so area resize yields fractional averages > 0 if any positive present
22
  import cv2
23
+
24
  m = mask.astype(np.float32)
25
  r = cv2.resize(m, (out_w, out_h), interpolation=cv2.INTER_AREA)
26
  out = (r > 0.0).astype(np.uint8)
src/wireseghr/model/minmax.py CHANGED
@@ -23,6 +23,7 @@ class MinMaxLuminance:
23
  y = (0.299 * r + 0.587 * g + 0.114 * b).astype(np.float32)
24
 
25
  import cv2 # lazy import to avoid test-time dependency at module import
 
26
  kernel = np.ones((self.kernel, self.kernel), dtype=np.uint8)
27
  y_min = cv2.erode(y, kernel, borderType=cv2.BORDER_REPLICATE)
28
  y_max = cv2.dilate(y, kernel, borderType=cv2.BORDER_REPLICATE)
 
23
  y = (0.299 * r + 0.587 * g + 0.114 * b).astype(np.float32)
24
 
25
  import cv2 # lazy import to avoid test-time dependency at module import
26
+
27
  kernel = np.ones((self.kernel, self.kernel), dtype=np.uint8)
28
  y_min = cv2.erode(y, kernel, borderType=cv2.BORDER_REPLICATE)
29
  y_max = cv2.dilate(y, kernel, borderType=cv2.BORDER_REPLICATE)
src/wireseghr/model/model.py CHANGED
@@ -19,16 +19,22 @@ class WireSegHR(nn.Module):
19
  Conditioning 1x1 is applied to coarse logits to produce a single-channel map.
20
  """
21
 
22
- def __init__(self, backbone: str = "mit_b3", in_channels: int = 7, pretrained: bool = True):
 
 
23
  super().__init__()
24
- self.encoder = SegFormerEncoder(backbone=backbone, in_channels=in_channels, pretrained=pretrained)
 
 
25
  # Use encoder-exposed feature dims for decoder projections
26
  in_chs = tuple(self.encoder.feature_dims)
27
  self.coarse_head = CoarseDecoder(in_chs=in_chs, embed_dim=128, num_classes=2)
28
  self.fine_head = FineDecoder(in_chs=in_chs, embed_dim=128, num_classes=2)
29
  self.cond1x1 = Conditioning1x1()
30
 
31
- def forward_coarse(self, x_coarse: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
 
 
32
  assert x_coarse.dim() == 4
33
  feats = self.encoder(x_coarse)
34
  logits_coarse = self.coarse_head(feats)
 
19
  Conditioning 1x1 is applied to coarse logits to produce a single-channel map.
20
  """
21
 
22
+ def __init__(
23
+ self, backbone: str = "mit_b3", in_channels: int = 7, pretrained: bool = True
24
+ ):
25
  super().__init__()
26
+ self.encoder = SegFormerEncoder(
27
+ backbone=backbone, in_channels=in_channels, pretrained=pretrained
28
+ )
29
  # Use encoder-exposed feature dims for decoder projections
30
  in_chs = tuple(self.encoder.feature_dims)
31
  self.coarse_head = CoarseDecoder(in_chs=in_chs, embed_dim=128, num_classes=2)
32
  self.fine_head = FineDecoder(in_chs=in_chs, embed_dim=128, num_classes=2)
33
  self.cond1x1 = Conditioning1x1()
34
 
35
+ def forward_coarse(
36
+ self, x_coarse: torch.Tensor
37
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
38
  assert x_coarse.dim() == 4
39
  feats = self.encoder(x_coarse)
40
  logits_coarse = self.coarse_head(feats)
src/wireseghr/train.py CHANGED
@@ -24,7 +24,9 @@ from wireseghr.metrics import compute_metrics
24
 
25
  def main():
26
  parser = argparse.ArgumentParser(description="WireSegHR training (skeleton)")
27
- parser.add_argument("--config", type=str, default="configs/default.yaml", help="Path to YAML config")
 
 
28
  args = parser.parse_args()
29
 
30
  cfg_path = args.config
@@ -42,12 +44,12 @@ def main():
42
 
43
  # Config
44
  coarse_train = int(cfg["coarse"]["train_size"]) # 512
45
- patch_size = int(cfg["fine"]["patch_size"]) # 768
46
- iters = int(cfg["optim"]["iters"]) # 40000
47
- batch_size = int(cfg["optim"]["batch_size"]) # 8
48
- base_lr = float(cfg["optim"]["lr"]) # 6e-5
49
  weight_decay = float(cfg["optim"]["weight_decay"]) # 0.01
50
- power = float(cfg["optim"]["power"]) # 1.0
51
  amp_flag = bool(cfg["optim"].get("amp", True))
52
 
53
  # Housekeeping
@@ -67,15 +69,29 @@ def main():
67
  val_masks = cfg["data"].get("val_masks", None)
68
  test_images = cfg["data"].get("test_images", None)
69
  test_masks = cfg["data"].get("test_masks", None)
70
- dset_val = WireSegDataset(val_images, val_masks, split="val") if val_images and val_masks else None
71
- dset_test = WireSegDataset(test_images, test_masks, split="test") if test_images and test_masks else None
 
 
 
 
 
 
 
 
72
  sampler = BalancedPatchSampler(patch_size=patch_size, min_wire_ratio=0.01)
73
- minmax = MinMaxLuminance(kernel=cfg["minmax"]["kernel"]) if cfg["minmax"]["enable"] else None
 
 
 
 
74
 
75
  # Model
76
  # Channel definition: RGB(3) + MinMax(2) + cond(1) + loc(1) = 7
77
  pretrained_flag = bool(cfg.get("pretrained", False))
78
- model = WireSegHR(backbone=cfg["backbone"], in_channels=7, pretrained=pretrained_flag)
 
 
79
  model = model.to(device)
80
 
81
  # Optimizer and loss
@@ -89,7 +105,9 @@ def main():
89
  resume_path = cfg.get("resume", None)
90
  if resume_path and os.path.isfile(resume_path):
91
  print(f"[WireSegHR][train] Resuming from {resume_path}")
92
- start_step, best_f1 = _load_checkpoint(resume_path, model, optim, scaler, device)
 
 
93
 
94
  # Training loop
95
  model.train()
@@ -98,14 +116,21 @@ def main():
98
  while step < iters:
99
  optim.zero_grad(set_to_none=True)
100
  imgs, masks = _sample_batch_same_size(dset, batch_size)
101
- batch = _prepare_batch(imgs, masks, coarse_train, patch_size, sampler, minmax, device)
 
 
102
 
103
- logits_coarse, cond_map = model.forward_coarse(batch["x_coarse"]) # (B,2,Hc/4,Wc/4) and (B,1,Hc/4,Wc/4)
 
 
104
 
105
  # Upsample cond to full-res to crop the fine patch-aligned conditioning
106
  B, _, hc4, wc4 = cond_map.shape
107
  cond_up = F.interpolate(
108
- cond_map.detach(), size=(batch["full_h"], batch["full_w"]), mode="bilinear", align_corners=False
 
 
 
109
  )
110
 
111
  # Build fine inputs: crop cond to patch, concat with patch RGB+MinMax and loc mask
@@ -114,7 +139,9 @@ def main():
114
 
115
  # Targets
116
  y_coarse = _build_coarse_targets(batch["mask_full"], hc4, wc4, device)
117
- y_fine = _build_fine_targets(batch["mask_patches"], logits_fine.shape[2], logits_fine.shape[3], device)
 
 
118
 
119
  with autocast(enabled=(device.type == "cuda" and amp_flag)):
120
  loss_coarse = ce(logits_coarse, y_coarse)
@@ -131,25 +158,47 @@ def main():
131
  pg["lr"] = lr
132
 
133
  if step % 50 == 0:
134
- print(
135
- f"[Iter {step}/{iters}] lr={lr:.6e}"
136
- )
137
 
138
  # Eval & Checkpoint
139
  if (step % eval_interval == 0) and (dset_val is not None):
140
  model.eval()
141
  val_stats = validate(model, dset_val, coarse_train, device, amp_flag)
142
- print(f"[Val @ {step}] IoU={val_stats['iou']:.4f} F1={val_stats['f1']:.4f} P={val_stats['precision']:.4f} R={val_stats['recall']:.4f}")
 
 
143
  # Save best
144
  if val_stats["f1"] > best_f1:
145
  best_f1 = val_stats["f1"]
146
- _save_checkpoint(os.path.join(out_dir, "best.pt"), step, model, optim, scaler, best_f1)
 
 
 
 
 
 
 
147
  # Save periodic ckpt
148
  if ckpt_interval > 0 and (step % ckpt_interval == 0):
149
- _save_checkpoint(os.path.join(out_dir, f"ckpt_{step}.pt"), step, model, optim, scaler, best_f1)
 
 
 
 
 
 
 
150
  # Save test visualizations
151
  if dset_test is not None:
152
- save_test_visuals(model, dset_test, coarse_train, device, os.path.join(out_dir, f"test_vis_{step}"), amp_flag, max_samples=8)
 
 
 
 
 
 
 
 
153
  model.train()
154
 
155
  step += 1
@@ -158,7 +207,9 @@ def main():
158
  print("[WireSegHR][train] Done.")
159
 
160
 
161
- def _sample_batch_same_size(dset: WireSegDataset, batch_size: int) -> Tuple[List[np.ndarray], List[np.ndarray]]:
 
 
162
  # Select a seed sample, then fill the batch with samples of the same (H,W)
163
  assert len(dset) > 0
164
  seed_idx = int(np.random.randint(0, len(dset)))
@@ -213,20 +264,35 @@ def _prepare_batch(
213
  if minmax is not None:
214
  y_min, y_max = minmax(imgf)
215
  else:
216
- y = (0.299 * imgf[..., 0] + 0.587 * imgf[..., 1] + 0.114 * imgf[..., 2]).astype(np.float32)
 
 
217
  y_min, y_max = y, y
218
 
219
  # Coarse input: resize RGB + MinMax to coarse_train, pad cond+loc zeros to reach 7 channels
220
- rgb_coarse = cv2.resize(imgf, (coarse_train, coarse_train), interpolation=cv2.INTER_LINEAR)
221
- y_min_c = cv2.resize(y_min, (coarse_train, coarse_train), interpolation=cv2.INTER_LINEAR)
222
- y_max_c = cv2.resize(y_max, (coarse_train, coarse_train), interpolation=cv2.INTER_LINEAR)
223
- c = np.concatenate([
224
- np.transpose(rgb_coarse, (2, 0, 1)), # 3xHxW
225
- y_min_c[None, ...], # 1xHxW
226
- y_max_c[None, ...], # 1xHxW
227
- np.zeros((1, coarse_train, coarse_train), np.float32), # cond placeholder
228
- np.zeros((1, coarse_train, coarse_train), np.float32), # loc placeholder
229
- ], axis=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  xs_coarse.append(torch.from_numpy(c))
231
 
232
  # Sample fine patch
@@ -258,7 +324,9 @@ def _prepare_batch(
258
  }
259
 
260
 
261
- def _build_fine_inputs(batch, cond_up: torch.Tensor, device: torch.device) -> torch.Tensor:
 
 
262
  # Build fine input tensor Bx7xP x P from per-sample numpy buffers and upsampled cond maps
263
  B = cond_up.shape[0]
264
  P = batch["loc_patches"][0].shape[0]
@@ -277,14 +345,18 @@ def _build_fine_inputs(batch, cond_up: torch.Tensor, device: torch.device) -> to
277
  rgb_t = torch.from_numpy(np.transpose(rgb, (2, 0, 1))) # 3xPxP
278
  ymin_t = torch.from_numpy(ymin)[None, ...] # 1xPxP
279
  ymax_t = torch.from_numpy(ymax)[None, ...] # 1xPxP
280
- loc_t = torch.from_numpy(loc)[None, ...] # 1xPxP
281
- x = torch.cat([rgb_t, ymin_t, ymax_t, cond_patch.cpu(), loc_t], dim=0).float() # 7xPxP
 
 
282
  xs.append(x)
283
  x_fine = torch.stack(xs, dim=0).to(device)
284
  return x_fine
285
 
286
 
287
- def _build_coarse_targets(masks: List[np.ndarray], out_h: int, out_w: int, device: torch.device) -> torch.Tensor:
 
 
288
  ys: List[torch.Tensor] = []
289
  for m in masks:
290
  dm = downsample_label_maxpool(m, out_h, out_w)
@@ -293,7 +365,9 @@ def _build_coarse_targets(masks: List[np.ndarray], out_h: int, out_w: int, devic
293
  return y
294
 
295
 
296
- def _build_fine_targets(mask_patches: List[np.ndarray], out_h: int, out_w: int, device: torch.device) -> torch.Tensor:
 
 
297
  ys: List[torch.Tensor] = []
298
  for m in mask_patches:
299
  dm = downsample_label_maxpool(m, out_h, out_w)
@@ -302,10 +376,6 @@ def _build_fine_targets(mask_patches: List[np.ndarray], out_h: int, out_w: int,
302
  return y
303
 
304
 
305
- if __name__ == "__main__":
306
- main()
307
-
308
-
309
  def set_seed(seed: int):
310
  random.seed(seed)
311
  np.random.seed(seed)
@@ -316,7 +386,14 @@ def set_seed(seed: int):
316
  cudnn.deterministic = True
317
 
318
 
319
- def _save_checkpoint(path: str, step: int, model: nn.Module, optim: torch.optim.Optimizer, scaler: GradScaler, best_f1: float):
 
 
 
 
 
 
 
320
  os.makedirs(os.path.dirname(path), exist_ok=True)
321
  state = {
322
  "step": step,
@@ -329,7 +406,13 @@ def _save_checkpoint(path: str, step: int, model: nn.Module, optim: torch.optim.
329
  print(f"[WireSegHR][train] Saved checkpoint: {path}")
330
 
331
 
332
- def _load_checkpoint(path: str, model: nn.Module, optim: torch.optim.Optimizer, scaler: GradScaler, device: torch.device) -> Tuple[int, float]:
 
 
 
 
 
 
333
  ckpt = torch.load(path, map_location=device)
334
  model.load_state_dict(ckpt["model"])
335
  optim.load_state_dict(ckpt["optim"])
@@ -343,7 +426,13 @@ def _load_checkpoint(path: str, model: nn.Module, optim: torch.optim.Optimizer,
343
 
344
 
345
  @torch.no_grad()
346
- def validate(model: WireSegHR, dset_val: WireSegDataset, coarse_size: int, device: torch.device, amp_flag: bool) -> Dict[str, float]:
 
 
 
 
 
 
347
  # Coarse-only validation: resize image to coarse_size, predict coarse logits, upsample to full and compute metrics
348
  model = model.to(device)
349
  metrics_sum = {"iou": 0.0, "f1": 0.0, "precision": 0.0, "recall": 0.0}
@@ -354,22 +443,36 @@ def validate(model: WireSegHR, dset_val: WireSegDataset, coarse_size: int, devic
354
  mask = item["mask"].astype(np.uint8)
355
  H, W = mask.shape
356
  # Build coarse input (zeros for cond+loc)
357
- rgb_c = cv2.resize(img, (coarse_size, coarse_size), interpolation=cv2.INTER_LINEAR)
358
- y = (0.299 * img[..., 0] + 0.587 * img[..., 1] + 0.114 * img[..., 2]).astype(np.float32)
359
- y_min = cv2.resize(y, (coarse_size, coarse_size), interpolation=cv2.INTER_LINEAR)
 
 
 
 
 
 
360
  y_max = y_min
361
- x = np.concatenate([
362
- np.transpose(rgb_c, (2, 0, 1)),
363
- y_min[None, ...],
364
- y_max[None, ...],
365
- np.zeros((1, coarse_size, coarse_size), np.float32),
366
- np.zeros((1, coarse_size, coarse_size), np.float32),
367
- ], axis=0)
 
 
 
368
  x_t = torch.from_numpy(x)[None, ...].to(device)
369
  with autocast(enabled=(device.type == "cuda" and amp_flag)):
370
  logits_c, _ = model.forward_coarse(x_t)
371
  prob = torch.softmax(logits_c, dim=1)[:, 1:2]
372
- prob_up = F.interpolate(prob, size=(H, W), mode="bilinear", align_corners=False)[0, 0].detach().cpu().numpy()
 
 
 
 
 
373
  pred = (prob_up > 0.5).astype(np.uint8)
374
  m = compute_metrics(pred, mask)
375
  for k in metrics_sum:
@@ -381,30 +484,56 @@ def validate(model: WireSegHR, dset_val: WireSegDataset, coarse_size: int, devic
381
 
382
 
383
  @torch.no_grad()
384
- def save_test_visuals(model: WireSegHR, dset_test: WireSegDataset, coarse_size: int, device: torch.device, out_dir: str, amp_flag: bool, max_samples: int = 8):
 
 
 
 
 
 
 
 
385
  os.makedirs(out_dir, exist_ok=True)
386
  for i in range(min(max_samples, len(dset_test))):
387
  item = dset_test[i]
388
  img = item["image"].astype(np.float32) / 255.0
389
  H, W = img.shape[:2]
390
- rgb_c = cv2.resize(img, (coarse_size, coarse_size), interpolation=cv2.INTER_LINEAR)
391
- y = (0.299 * img[..., 0] + 0.587 * img[..., 1] + 0.114 * img[..., 2]).astype(np.float32)
392
- y_min = cv2.resize(y, (coarse_size, coarse_size), interpolation=cv2.INTER_LINEAR)
 
 
 
 
 
 
393
  y_max = y_min
394
- x = np.concatenate([
395
- np.transpose(rgb_c, (2, 0, 1)),
396
- y_min[None, ...],
397
- y_max[None, ...],
398
- np.zeros((1, coarse_size, coarse_size), np.float32),
399
- np.zeros((1, coarse_size, coarse_size), np.float32),
400
- ], axis=0)
 
 
 
401
  x_t = torch.from_numpy(x)[None, ...].to(device)
402
  with autocast(enabled=(device.type == "cuda" and amp_flag)):
403
  logits_c, _ = model.forward_coarse(x_t)
404
  prob = torch.softmax(logits_c, dim=1)[:, 1:2]
405
- prob_up = F.interpolate(prob, size=(H, W), mode="bilinear", align_corners=False)[0, 0].detach().cpu().numpy()
 
 
 
 
 
406
  pred = (prob_up > 0.5).astype(np.uint8) * 255
407
  # Save input and prediction
408
  img_bgr = (img[..., ::-1] * 255.0).astype(np.uint8)
409
  cv2.imwrite(os.path.join(out_dir, f"{i:03d}_input.jpg"), img_bgr)
410
  cv2.imwrite(os.path.join(out_dir, f"{i:03d}_pred.png"), pred)
 
 
 
 
 
24
 
25
  def main():
26
  parser = argparse.ArgumentParser(description="WireSegHR training (skeleton)")
27
+ parser.add_argument(
28
+ "--config", type=str, default="configs/default.yaml", help="Path to YAML config"
29
+ )
30
  args = parser.parse_args()
31
 
32
  cfg_path = args.config
 
44
 
45
  # Config
46
  coarse_train = int(cfg["coarse"]["train_size"]) # 512
47
+ patch_size = int(cfg["fine"]["patch_size"]) # 768
48
+ iters = int(cfg["optim"]["iters"]) # 40000
49
+ batch_size = int(cfg["optim"]["batch_size"]) # 8
50
+ base_lr = float(cfg["optim"]["lr"]) # 6e-5
51
  weight_decay = float(cfg["optim"]["weight_decay"]) # 0.01
52
+ power = float(cfg["optim"]["power"]) # 1.0
53
  amp_flag = bool(cfg["optim"].get("amp", True))
54
 
55
  # Housekeeping
 
69
  val_masks = cfg["data"].get("val_masks", None)
70
  test_images = cfg["data"].get("test_images", None)
71
  test_masks = cfg["data"].get("test_masks", None)
72
+ dset_val = (
73
+ WireSegDataset(val_images, val_masks, split="val")
74
+ if val_images and val_masks
75
+ else None
76
+ )
77
+ dset_test = (
78
+ WireSegDataset(test_images, test_masks, split="test")
79
+ if test_images and test_masks
80
+ else None
81
+ )
82
  sampler = BalancedPatchSampler(patch_size=patch_size, min_wire_ratio=0.01)
83
+ minmax = (
84
+ MinMaxLuminance(kernel=cfg["minmax"]["kernel"])
85
+ if cfg["minmax"]["enable"]
86
+ else None
87
+ )
88
 
89
  # Model
90
  # Channel definition: RGB(3) + MinMax(2) + cond(1) + loc(1) = 7
91
  pretrained_flag = bool(cfg.get("pretrained", False))
92
+ model = WireSegHR(
93
+ backbone=cfg["backbone"], in_channels=7, pretrained=pretrained_flag
94
+ )
95
  model = model.to(device)
96
 
97
  # Optimizer and loss
 
105
  resume_path = cfg.get("resume", None)
106
  if resume_path and os.path.isfile(resume_path):
107
  print(f"[WireSegHR][train] Resuming from {resume_path}")
108
+ start_step, best_f1 = _load_checkpoint(
109
+ resume_path, model, optim, scaler, device
110
+ )
111
 
112
  # Training loop
113
  model.train()
 
116
  while step < iters:
117
  optim.zero_grad(set_to_none=True)
118
  imgs, masks = _sample_batch_same_size(dset, batch_size)
119
+ batch = _prepare_batch(
120
+ imgs, masks, coarse_train, patch_size, sampler, minmax, device
121
+ )
122
 
123
+ logits_coarse, cond_map = model.forward_coarse(
124
+ batch["x_coarse"]
125
+ ) # (B,2,Hc/4,Wc/4) and (B,1,Hc/4,Wc/4)
126
 
127
  # Upsample cond to full-res to crop the fine patch-aligned conditioning
128
  B, _, hc4, wc4 = cond_map.shape
129
  cond_up = F.interpolate(
130
+ cond_map.detach(),
131
+ size=(batch["full_h"], batch["full_w"]),
132
+ mode="bilinear",
133
+ align_corners=False,
134
  )
135
 
136
  # Build fine inputs: crop cond to patch, concat with patch RGB+MinMax and loc mask
 
139
 
140
  # Targets
141
  y_coarse = _build_coarse_targets(batch["mask_full"], hc4, wc4, device)
142
+ y_fine = _build_fine_targets(
143
+ batch["mask_patches"], logits_fine.shape[2], logits_fine.shape[3], device
144
+ )
145
 
146
  with autocast(enabled=(device.type == "cuda" and amp_flag)):
147
  loss_coarse = ce(logits_coarse, y_coarse)
 
158
  pg["lr"] = lr
159
 
160
  if step % 50 == 0:
161
+ print(f"[Iter {step}/{iters}] lr={lr:.6e}")
 
 
162
 
163
  # Eval & Checkpoint
164
  if (step % eval_interval == 0) and (dset_val is not None):
165
  model.eval()
166
  val_stats = validate(model, dset_val, coarse_train, device, amp_flag)
167
+ print(
168
+ f"[Val @ {step}] IoU={val_stats['iou']:.4f} F1={val_stats['f1']:.4f} P={val_stats['precision']:.4f} R={val_stats['recall']:.4f}"
169
+ )
170
  # Save best
171
  if val_stats["f1"] > best_f1:
172
  best_f1 = val_stats["f1"]
173
+ _save_checkpoint(
174
+ os.path.join(out_dir, "best.pt"),
175
+ step,
176
+ model,
177
+ optim,
178
+ scaler,
179
+ best_f1,
180
+ )
181
  # Save periodic ckpt
182
  if ckpt_interval > 0 and (step % ckpt_interval == 0):
183
+ _save_checkpoint(
184
+ os.path.join(out_dir, f"ckpt_{step}.pt"),
185
+ step,
186
+ model,
187
+ optim,
188
+ scaler,
189
+ best_f1,
190
+ )
191
  # Save test visualizations
192
  if dset_test is not None:
193
+ save_test_visuals(
194
+ model,
195
+ dset_test,
196
+ coarse_train,
197
+ device,
198
+ os.path.join(out_dir, f"test_vis_{step}"),
199
+ amp_flag,
200
+ max_samples=8,
201
+ )
202
  model.train()
203
 
204
  step += 1
 
207
  print("[WireSegHR][train] Done.")
208
 
209
 
210
+ def _sample_batch_same_size(
211
+ dset: WireSegDataset, batch_size: int
212
+ ) -> Tuple[List[np.ndarray], List[np.ndarray]]:
213
  # Select a seed sample, then fill the batch with samples of the same (H,W)
214
  assert len(dset) > 0
215
  seed_idx = int(np.random.randint(0, len(dset)))
 
264
  if minmax is not None:
265
  y_min, y_max = minmax(imgf)
266
  else:
267
+ y = (
268
+ 0.299 * imgf[..., 0] + 0.587 * imgf[..., 1] + 0.114 * imgf[..., 2]
269
+ ).astype(np.float32)
270
  y_min, y_max = y, y
271
 
272
  # Coarse input: resize RGB + MinMax to coarse_train, pad cond+loc zeros to reach 7 channels
273
+ rgb_coarse = cv2.resize(
274
+ imgf, (coarse_train, coarse_train), interpolation=cv2.INTER_LINEAR
275
+ )
276
+ y_min_c = cv2.resize(
277
+ y_min, (coarse_train, coarse_train), interpolation=cv2.INTER_LINEAR
278
+ )
279
+ y_max_c = cv2.resize(
280
+ y_max, (coarse_train, coarse_train), interpolation=cv2.INTER_LINEAR
281
+ )
282
+ c = np.concatenate(
283
+ [
284
+ np.transpose(rgb_coarse, (2, 0, 1)), # 3xHxW
285
+ y_min_c[None, ...], # 1xHxW
286
+ y_max_c[None, ...], # 1xHxW
287
+ np.zeros(
288
+ (1, coarse_train, coarse_train), np.float32
289
+ ), # cond placeholder
290
+ np.zeros(
291
+ (1, coarse_train, coarse_train), np.float32
292
+ ), # loc placeholder
293
+ ],
294
+ axis=0,
295
+ )
296
  xs_coarse.append(torch.from_numpy(c))
297
 
298
  # Sample fine patch
 
324
  }
325
 
326
 
327
+ def _build_fine_inputs(
328
+ batch, cond_up: torch.Tensor, device: torch.device
329
+ ) -> torch.Tensor:
330
  # Build fine input tensor Bx7xP x P from per-sample numpy buffers and upsampled cond maps
331
  B = cond_up.shape[0]
332
  P = batch["loc_patches"][0].shape[0]
 
345
  rgb_t = torch.from_numpy(np.transpose(rgb, (2, 0, 1))) # 3xPxP
346
  ymin_t = torch.from_numpy(ymin)[None, ...] # 1xPxP
347
  ymax_t = torch.from_numpy(ymax)[None, ...] # 1xPxP
348
+ loc_t = torch.from_numpy(loc)[None, ...] # 1xPxP
349
+ x = torch.cat(
350
+ [rgb_t, ymin_t, ymax_t, cond_patch.cpu(), loc_t], dim=0
351
+ ).float() # 7xPxP
352
  xs.append(x)
353
  x_fine = torch.stack(xs, dim=0).to(device)
354
  return x_fine
355
 
356
 
357
+ def _build_coarse_targets(
358
+ masks: List[np.ndarray], out_h: int, out_w: int, device: torch.device
359
+ ) -> torch.Tensor:
360
  ys: List[torch.Tensor] = []
361
  for m in masks:
362
  dm = downsample_label_maxpool(m, out_h, out_w)
 
365
  return y
366
 
367
 
368
+ def _build_fine_targets(
369
+ mask_patches: List[np.ndarray], out_h: int, out_w: int, device: torch.device
370
+ ) -> torch.Tensor:
371
  ys: List[torch.Tensor] = []
372
  for m in mask_patches:
373
  dm = downsample_label_maxpool(m, out_h, out_w)
 
376
  return y
377
 
378
 
 
 
 
 
379
  def set_seed(seed: int):
380
  random.seed(seed)
381
  np.random.seed(seed)
 
386
  cudnn.deterministic = True
387
 
388
 
389
+ def _save_checkpoint(
390
+ path: str,
391
+ step: int,
392
+ model: nn.Module,
393
+ optim: torch.optim.Optimizer,
394
+ scaler: GradScaler,
395
+ best_f1: float,
396
+ ):
397
  os.makedirs(os.path.dirname(path), exist_ok=True)
398
  state = {
399
  "step": step,
 
406
  print(f"[WireSegHR][train] Saved checkpoint: {path}")
407
 
408
 
409
+ def _load_checkpoint(
410
+ path: str,
411
+ model: nn.Module,
412
+ optim: torch.optim.Optimizer,
413
+ scaler: GradScaler,
414
+ device: torch.device,
415
+ ) -> Tuple[int, float]:
416
  ckpt = torch.load(path, map_location=device)
417
  model.load_state_dict(ckpt["model"])
418
  optim.load_state_dict(ckpt["optim"])
 
426
 
427
 
428
  @torch.no_grad()
429
+ def validate(
430
+ model: WireSegHR,
431
+ dset_val: WireSegDataset,
432
+ coarse_size: int,
433
+ device: torch.device,
434
+ amp_flag: bool,
435
+ ) -> Dict[str, float]:
436
  # Coarse-only validation: resize image to coarse_size, predict coarse logits, upsample to full and compute metrics
437
  model = model.to(device)
438
  metrics_sum = {"iou": 0.0, "f1": 0.0, "precision": 0.0, "recall": 0.0}
 
443
  mask = item["mask"].astype(np.uint8)
444
  H, W = mask.shape
445
  # Build coarse input (zeros for cond+loc)
446
+ rgb_c = cv2.resize(
447
+ img, (coarse_size, coarse_size), interpolation=cv2.INTER_LINEAR
448
+ )
449
+ y = (0.299 * img[..., 0] + 0.587 * img[..., 1] + 0.114 * img[..., 2]).astype(
450
+ np.float32
451
+ )
452
+ y_min = cv2.resize(
453
+ y, (coarse_size, coarse_size), interpolation=cv2.INTER_LINEAR
454
+ )
455
  y_max = y_min
456
+ x = np.concatenate(
457
+ [
458
+ np.transpose(rgb_c, (2, 0, 1)),
459
+ y_min[None, ...],
460
+ y_max[None, ...],
461
+ np.zeros((1, coarse_size, coarse_size), np.float32),
462
+ np.zeros((1, coarse_size, coarse_size), np.float32),
463
+ ],
464
+ axis=0,
465
+ )
466
  x_t = torch.from_numpy(x)[None, ...].to(device)
467
  with autocast(enabled=(device.type == "cuda" and amp_flag)):
468
  logits_c, _ = model.forward_coarse(x_t)
469
  prob = torch.softmax(logits_c, dim=1)[:, 1:2]
470
+ prob_up = (
471
+ F.interpolate(prob, size=(H, W), mode="bilinear", align_corners=False)[0, 0]
472
+ .detach()
473
+ .cpu()
474
+ .numpy()
475
+ )
476
  pred = (prob_up > 0.5).astype(np.uint8)
477
  m = compute_metrics(pred, mask)
478
  for k in metrics_sum:
 
484
 
485
 
486
  @torch.no_grad()
487
+ def save_test_visuals(
488
+ model: WireSegHR,
489
+ dset_test: WireSegDataset,
490
+ coarse_size: int,
491
+ device: torch.device,
492
+ out_dir: str,
493
+ amp_flag: bool,
494
+ max_samples: int = 8,
495
+ ):
496
  os.makedirs(out_dir, exist_ok=True)
497
  for i in range(min(max_samples, len(dset_test))):
498
  item = dset_test[i]
499
  img = item["image"].astype(np.float32) / 255.0
500
  H, W = img.shape[:2]
501
+ rgb_c = cv2.resize(
502
+ img, (coarse_size, coarse_size), interpolation=cv2.INTER_LINEAR
503
+ )
504
+ y = (0.299 * img[..., 0] + 0.587 * img[..., 1] + 0.114 * img[..., 2]).astype(
505
+ np.float32
506
+ )
507
+ y_min = cv2.resize(
508
+ y, (coarse_size, coarse_size), interpolation=cv2.INTER_LINEAR
509
+ )
510
  y_max = y_min
511
+ x = np.concatenate(
512
+ [
513
+ np.transpose(rgb_c, (2, 0, 1)),
514
+ y_min[None, ...],
515
+ y_max[None, ...],
516
+ np.zeros((1, coarse_size, coarse_size), np.float32),
517
+ np.zeros((1, coarse_size, coarse_size), np.float32),
518
+ ],
519
+ axis=0,
520
+ )
521
  x_t = torch.from_numpy(x)[None, ...].to(device)
522
  with autocast(enabled=(device.type == "cuda" and amp_flag)):
523
  logits_c, _ = model.forward_coarse(x_t)
524
  prob = torch.softmax(logits_c, dim=1)[:, 1:2]
525
+ prob_up = (
526
+ F.interpolate(prob, size=(H, W), mode="bilinear", align_corners=False)[0, 0]
527
+ .detach()
528
+ .cpu()
529
+ .numpy()
530
+ )
531
  pred = (prob_up > 0.5).astype(np.uint8) * 255
532
  # Save input and prediction
533
  img_bgr = (img[..., ::-1] * 255.0).astype(np.uint8)
534
  cv2.imwrite(os.path.join(out_dir, f"{i:03d}_input.jpg"), img_bgr)
535
  cv2.imwrite(os.path.join(out_dir, f"{i:03d}_pred.png"), pred)
536
+
537
+
538
+ if __name__ == "__main__":
539
+ main()