train.py
CHANGED
|
@@ -6,6 +6,8 @@ import tempfile,zipfile
|
|
| 6 |
import io
|
| 7 |
import open3d as o3d
|
| 8 |
import os
|
|
|
|
|
|
|
| 9 |
|
| 10 |
from visu import plot_reconstruction_local, plot_wireframe_local, plot_bpo_cameras_from_entry_local, _plotly_rgb_to_normalized_o3d_color
|
| 11 |
from utils import read_colmap_rec, empty_solution
|
|
@@ -21,8 +23,31 @@ from fast_pointnet_class_10d import load_pointnet_model as load_pointnet_class_m
|
|
| 21 |
import torch
|
| 22 |
import time
|
| 23 |
|
| 24 |
-
#
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
#ds = ds.shuffle()
|
| 27 |
|
| 28 |
scores_hss = []
|
|
@@ -45,7 +70,6 @@ pnet_class_model = load_pointnet_class_model(model_path="pnet_class.pth", device
|
|
| 45 |
#voxel_model = load_3dcnn_model(model_path="/home/skvrnjan/personal/hoho_voxel/initial_epoch_100.pth", device=device, predict_score=True)
|
| 46 |
voxel_model = None
|
| 47 |
|
| 48 |
-
config = {'vertex_threshold': 0.4, 'edge_threshold': 0.6, 'only_predicted_connections': False}
|
| 49 |
|
| 50 |
idx = 0
|
| 51 |
prediction_times = []
|
|
@@ -60,9 +84,13 @@ for a in tqdm(ds['train'], desc="Processing dataset"):
|
|
| 60 |
end_time = time.time()
|
| 61 |
prediction_time = end_time - start_time
|
| 62 |
prediction_times.append(prediction_time)
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
pred_vertices, pred_edges = empty_solution()
|
| 67 |
|
| 68 |
score = hss(pred_vertices, pred_edges, a['wf_vertices'], a['wf_edges'], vert_thresh=0.5, edge_thresh=0.5)
|
|
@@ -82,13 +110,51 @@ for a in tqdm(ds['train'], desc="Processing dataset"):
|
|
| 82 |
o3d.visualization.draw_geometries(visu_all, window_name=f"3D Reconstruction - HSS: {score.hss:.4f}, F1: {score.f1:.4f}, IoU: {score.iou:.4f}")
|
| 83 |
|
| 84 |
idx += 1
|
| 85 |
-
if idx >=
|
|
|
|
| 86 |
break
|
| 87 |
|
| 88 |
-
for i in range(10):
|
| 89 |
print("END OF DATASET")
|
| 90 |
-
print(f"Mean HSS: {np.mean(scores_hss):.4f}")
|
| 91 |
-
print(f"Mean F1: {np.mean(scores_f1):.4f}")
|
| 92 |
-
print(f"Mean IoU: {np.mean(scores_iou):.4f}")
|
| 93 |
-
print(config)
|
| 94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
import io
|
| 7 |
import open3d as o3d
|
| 8 |
import os
|
| 9 |
+
import argparse # Added for command-line arguments
|
| 10 |
+
import numpy as np # Make sure numpy is imported if not already implicitly
|
| 11 |
|
| 12 |
from visu import plot_reconstruction_local, plot_wireframe_local, plot_bpo_cameras_from_entry_local, _plotly_rgb_to_normalized_o3d_color
|
| 13 |
from utils import read_colmap_rec, empty_solution
|
|
|
|
| 23 |
import torch
|
| 24 |
import time
|
| 25 |
|
| 26 |
+
# --- Argument Parsing ---
|
| 27 |
+
parser = argparse.ArgumentParser(description="Train and evaluate HoHo model with custom config.")
|
| 28 |
+
parser.add_argument('--vertex_threshold', type=float, default=0.4, help='Vertex threshold for prediction.')
|
| 29 |
+
parser.add_argument('--edge_threshold', type=float, default=0.625, help='Edge threshold for prediction.')
|
| 30 |
+
parser.add_argument('--only_predicted_connections', type=lambda x: (str(x).lower() == 'true'), default=True, help='Use only predicted connections (True/False).')
|
| 31 |
+
parser.add_argument('--max_samples', type=int, default=100, help='Maximum number of samples to process.')
|
| 32 |
+
parser.add_argument('--results_dir', type=str, default="results", help='Directory to save result files.')
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
args = parser.parse_args()
|
| 36 |
+
|
| 37 |
+
# --- Configuration from Arguments ---
|
| 38 |
+
config = {
|
| 39 |
+
'vertex_threshold': args.vertex_threshold,
|
| 40 |
+
'edge_threshold': args.edge_threshold,
|
| 41 |
+
'only_predicted_connections': args.only_predicted_connections
|
| 42 |
+
}
|
| 43 |
+
print(f"Running with configuration: {config}")
|
| 44 |
+
|
| 45 |
+
# Create results directory if it doesn't exist
|
| 46 |
+
os.makedirs(args.results_dir, exist_ok=True)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
ds = load_dataset("usm3d/hoho25k", cache_dir="/media/skvrnjan/sd/hoho25k/", trust_remote_code=True)
|
| 50 |
+
#ds = load_dataset("usm3d/hoho25k", cache_dir="/mnt/personal/skvrnjan/hoho25k/", trust_remote_code=True)
|
| 51 |
#ds = ds.shuffle()
|
| 52 |
|
| 53 |
scores_hss = []
|
|
|
|
| 70 |
#voxel_model = load_3dcnn_model(model_path="/home/skvrnjan/personal/hoho_voxel/initial_epoch_100.pth", device=device, predict_score=True)
|
| 71 |
voxel_model = None
|
| 72 |
|
|
|
|
| 73 |
|
| 74 |
idx = 0
|
| 75 |
prediction_times = []
|
|
|
|
| 84 |
end_time = time.time()
|
| 85 |
prediction_time = end_time - start_time
|
| 86 |
prediction_times.append(prediction_time)
|
| 87 |
+
if prediction_times: # ensure not empty before calculating mean
|
| 88 |
+
mean_time = np.mean(prediction_times)
|
| 89 |
+
print(f"Prediction time: {prediction_time:.4f} seconds, Mean time: {mean_time:.4f} seconds")
|
| 90 |
+
else:
|
| 91 |
+
print(f"Prediction time: {prediction_time:.4f} seconds")
|
| 92 |
+
except Exception as e: # Catch specific exceptions if possible, or log the error
|
| 93 |
+
print(f"Error during prediction: {e}")
|
| 94 |
pred_vertices, pred_edges = empty_solution()
|
| 95 |
|
| 96 |
score = hss(pred_vertices, pred_edges, a['wf_vertices'], a['wf_edges'], vert_thresh=0.5, edge_thresh=0.5)
|
|
|
|
| 110 |
o3d.visualization.draw_geometries(visu_all, window_name=f"3D Reconstruction - HSS: {score.hss:.4f}, F1: {score.f1:.4f}, IoU: {score.iou:.4f}")
|
| 111 |
|
| 112 |
idx += 1
|
| 113 |
+
if idx >= args.max_samples:
|
| 114 |
+
print(f"Reached max_samples limit: {args.max_samples}")
|
| 115 |
break
|
| 116 |
|
| 117 |
+
for i in range(10): # This loop seems to be for console output spacing
|
| 118 |
print("END OF DATASET")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
+
mean_hss_val = np.mean(scores_hss) if scores_hss else 0.0
|
| 121 |
+
mean_f1_val = np.mean(scores_f1) if scores_f1 else 0.0
|
| 122 |
+
mean_iou_val = np.mean(scores_iou) if scores_iou else 0.0
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
print(f"Mean HSS: {mean_hss_val:.4f}")
|
| 126 |
+
print(f"Mean F1: {mean_f1_val:.4f}")
|
| 127 |
+
print(f"Mean IoU: {mean_iou_val:.4f}")
|
| 128 |
+
print(f"Final Config: {config}")
|
| 129 |
+
if prediction_times:
|
| 130 |
+
print(f"Overall Mean Prediction Time: {np.mean(prediction_times):.4f} seconds")
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
# --- Writing results to a file ---
|
| 134 |
+
vt_str = str(config['vertex_threshold']).replace('.', 'p')
|
| 135 |
+
et_str = str(config['edge_threshold']).replace('.', 'p')
|
| 136 |
+
opc_str = str(config['only_predicted_connections'])
|
| 137 |
+
|
| 138 |
+
results_filename = f"results_vt{vt_str}_et{et_str}_opc{opc_str}_samples{args.max_samples}.txt"
|
| 139 |
+
results_filepath = os.path.join(args.results_dir, results_filename)
|
| 140 |
+
|
| 141 |
+
with open(results_filepath, 'w') as f:
|
| 142 |
+
f.write(f"Configuration: {config}\n")
|
| 143 |
+
f.write(f"Max Samples Processed: {args.max_samples}\n")
|
| 144 |
+
f.write(f"Mean HSS: {mean_hss_val:.4f}\n")
|
| 145 |
+
f.write(f"Mean F1: {mean_f1_val:.4f}\n")
|
| 146 |
+
f.write(f"Mean IoU: {mean_iou_val:.4f}\n")
|
| 147 |
+
if prediction_times:
|
| 148 |
+
f.write(f"Overall Mean Prediction Time: {np.mean(prediction_times):.4f} seconds\n")
|
| 149 |
+
f.write("\nIndividual HSS Scores:\n")
|
| 150 |
+
for s_hss in scores_hss:
|
| 151 |
+
f.write(f"{s_hss:.4f}\n")
|
| 152 |
+
f.write("\nIndividual F1 Scores:\n")
|
| 153 |
+
for s_f1 in scores_f1:
|
| 154 |
+
f.write(f"{s_f1:.4f}\n")
|
| 155 |
+
f.write("\nIndividual IoU Scores:\n")
|
| 156 |
+
for s_iou in scores_iou:
|
| 157 |
+
f.write(f"{s_iou:.4f}\n")
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
print(f"Results saved to {results_filepath}")
|