jskvrna commited on
Commit
e36bb23
·
1 Parent(s): db367cd

Enables voxel model integration for prediction

Browse files

Integrates the voxel model into the prediction pipeline.
This change allows the system to leverage 3D convolutional
neural networks for improved wireframe prediction.
Also, adds visualization of HSS, F1, and IoU scores.

Files changed (1) hide show
  1. train.py +18 -12
train.py CHANGED
@@ -12,9 +12,10 @@ from utils import read_colmap_rec, empty_solution
12
 
13
  #from hoho2025.example_solutions import predict_wireframe
14
  from hoho2025.metric_helper import hss
15
- from predict import predict_wireframe
16
  from tqdm import tqdm
17
  from fast_pointnet import load_pointnet_model
 
18
  import torch
19
 
20
  #ds = load_dataset("usm3d/hoho25k", cache_dir="/media/skvrnjan/sd/hoho25k/", trust_remote_code=True)
@@ -25,23 +26,34 @@ scores_hss = []
25
  scores_f1 = []
26
  scores_iou = []
27
 
28
- show_visu = True
29
 
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
31
 
32
  #pnet_model = load_pointnet_model(model_path="/home/skvrnjan/personal/hoho_pnet/initial_epoch_100.pth", device=device, predict_score=True)
33
  pnet_model = None
34
 
 
 
 
35
  idx = 0
36
  for a in tqdm(ds['train'], desc="Processing dataset"):
37
  #plot_all_modalities(a)
38
- #pred_vertices, pred_edges = predict_wireframe(a, pnet_model)
 
39
  try:
40
- pred_vertices, pred_edges = predict_wireframe(a, pnet_model)
 
41
  except:
42
  pred_vertices, pred_edges = empty_solution()
43
 
44
- if show_visu:
 
 
 
 
 
 
45
  colmap = read_colmap_rec(a['colmap_binary'])
46
  pcd, geometries = plot_reconstruction_local(None, colmap, points=True, cameras=True, crop_outliers=True)
47
  wireframe = plot_wireframe_local(None, a['wf_vertices'], a['wf_edges'], a['wf_classifications'])
@@ -49,13 +61,7 @@ for a in tqdm(ds['train'], desc="Processing dataset"):
49
  bpo_cams = plot_bpo_cameras_from_entry_local(None, a)
50
 
51
  visu_all = [pcd] + geometries + wireframe + bpo_cams + wireframe2
52
- o3d.visualization.draw_geometries(visu_all, window_name="3D Reconstruction")
53
-
54
- score = hss(pred_vertices, pred_edges, a['wf_vertices'], a['wf_edges'], vert_thresh=0.5, edge_thresh=0.5)
55
- print(f"Score: {score}")
56
- scores_hss.append(score.hss)
57
- scores_f1.append(score.f1)
58
- scores_iou.append(score.iou)
59
 
60
  for i in range(10):
61
  print("END OF DATASET")
 
12
 
13
  #from hoho2025.example_solutions import predict_wireframe
14
  from hoho2025.metric_helper import hss
15
+ from predict import predict_wireframe, predict_wireframe_old
16
  from tqdm import tqdm
17
  from fast_pointnet import load_pointnet_model
18
+ from fast_voxel import load_3dcnn_model
19
  import torch
20
 
21
  #ds = load_dataset("usm3d/hoho25k", cache_dir="/media/skvrnjan/sd/hoho25k/", trust_remote_code=True)
 
26
  scores_f1 = []
27
  scores_iou = []
28
 
29
+ show_visu = False
30
 
31
  device = "cuda" if torch.cuda.is_available() else "cpu"
32
 
33
  #pnet_model = load_pointnet_model(model_path="/home/skvrnjan/personal/hoho_pnet/initial_epoch_100.pth", device=device, predict_score=True)
34
  pnet_model = None
35
 
36
+ #voxel_model = load_3dcnn_model(model_path="/home/skvrnjan/personal/hoho_voxel/initial_epoch_100.pth", device=device, predict_score=True)
37
+ voxel_model = None
38
+
39
  idx = 0
40
  for a in tqdm(ds['train'], desc="Processing dataset"):
41
  #plot_all_modalities(a)
42
+ #pred_vertices, pred_edges = predict_wireframe_old(a)
43
+ #pred_vertices, pred_edges = predict_wireframe(a, pnet_model, voxel_model)
44
  try:
45
+ pred_vertices, pred_edges = predict_wireframe(a, pnet_model, voxel_model)
46
+ #pred_vertices, pred_edges = predict_wireframe_old(a)
47
  except:
48
  pred_vertices, pred_edges = empty_solution()
49
 
50
+ score = hss(pred_vertices, pred_edges, a['wf_vertices'], a['wf_edges'], vert_thresh=0.5, edge_thresh=0.5)
51
+ print(f"Score: {score}")
52
+ scores_hss.append(score.hss)
53
+ scores_f1.append(score.f1)
54
+ scores_iou.append(score.iou)
55
+
56
+ if show_visu and score.hss < 0.1:
57
  colmap = read_colmap_rec(a['colmap_binary'])
58
  pcd, geometries = plot_reconstruction_local(None, colmap, points=True, cameras=True, crop_outliers=True)
59
  wireframe = plot_wireframe_local(None, a['wf_vertices'], a['wf_edges'], a['wf_classifications'])
 
61
  bpo_cams = plot_bpo_cameras_from_entry_local(None, a)
62
 
63
  visu_all = [pcd] + geometries + wireframe + bpo_cams + wireframe2
64
+ o3d.visualization.draw_geometries(visu_all, window_name=f"3D Reconstruction - HSS: {score.hss:.4f}, F1: {score.f1:.4f}, IoU: {score.iou:.4f}")
 
 
 
 
 
 
65
 
66
  for i in range(10):
67
  print("END OF DATASET")