Enables voxel model integration for prediction
Browse filesIntegrates the voxel model into the prediction pipeline.
This change allows the system to leverage 3D convolutional
neural networks for improved wireframe prediction.
Also, adds visualization of HSS, F1, and IoU scores.
train.py
CHANGED
|
@@ -12,9 +12,10 @@ from utils import read_colmap_rec, empty_solution
|
|
| 12 |
|
| 13 |
#from hoho2025.example_solutions import predict_wireframe
|
| 14 |
from hoho2025.metric_helper import hss
|
| 15 |
-
from predict import predict_wireframe
|
| 16 |
from tqdm import tqdm
|
| 17 |
from fast_pointnet import load_pointnet_model
|
|
|
|
| 18 |
import torch
|
| 19 |
|
| 20 |
#ds = load_dataset("usm3d/hoho25k", cache_dir="/media/skvrnjan/sd/hoho25k/", trust_remote_code=True)
|
|
@@ -25,23 +26,34 @@ scores_hss = []
|
|
| 25 |
scores_f1 = []
|
| 26 |
scores_iou = []
|
| 27 |
|
| 28 |
-
show_visu =
|
| 29 |
|
| 30 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 31 |
|
| 32 |
#pnet_model = load_pointnet_model(model_path="/home/skvrnjan/personal/hoho_pnet/initial_epoch_100.pth", device=device, predict_score=True)
|
| 33 |
pnet_model = None
|
| 34 |
|
|
|
|
|
|
|
|
|
|
| 35 |
idx = 0
|
| 36 |
for a in tqdm(ds['train'], desc="Processing dataset"):
|
| 37 |
#plot_all_modalities(a)
|
| 38 |
-
#pred_vertices, pred_edges =
|
|
|
|
| 39 |
try:
|
| 40 |
-
pred_vertices, pred_edges = predict_wireframe(a, pnet_model)
|
|
|
|
| 41 |
except:
|
| 42 |
pred_vertices, pred_edges = empty_solution()
|
| 43 |
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
colmap = read_colmap_rec(a['colmap_binary'])
|
| 46 |
pcd, geometries = plot_reconstruction_local(None, colmap, points=True, cameras=True, crop_outliers=True)
|
| 47 |
wireframe = plot_wireframe_local(None, a['wf_vertices'], a['wf_edges'], a['wf_classifications'])
|
|
@@ -49,13 +61,7 @@ for a in tqdm(ds['train'], desc="Processing dataset"):
|
|
| 49 |
bpo_cams = plot_bpo_cameras_from_entry_local(None, a)
|
| 50 |
|
| 51 |
visu_all = [pcd] + geometries + wireframe + bpo_cams + wireframe2
|
| 52 |
-
o3d.visualization.draw_geometries(visu_all, window_name="3D Reconstruction")
|
| 53 |
-
|
| 54 |
-
score = hss(pred_vertices, pred_edges, a['wf_vertices'], a['wf_edges'], vert_thresh=0.5, edge_thresh=0.5)
|
| 55 |
-
print(f"Score: {score}")
|
| 56 |
-
scores_hss.append(score.hss)
|
| 57 |
-
scores_f1.append(score.f1)
|
| 58 |
-
scores_iou.append(score.iou)
|
| 59 |
|
| 60 |
for i in range(10):
|
| 61 |
print("END OF DATASET")
|
|
|
|
| 12 |
|
| 13 |
#from hoho2025.example_solutions import predict_wireframe
|
| 14 |
from hoho2025.metric_helper import hss
|
| 15 |
+
from predict import predict_wireframe, predict_wireframe_old
|
| 16 |
from tqdm import tqdm
|
| 17 |
from fast_pointnet import load_pointnet_model
|
| 18 |
+
from fast_voxel import load_3dcnn_model
|
| 19 |
import torch
|
| 20 |
|
| 21 |
#ds = load_dataset("usm3d/hoho25k", cache_dir="/media/skvrnjan/sd/hoho25k/", trust_remote_code=True)
|
|
|
|
| 26 |
scores_f1 = []
|
| 27 |
scores_iou = []
|
| 28 |
|
| 29 |
+
show_visu = False
|
| 30 |
|
| 31 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 32 |
|
| 33 |
#pnet_model = load_pointnet_model(model_path="/home/skvrnjan/personal/hoho_pnet/initial_epoch_100.pth", device=device, predict_score=True)
|
| 34 |
pnet_model = None
|
| 35 |
|
| 36 |
+
#voxel_model = load_3dcnn_model(model_path="/home/skvrnjan/personal/hoho_voxel/initial_epoch_100.pth", device=device, predict_score=True)
|
| 37 |
+
voxel_model = None
|
| 38 |
+
|
| 39 |
idx = 0
|
| 40 |
for a in tqdm(ds['train'], desc="Processing dataset"):
|
| 41 |
#plot_all_modalities(a)
|
| 42 |
+
#pred_vertices, pred_edges = predict_wireframe_old(a)
|
| 43 |
+
#pred_vertices, pred_edges = predict_wireframe(a, pnet_model, voxel_model)
|
| 44 |
try:
|
| 45 |
+
pred_vertices, pred_edges = predict_wireframe(a, pnet_model, voxel_model)
|
| 46 |
+
#pred_vertices, pred_edges = predict_wireframe_old(a)
|
| 47 |
except:
|
| 48 |
pred_vertices, pred_edges = empty_solution()
|
| 49 |
|
| 50 |
+
score = hss(pred_vertices, pred_edges, a['wf_vertices'], a['wf_edges'], vert_thresh=0.5, edge_thresh=0.5)
|
| 51 |
+
print(f"Score: {score}")
|
| 52 |
+
scores_hss.append(score.hss)
|
| 53 |
+
scores_f1.append(score.f1)
|
| 54 |
+
scores_iou.append(score.iou)
|
| 55 |
+
|
| 56 |
+
if show_visu and score.hss < 0.1:
|
| 57 |
colmap = read_colmap_rec(a['colmap_binary'])
|
| 58 |
pcd, geometries = plot_reconstruction_local(None, colmap, points=True, cameras=True, crop_outliers=True)
|
| 59 |
wireframe = plot_wireframe_local(None, a['wf_vertices'], a['wf_edges'], a['wf_classifications'])
|
|
|
|
| 61 |
bpo_cams = plot_bpo_cameras_from_entry_local(None, a)
|
| 62 |
|
| 63 |
visu_all = [pcd] + geometries + wireframe + bpo_cams + wireframe2
|
| 64 |
+
o3d.visualization.draw_geometries(visu_all, window_name=f"3D Reconstruction - HSS: {score.hss:.4f}, F1: {score.f1:.4f}, IoU: {score.iou:.4f}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
for i in range(10):
|
| 67 |
print("END OF DATASET")
|