MRiabov commited on
Commit
e160020
·
1 Parent(s): 953508f

(debug) eval fixes

Browse files
Files changed (2) hide show
  1. configs/default.yaml +2 -2
  2. train.py +75 -9
configs/default.yaml CHANGED
@@ -29,7 +29,7 @@ inference:
29
 
30
  optim:
31
  iters: 2000
32
- batch_size: 8
33
  lr: 6e-5
34
  weight_decay: 0.01
35
  schedule: poly
@@ -40,7 +40,7 @@ seed: 42
40
  out_dir: runs/wireseghr
41
  eval_interval: 30
42
  ckpt_interval: 100
43
- # resume: runs/wireseghr/ckpt_1000.pt # optional
44
 
45
  # dataset paths (placeholders)
46
  data:
 
29
 
30
  optim:
31
  iters: 2000
32
+ batch_size: 4
33
  lr: 6e-5
34
  weight_decay: 0.01
35
  schedule: poly
 
40
  out_dir: runs/wireseghr
41
  eval_interval: 30
42
  ckpt_interval: 100
43
+ resume: runs/wireseghr/ckpt_1500.pt # optional
44
 
45
  # dataset paths (placeholders)
46
  data:
train.py CHANGED
@@ -87,6 +87,11 @@ def main():
87
  else None
88
  )
89
 
 
 
 
 
 
90
  # Model
91
  # Channel definition: RGB(3) + MinMax(2) + cond(1) + loc(1) = 7
92
  pretrained_flag = bool(cfg.get("pretrained", False))
@@ -160,7 +165,16 @@ def main():
160
  # Eval & Checkpoint
161
  if (step % eval_interval == 0) and (dset_val is not None):
162
  model.eval()
163
- val_stats = validate(model, dset_val, coarse_train, device, amp_flag)
 
 
 
 
 
 
 
 
 
164
  print(
165
  f"[Val @ {step}] IoU={val_stats['iou']:.4f} F1={val_stats['f1']:.4f} P={val_stats['precision']:.4f} R={val_stats['recall']:.4f}"
166
  )
@@ -194,6 +208,9 @@ def main():
194
  device,
195
  os.path.join(out_dir, f"test_vis_{step}"),
196
  amp_flag,
 
 
 
197
  max_samples=8,
198
  )
199
  model.train()
@@ -461,6 +478,9 @@ def validate(
461
  coarse_size: int,
462
  device: torch.device,
463
  amp_flag: bool,
 
 
 
464
  ) -> Dict[str, float]:
465
  # Coarse-only validation: resize image to coarse_size, predict coarse logits, upsample to full and compute metrics
466
  model = model.to(device)
@@ -482,11 +502,33 @@ def validate(
482
  t_img, size=(coarse_size, coarse_size), mode="bilinear", align_corners=False
483
  )[0]
484
  y_t = 0.299 * t_img[:, 0:1] + 0.587 * t_img[:, 1:2] + 0.114 * t_img[:, 2:3]
485
- y_c = F.interpolate(
486
- y_t, size=(coarse_size, coarse_size), mode="bilinear", align_corners=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
487
  )[0]
488
  zeros_c = torch.zeros(1, coarse_size, coarse_size, device=device)
489
- x_t = torch.cat([rgb_c, y_c, y_c, zeros_c, zeros_c], dim=0).unsqueeze(0)
490
  with autocast(enabled=(device.type == "cuda" and amp_flag)):
491
  logits_c, _ = model.forward_coarse(x_t)
492
  prob = torch.softmax(logits_c, dim=1)[:, 1:2]
@@ -496,7 +538,7 @@ def validate(
496
  .cpu()
497
  .numpy()
498
  )
499
- pred = (prob_up > 0.5).astype(np.uint8)
500
  m = compute_metrics(pred, mask)
501
  for k in metrics_sum:
502
  metrics_sum[k] += m[k]
