(debug) eval fixes
Browse files- configs/default.yaml +2 -2
- train.py +75 -9
configs/default.yaml
CHANGED
|
@@ -29,7 +29,7 @@ inference:
|
|
| 29 |
|
| 30 |
optim:
|
| 31 |
iters: 2000
|
| 32 |
-
batch_size:
|
| 33 |
lr: 6e-5
|
| 34 |
weight_decay: 0.01
|
| 35 |
schedule: poly
|
|
@@ -40,7 +40,7 @@ seed: 42
|
|
| 40 |
out_dir: runs/wireseghr
|
| 41 |
eval_interval: 30
|
| 42 |
ckpt_interval: 100
|
| 43 |
-
|
| 44 |
|
| 45 |
# dataset paths (placeholders)
|
| 46 |
data:
|
|
|
|
| 29 |
|
| 30 |
optim:
|
| 31 |
iters: 2000
|
| 32 |
+
batch_size: 4
|
| 33 |
lr: 6e-5
|
| 34 |
weight_decay: 0.01
|
| 35 |
schedule: poly
|
|
|
|
| 40 |
out_dir: runs/wireseghr
|
| 41 |
eval_interval: 30
|
| 42 |
ckpt_interval: 100
|
| 43 |
+
resume: runs/wireseghr/ckpt_1500.pt # optional
|
| 44 |
|
| 45 |
# dataset paths (placeholders)
|
| 46 |
data:
|
train.py
CHANGED
|
@@ -87,6 +87,11 @@ def main():
|
|
| 87 |
else None
|
| 88 |
)
|
| 89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
# Model
|
| 91 |
# Channel definition: RGB(3) + MinMax(2) + cond(1) + loc(1) = 7
|
| 92 |
pretrained_flag = bool(cfg.get("pretrained", False))
|
|
@@ -160,7 +165,16 @@ def main():
|
|
| 160 |
# Eval & Checkpoint
|
| 161 |
if (step % eval_interval == 0) and (dset_val is not None):
|
| 162 |
model.eval()
|
| 163 |
-
val_stats = validate(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
print(
|
| 165 |
f"[Val @ {step}] IoU={val_stats['iou']:.4f} F1={val_stats['f1']:.4f} P={val_stats['precision']:.4f} R={val_stats['recall']:.4f}"
|
| 166 |
)
|
|
@@ -194,6 +208,9 @@ def main():
|
|
| 194 |
device,
|
| 195 |
os.path.join(out_dir, f"test_vis_{step}"),
|
| 196 |
amp_flag,
|
|
|
|
|
|
|
|
|
|
| 197 |
max_samples=8,
|
| 198 |
)
|
| 199 |
model.train()
|
|
@@ -461,6 +478,9 @@ def validate(
|
|
| 461 |
coarse_size: int,
|
| 462 |
device: torch.device,
|
| 463 |
amp_flag: bool,
|
|
|
|
|
|
|
|
|
|
| 464 |
) -> Dict[str, float]:
|
| 465 |
# Coarse-only validation: resize image to coarse_size, predict coarse logits, upsample to full and compute metrics
|
| 466 |
model = model.to(device)
|
|
@@ -482,11 +502,33 @@ def validate(
|
|
| 482 |
t_img, size=(coarse_size, coarse_size), mode="bilinear", align_corners=False
|
| 483 |
)[0]
|
| 484 |
y_t = 0.299 * t_img[:, 0:1] + 0.587 * t_img[:, 1:2] + 0.114 * t_img[:, 2:3]
|
| 485 |
-
|
| 486 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 487 |
)[0]
|
| 488 |
zeros_c = torch.zeros(1, coarse_size, coarse_size, device=device)
|
| 489 |
-
x_t = torch.cat([rgb_c,
|
| 490 |
with autocast(enabled=(device.type == "cuda" and amp_flag)):
|
| 491 |
logits_c, _ = model.forward_coarse(x_t)
|
| 492 |
prob = torch.softmax(logits_c, dim=1)[:, 1:2]
|
|
@@ -496,7 +538,7 @@ def validate(
|
|
| 496 |
.cpu()
|
| 497 |
.numpy()
|
| 498 |
)
|
| 499 |
-
pred = (prob_up >
|
| 500 |
m = compute_metrics(pred, mask)
|
| 501 |
for k in metrics_sum:
|
| 502 |
metrics_sum[k] += m[k]
|
|
@@ -514,6 +556,9 @@ def save_test_visuals(
|
|
| 514 |
device: torch.device,
|
| 515 |
out_dir: str,
|
| 516 |
amp_flag: bool,
|
|
|
|
|
|
|
|
|
|
| 517 |
max_samples: int = 8,
|
| 518 |
):
|
| 519 |
os.makedirs(out_dir, exist_ok=True)
|
|
@@ -531,11 +576,32 @@ def save_test_visuals(
|
|
| 531 |
t_img, size=(coarse_size, coarse_size), mode="bilinear", align_corners=False
|
| 532 |
)[0]
|
| 533 |
y_t = 0.299 * t_img[:, 0:1] + 0.587 * t_img[:, 1:2] + 0.114 * t_img[:, 2:3]
|
| 534 |
-
|
| 535 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 536 |
)[0]
|
| 537 |
zeros_c = torch.zeros(1, coarse_size, coarse_size, device=device)
|
| 538 |
-
x_t = torch.cat([rgb_c,
|
| 539 |
with autocast(enabled=(device.type == "cuda" and amp_flag)):
|
| 540 |
logits_c, _ = model.forward_coarse(x_t)
|
| 541 |
prob = torch.softmax(logits_c, dim=1)[:, 1:2]
|
|
@@ -545,7 +611,7 @@ def save_test_visuals(
|
|
| 545 |
.cpu()
|
| 546 |
.numpy()
|
| 547 |
)
|
| 548 |
-
pred = (prob_up >
|
| 549 |
# Save input and prediction
|
| 550 |
img_bgr = (img[..., ::-1] * 255.0).astype(np.uint8)
|
| 551 |
cv2.imwrite(os.path.join(out_dir, f"{i:03d}_input.jpg"), img_bgr)
|
|
|
|
| 87 |
else None
|
| 88 |
)
|
| 89 |
|
| 90 |
+
# Inference/eval settings from config
|
| 91 |
+
prob_thresh = float(cfg["inference"]["prob_threshold"])
|
| 92 |
+
mm_enable = bool(cfg["minmax"]["enable"])
|
| 93 |
+
mm_kernel = int(cfg["minmax"]["kernel"])
|
| 94 |
+
|
| 95 |
# Model
|
| 96 |
# Channel definition: RGB(3) + MinMax(2) + cond(1) + loc(1) = 7
|
| 97 |
pretrained_flag = bool(cfg.get("pretrained", False))
|
|
|
|
| 165 |
# Eval & Checkpoint
|
| 166 |
if (step % eval_interval == 0) and (dset_val is not None):
|
| 167 |
model.eval()
|
| 168 |
+
val_stats = validate(
|
| 169 |
+
model,
|
| 170 |
+
dset_val,
|
| 171 |
+
coarse_train,
|
| 172 |
+
device,
|
| 173 |
+
amp_flag,
|
| 174 |
+
prob_thresh,
|
| 175 |
+
mm_enable,
|
| 176 |
+
mm_kernel,
|
| 177 |
+
)
|
| 178 |
print(
|
| 179 |
f"[Val @ {step}] IoU={val_stats['iou']:.4f} F1={val_stats['f1']:.4f} P={val_stats['precision']:.4f} R={val_stats['recall']:.4f}"
|
| 180 |
)
|
|
|
|
| 208 |
device,
|
| 209 |
os.path.join(out_dir, f"test_vis_{step}"),
|
| 210 |
amp_flag,
|
| 211 |
+
mm_enable,
|
| 212 |
+
mm_kernel,
|
| 213 |
+
prob_thresh,
|
| 214 |
max_samples=8,
|
| 215 |
)
|
| 216 |
model.train()
|
|
|
|
| 478 |
coarse_size: int,
|
| 479 |
device: torch.device,
|
| 480 |
amp_flag: bool,
|
| 481 |
+
prob_thresh: float,
|
| 482 |
+
minmax_enable: bool,
|
| 483 |
+
minmax_kernel: int,
|
| 484 |
) -> Dict[str, float]:
|
| 485 |
# Coarse-only validation: resize image to coarse_size, predict coarse logits, upsample to full and compute metrics
|
| 486 |
model = model.to(device)
|
|
|
|
| 502 |
t_img, size=(coarse_size, coarse_size), mode="bilinear", align_corners=False
|
| 503 |
)[0]
|
| 504 |
y_t = 0.299 * t_img[:, 0:1] + 0.587 * t_img[:, 1:2] + 0.114 * t_img[:, 2:3]
|
| 505 |
+
if minmax_enable:
|
| 506 |
+
# Asymmetric padding for even kernel to keep same HxW
|
| 507 |
+
k = int(minmax_kernel)
|
| 508 |
+
if (k % 2) == 0:
|
| 509 |
+
pad = (k // 2 - 1, k // 2, k // 2 - 1, k // 2)
|
| 510 |
+
else:
|
| 511 |
+
pad = (k // 2, k // 2, k // 2, k // 2)
|
| 512 |
+
y_p = F.pad(y_t, pad, mode="replicate")
|
| 513 |
+
y_max_full = F.max_pool2d(y_p, kernel_size=k, stride=1)
|
| 514 |
+
y_min_full = -F.max_pool2d(-y_p, kernel_size=k, stride=1)
|
| 515 |
+
else:
|
| 516 |
+
y_min_full = y_t
|
| 517 |
+
y_max_full = y_t
|
| 518 |
+
y_min_c = F.interpolate(
|
| 519 |
+
y_min_full,
|
| 520 |
+
size=(coarse_size, coarse_size),
|
| 521 |
+
mode="bilinear",
|
| 522 |
+
align_corners=False,
|
| 523 |
+
)[0]
|
| 524 |
+
y_max_c = F.interpolate(
|
| 525 |
+
y_max_full,
|
| 526 |
+
size=(coarse_size, coarse_size),
|
| 527 |
+
mode="bilinear",
|
| 528 |
+
align_corners=False,
|
| 529 |
)[0]
|
| 530 |
zeros_c = torch.zeros(1, coarse_size, coarse_size, device=device)
|
| 531 |
+
x_t = torch.cat([rgb_c, y_min_c, y_max_c, zeros_c, zeros_c], dim=0).unsqueeze(0)
|
| 532 |
with autocast(enabled=(device.type == "cuda" and amp_flag)):
|
| 533 |
logits_c, _ = model.forward_coarse(x_t)
|
| 534 |
prob = torch.softmax(logits_c, dim=1)[:, 1:2]
|
|
|
|
| 538 |
.cpu()
|
| 539 |
.numpy()
|
| 540 |
)
|
| 541 |
+
pred = (prob_up > prob_thresh).astype(np.uint8)
|
| 542 |
m = compute_metrics(pred, mask)
|
| 543 |
for k in metrics_sum:
|
| 544 |
metrics_sum[k] += m[k]
|
|
|
|
| 556 |
device: torch.device,
|
| 557 |
out_dir: str,
|
| 558 |
amp_flag: bool,
|
| 559 |
+
minmax_enable: bool,
|
| 560 |
+
minmax_kernel: int,
|
| 561 |
+
prob_thresh: float,
|
| 562 |
max_samples: int = 8,
|
| 563 |
):
|
| 564 |
os.makedirs(out_dir, exist_ok=True)
|
|
|
|
| 576 |
t_img, size=(coarse_size, coarse_size), mode="bilinear", align_corners=False
|
| 577 |
)[0]
|
| 578 |
y_t = 0.299 * t_img[:, 0:1] + 0.587 * t_img[:, 1:2] + 0.114 * t_img[:, 2:3]
|
| 579 |
+
if minmax_enable:
|
| 580 |
+
k = int(minmax_kernel)
|
| 581 |
+
if (k % 2) == 0:
|
| 582 |
+
pad = (k // 2 - 1, k // 2, k // 2 - 1, k // 2)
|
| 583 |
+
else:
|
| 584 |
+
pad = (k // 2, k // 2, k // 2, k // 2)
|
| 585 |
+
y_p = F.pad(y_t, pad, mode="replicate")
|
| 586 |
+
y_max_full = F.max_pool2d(y_p, kernel_size=k, stride=1)
|
| 587 |
+
y_min_full = -F.max_pool2d(-y_p, kernel_size=k, stride=1)
|
| 588 |
+
else:
|
| 589 |
+
y_min_full = y_t
|
| 590 |
+
y_max_full = y_t
|
| 591 |
+
y_min_c = F.interpolate(
|
| 592 |
+
y_min_full,
|
| 593 |
+
size=(coarse_size, coarse_size),
|
| 594 |
+
mode="bilinear",
|
| 595 |
+
align_corners=False,
|
| 596 |
+
)[0]
|
| 597 |
+
y_max_c = F.interpolate(
|
| 598 |
+
y_max_full,
|
| 599 |
+
size=(coarse_size, coarse_size),
|
| 600 |
+
mode="bilinear",
|
| 601 |
+
align_corners=False,
|
| 602 |
)[0]
|
| 603 |
zeros_c = torch.zeros(1, coarse_size, coarse_size, device=device)
|
| 604 |
+
x_t = torch.cat([rgb_c, y_min_c, y_max_c, zeros_c, zeros_c], dim=0).unsqueeze(0)
|
| 605 |
with autocast(enabled=(device.type == "cuda" and amp_flag)):
|
| 606 |
logits_c, _ = model.forward_coarse(x_t)
|
| 607 |
prob = torch.softmax(logits_c, dim=1)[:, 1:2]
|
|
|
|
| 611 |
.cpu()
|
| 612 |
.numpy()
|
| 613 |
)
|
| 614 |
+
pred = (prob_up > prob_thresh).astype(np.uint8) * 255
|
| 615 |
# Save input and prediction
|
| 616 |
img_bgr = (img[..., ::-1] * 255.0).astype(np.uint8)
|
| 617 |
cv2.imwrite(os.path.join(out_dir, f"{i:03d}_input.jpg"), img_bgr)
|