MRiabov commited on
Commit
ba1bfde
·
1 Parent(s): a2999cc

Save test visuals during final eval

Browse files
Files changed (1) hide show
  1. train.py +78 -0
train.py CHANGED
@@ -373,6 +373,28 @@ def main():
373
  print(
374
  f"[Test Final][Coarse] IoU={test_stats['iou_coarse']:.4f} F1={test_stats['f1_coarse']:.4f} P={test_stats['precision_coarse']:.4f} R={test_stats['recall_coarse']:.4f}"
375
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
  model.train()
377
 
378
  print("[WireSegHR][train] Done.")
@@ -764,5 +786,61 @@ def save_test_visuals(
764
  cv2.imwrite(os.path.join(out_dir, f"{i:03d}_pred.png"), pred)
765
 
766
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
767
  if __name__ == "__main__":
768
  main()
 
373
  print(
374
  f"[Test Final][Coarse] IoU={test_stats['iou_coarse']:.4f} F1={test_stats['f1_coarse']:.4f} P={test_stats['precision_coarse']:.4f} R={test_stats['recall_coarse']:.4f}"
375
  )
376
+ # Save final evaluation artifacts
377
+ final_out = os.path.join(out_dir, f"final_vis_{step}")
378
+ os.makedirs(final_out, exist_ok=True)
379
+ # Dump metrics for record
380
+ with open(os.path.join(final_out, "metrics.yaml"), "w") as f:
381
+ yaml.safe_dump({**test_stats, "step": step}, f, sort_keys=False)
382
+ # Save predictions (fine + coarse) for the whole test set
383
+ save_final_visuals(
384
+ model,
385
+ dset_test,
386
+ coarse_test,
387
+ device,
388
+ final_out,
389
+ amp_enabled,
390
+ amp_dtype,
391
+ prob_thresh,
392
+ mm_enable,
393
+ mm_kernel,
394
+ eval_patch_size,
395
+ overlap,
396
+ eval_fine_batch,
397
+ )
398
  model.train()
399
 
400
  print("[WireSegHR][train] Done.")
 
786
  cv2.imwrite(os.path.join(out_dir, f"{i:03d}_pred.png"), pred)
787
 
788
 
789
+ @torch.no_grad()
790
+ def save_final_visuals(
791
+ model: WireSegHR,
792
+ dset_test: WireSegDataset,
793
+ coarse_size: int,
794
+ device: torch.device,
795
+ out_dir: str,
796
+ amp_flag: bool,
797
+ amp_dtype,
798
+ prob_thresh: float,
799
+ minmax_enable: bool,
800
+ minmax_kernel: int,
801
+ fine_patch_size: int,
802
+ fine_overlap: int,
803
+ fine_batch: int,
804
+ ):
805
+ os.makedirs(out_dir, exist_ok=True)
806
+ for i in range(len(dset_test)):
807
+ item = dset_test[i]
808
+ img = item["image"].astype(np.float32) / 255.0
809
+ H, W = img.shape[:2]
810
+ # Coarse pass
811
+ prob_up, cond_map, t_img, y_min_full, y_max_full = _coarse_forward(
812
+ model,
813
+ img,
814
+ int(coarse_size),
815
+ bool(minmax_enable),
816
+ int(minmax_kernel),
817
+ device,
818
+ bool(amp_flag),
819
+ amp_dtype,
820
+ )
821
+ pred_coarse = ((prob_up > prob_thresh).to(torch.uint8) * 255).cpu().numpy()
822
+ # Fine pass (tiled)
823
+ prob_full = _tiled_fine_forward(
824
+ model,
825
+ t_img,
826
+ cond_map,
827
+ y_min_full,
828
+ y_max_full,
829
+ int(fine_patch_size),
830
+ int(fine_overlap),
831
+ int(fine_batch),
832
+ device,
833
+ bool(amp_flag),
834
+ amp_dtype,
835
+ )
836
+ pred_fine = ((prob_full > prob_thresh).to(torch.uint8) * 255).cpu().numpy()
837
+ # Save input and predictions
838
+ img_bgr = (img[..., ::-1] * 255.0).astype(np.uint8)
839
+ base = f"{i:03d}"
840
+ cv2.imwrite(os.path.join(out_dir, f"{base}_input.jpg"), img_bgr)
841
+ cv2.imwrite(os.path.join(out_dir, f"{base}_coarse_pred.png"), pred_coarse)
842
+ cv2.imwrite(os.path.join(out_dir, f"{base}_fine_pred.png"), pred_fine)
843
+
844
+
845
  if __name__ == "__main__":
846
  main()