@@ -514,6 +556,9 @@ def save_test_visuals(
514
  device: torch.device,
515
  out_dir: str,
516
  amp_flag: bool,
 
 
 
517
  max_samples: int = 8,
518
  ):
519
  os.makedirs(out_dir, exist_ok=True)
@@ -531,11 +576,32 @@ def save_test_visuals(
531
  t_img, size=(coarse_size, coarse_size), mode="bilinear", align_corners=False
532
  )[0]
533
  y_t = 0.299 * t_img[:, 0:1] + 0.587 * t_img[:, 1:2] + 0.114 * t_img[:, 2:3]
534
- y_c = F.interpolate(
535
- y_t, size=(coarse_size, coarse_size), mode="bilinear", align_corners=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
536
  )[0]
537
  zeros_c = torch.zeros(1, coarse_size, coarse_size, device=device)
538
- x_t = torch.cat([rgb_c, y_c, y_c, zeros_c, zeros_c], dim=0).unsqueeze(0)
539
  with autocast(enabled=(device.type == "cuda" and amp_flag)):
540
  logits_c, _ = model.forward_coarse(x_t)
541
  prob = torch.softmax(logits_c, dim=1)[:, 1:2]
@@ -545,7 +611,7 @@ def save_test_visuals(
545
  .cpu()
546
  .numpy()
547
  )
548
- pred = (prob_up > 0.5).astype(np.uint8) * 255
549
  # Save input and prediction
550
  img_bgr = (img[..., ::-1] * 255.0).astype(np.uint8)
551
  cv2.imwrite(os.path.join(out_dir, f"{i:03d}_input.jpg"), img_bgr)
 
87
  else None
88
  )
89
 
90
+ # Inference/eval settings from config
91
+ prob_thresh = float(cfg["inference"]["prob_threshold"])
92
+ mm_enable = bool(cfg["minmax"]["enable"])
93
+ mm_kernel = int(cfg["minmax"]["kernel"])
94
+
95
  # Model
96
  # Channel definition: RGB(3) + MinMax(2) + cond(1) + loc(1) = 7
97
  pretrained_flag = bool(cfg.get("pretrained", False))
 
165
  # Eval & Checkpoint
166
  if (step % eval_interval == 0) and (dset_val is not None):
167
  model.eval()
168
+ val_stats = validate(
169
+ model,
170
+ dset_val,
171
+ coarse_train,
172
+ device,
173
+ amp_flag,
174
+ prob_thresh,
175
+ mm_enable,
176
+ mm_kernel,
177
+ )
178
  print(
179
  f"[Val @ {step}] IoU={val_stats['iou']:.4f} F1={val_stats['f1']:.4f} P={val_stats['precision']:.4f} R={val_stats['recall']:.4f}"
180
  )
 
208
  device,
209
  os.path.join(out_dir, f"test_vis_{step}"),
210
  amp_flag,
211
+ mm_enable,
212
+ mm_kernel,
213
+ prob_thresh,
214
  max_samples=8,
215
  )
216
  model.train()
 
478
  coarse_size: int,
479
  device: torch.device,
480
  amp_flag: bool,
481
+ prob_thresh: float,
482
+ minmax_enable: bool,
483
+ minmax_kernel: int,
484
  ) -> Dict[str, float]:
485
  # Coarse-only validation: resize image to coarse_size, predict coarse logits, upsample to full and compute metrics
486
  model = model.to(device)
 
502
  t_img, size=(coarse_size, coarse_size), mode="bilinear", align_corners=False
503
  )[0]
504
  y_t = 0.299 * t_img[:, 0:1] + 0.587 * t_img[:, 1:2] + 0.114 * t_img[:, 2:3]
