Save test visuals during final eval
Browse files
train.py
CHANGED
|
@@ -373,6 +373,28 @@ def main():
|
|
| 373 |
print(
|
| 374 |
f"[Test Final][Coarse] IoU={test_stats['iou_coarse']:.4f} F1={test_stats['f1_coarse']:.4f} P={test_stats['precision_coarse']:.4f} R={test_stats['recall_coarse']:.4f}"
|
| 375 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
model.train()
|
| 377 |
|
| 378 |
print("[WireSegHR][train] Done.")
|
|
@@ -764,5 +786,61 @@ def save_test_visuals(
|
|
| 764 |
cv2.imwrite(os.path.join(out_dir, f"{i:03d}_pred.png"), pred)
|
| 765 |
|
| 766 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 767 |
if __name__ == "__main__":
|
| 768 |
main()
|
|
|
|
| 373 |
print(
|
| 374 |
f"[Test Final][Coarse] IoU={test_stats['iou_coarse']:.4f} F1={test_stats['f1_coarse']:.4f} P={test_stats['precision_coarse']:.4f} R={test_stats['recall_coarse']:.4f}"
|
| 375 |
)
|
| 376 |
+
# Save final evaluation artifacts
|
| 377 |
+
final_out = os.path.join(out_dir, f"final_vis_{step}")
|
| 378 |
+
os.makedirs(final_out, exist_ok=True)
|
| 379 |
+
# Dump metrics for record
|
| 380 |
+
with open(os.path.join(final_out, "metrics.yaml"), "w") as f:
|
| 381 |
+
yaml.safe_dump({**test_stats, "step": step}, f, sort_keys=False)
|
| 382 |
+
# Save predictions (fine + coarse) for the whole test set
|
| 383 |
+
save_final_visuals(
|
| 384 |
+
model,
|
| 385 |
+
dset_test,
|
| 386 |
+
coarse_test,
|
| 387 |
+
device,
|
| 388 |
+
final_out,
|
| 389 |
+
amp_enabled,
|
| 390 |
+
amp_dtype,
|
| 391 |
+
prob_thresh,
|
| 392 |
+
mm_enable,
|
| 393 |
+
mm_kernel,
|
| 394 |
+
eval_patch_size,
|
| 395 |
+
overlap,
|
| 396 |
+
eval_fine_batch,
|
| 397 |
+
)
|
| 398 |
model.train()
|
| 399 |
|
| 400 |
print("[WireSegHR][train] Done.")
|
|
|
|
| 786 |
cv2.imwrite(os.path.join(out_dir, f"{i:03d}_pred.png"), pred)
|
| 787 |
|
| 788 |
|
| 789 |
+
@torch.no_grad()
|
| 790 |
+
def save_final_visuals(
|
| 791 |
+
model: WireSegHR,
|
| 792 |
+
dset_test: WireSegDataset,
|
| 793 |
+
coarse_size: int,
|
| 794 |
+
device: torch.device,
|
| 795 |
+
out_dir: str,
|
| 796 |
+
amp_flag: bool,
|
| 797 |
+
amp_dtype,
|
| 798 |
+
prob_thresh: float,
|
| 799 |
+
minmax_enable: bool,
|
| 800 |
+
minmax_kernel: int,
|
| 801 |
+
fine_patch_size: int,
|
| 802 |
+
fine_overlap: int,
|
| 803 |
+
fine_batch: int,
|
| 804 |
+
):
|
| 805 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 806 |
+
for i in range(len(dset_test)):
|
| 807 |
+
item = dset_test[i]
|
| 808 |
+
img = item["image"].astype(np.float32) / 255.0
|
| 809 |
+
H, W = img.shape[:2]
|
| 810 |
+
# Coarse pass
|
| 811 |
+
prob_up, cond_map, t_img, y_min_full, y_max_full = _coarse_forward(
|
| 812 |
+
model,
|
| 813 |
+
img,
|
| 814 |
+
int(coarse_size),
|
| 815 |
+
bool(minmax_enable),
|
| 816 |
+
int(minmax_kernel),
|
| 817 |
+
device,
|
| 818 |
+
bool(amp_flag),
|
| 819 |
+
amp_dtype,
|
| 820 |
+
)
|
| 821 |
+
pred_coarse = ((prob_up > prob_thresh).to(torch.uint8) * 255).cpu().numpy()
|
| 822 |
+
# Fine pass (tiled)
|
| 823 |
+
prob_full = _tiled_fine_forward(
|
| 824 |
+
model,
|
| 825 |
+
t_img,
|
| 826 |
+
cond_map,
|
| 827 |
+
y_min_full,
|
| 828 |
+
y_max_full,
|
| 829 |
+
int(fine_patch_size),
|
| 830 |
+
int(fine_overlap),
|
| 831 |
+
int(fine_batch),
|
| 832 |
+
device,
|
| 833 |
+
bool(amp_flag),
|
| 834 |
+
amp_dtype,
|
| 835 |
+
)
|
| 836 |
+
pred_fine = ((prob_full > prob_thresh).to(torch.uint8) * 255).cpu().numpy()
|
| 837 |
+
# Save input and predictions
|
| 838 |
+
img_bgr = (img[..., ::-1] * 255.0).astype(np.uint8)
|
| 839 |
+
base = f"{i:03d}"
|
| 840 |
+
cv2.imwrite(os.path.join(out_dir, f"{base}_input.jpg"), img_bgr)
|
| 841 |
+
cv2.imwrite(os.path.join(out_dir, f"{base}_coarse_pred.png"), pred_coarse)
|
| 842 |
+
cv2.imwrite(os.path.join(out_dir, f"{base}_fine_pred.png"), pred_fine)
|
| 843 |
+
|
| 844 |
+
|
| 845 |
if __name__ == "__main__":
|
| 846 |
main()
|