505
+ if minmax_enable:
506
+ # Asymmetric padding for even kernel to keep same HxW
507
+ k = int(minmax_kernel)
508
+ if (k % 2) == 0:
509
+ pad = (k // 2 - 1, k // 2, k // 2 - 1, k // 2)
510
+ else:
511
+ pad = (k // 2, k // 2, k // 2, k // 2)
512
+ y_p = F.pad(y_t, pad, mode="replicate")
513
+ y_max_full = F.max_pool2d(y_p, kernel_size=k, stride=1)
514
+ y_min_full = -F.max_pool2d(-y_p, kernel_size=k, stride=1)
515
+ else:
516
+ y_min_full = y_t
517
+ y_max_full = y_t
518
+ y_min_c = F.interpolate(
519
+ y_min_full,
520
+ size=(coarse_size, coarse_size),
521
+ mode="bilinear",
522
+ align_corners=False,
523
+ )[0]
524
+ y_max_c = F.interpolate(
525
+ y_max_full,
526
+ size=(coarse_size, coarse_size),
527
+ mode="bilinear",
528
+ align_corners=False,
529
  )[0]
530
  zeros_c = torch.zeros(1, coarse_size, coarse_size, device=device)
531
+ x_t = torch.cat([rgb_c, y_min_c, y_max_c, zeros_c, zeros_c], dim=0).unsqueeze(0)
532
  with autocast(enabled=(device.type == "cuda" and amp_flag)):
533
  logits_c, _ = model.forward_coarse(x_t)
534
  prob = torch.softmax(logits_c, dim=1)[:, 1:2]
 
538
  .cpu()
539
  .numpy()
540
  )
541
+ pred = (prob_up > prob_thresh).astype(np.uint8)
542
  m = compute_metrics(pred, mask)
543
  for k in metrics_sum:
544
  metrics_sum[k] += m[k]
 
556
  device: torch.device,
557
  out_dir: str,
558
  amp_flag: bool,
559
+ minmax_enable: bool,
560
+ minmax_kernel: int,
561
+ prob_thresh: float,
562
  max_samples: int = 8,
563
  ):
564
  os.makedirs(out_dir, exist_ok=True)
 
576
  t_img, size=(coarse_size, coarse_size), mode="bilinear", align_corners=False
577
  )[0]
578
  y_t = 0.299 * t_img[:, 0:1] + 0.587 * t_img[:, 1:2] + 0.114 * t_img[:, 2:3]
579
+ if minmax_enable:
580
+ k = int(minmax_kernel)
581
+ if (k % 2) == 0:
582
+ pad = (k // 2 - 1, k // 2, k // 2 - 1, k // 2)
583
+ else:
584
+ pad = (k // 2, k // 2, k // 2, k // 2)
585
+ y_p = F.pad(y_t, pad, mode="replicate")
586
+ y_max_full = F.max_pool2d(y_p, kernel_size=k, stride=1)
587
+ y_min_full = -F.max_pool2d(-y_p, kernel_size=k, stride=1)
588
+ else:
589
+ y_min_full = y_t
590
+ y_max_full = y_t
591
+ y_min_c = F.interpolate(
592
+ y_min_full,
593
+ size=(coarse_size, coarse_size),
594
+ mode="bilinear",
595
+ align_corners=False,
596
+ )[0]
597
+ y_max_c = F.interpolate(
598
+ y_max_full,
599
+ size=(coarse_size, coarse_size),
600
+ mode="bilinear",
601
+ align_corners=False,
602
  )[0]
603
  zeros_c = torch.zeros(1, coarse_size, coarse_size, device=device)
604
+ x_t = torch.cat([rgb_c, y_min_c, y_max_c, zeros_c, zeros_c], dim=0).unsqueeze(0)
605
  with autocast(enabled=(device.type == "cuda" and amp_flag)):
606
  logits_c, _ = model.forward_coarse(x_t)
607
  prob = torch.softmax(logits_c, dim=1)[:, 1:2]
 
611
  .cpu()
612
  .numpy()
613
  )
614
+ pred = (prob_up > prob_thresh).astype(np.uint8) * 255
615
  # Save input and prediction
616
  img_bgr = (img[..., ::-1] * 255.0).astype(np.uint8)
617
  cv2.imwrite(os.path.join(out_dir, f"{i:03d}_input.jpg"), img_bgr)