jskvrna commited on
Commit
4949317
·
1 Parent(s): 5c7538d
Files changed (1) hide show
  1. train.py +78 -12
train.py CHANGED
@@ -6,6 +6,8 @@ import tempfile,zipfile
6
  import io
7
  import open3d as o3d
8
  import os
 
 
9
 
10
  from visu import plot_reconstruction_local, plot_wireframe_local, plot_bpo_cameras_from_entry_local, _plotly_rgb_to_normalized_o3d_color
11
  from utils import read_colmap_rec, empty_solution
@@ -21,8 +23,31 @@ from fast_pointnet_class_10d import load_pointnet_model as load_pointnet_class_m
21
  import torch
22
  import time
23
 
24
- #ds = load_dataset("usm3d/hoho25k", cache_dir="/media/skvrnjan/sd/hoho25k/", trust_remote_code=True)
25
- ds = load_dataset("usm3d/hoho25k", cache_dir="/mnt/personal/skvrnjan/hoho25k/", trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  #ds = ds.shuffle()
27
 
28
  scores_hss = []
@@ -45,7 +70,6 @@ pnet_class_model = load_pointnet_class_model(model_path="pnet_class.pth", device
45
  #voxel_model = load_3dcnn_model(model_path="/home/skvrnjan/personal/hoho_voxel/initial_epoch_100.pth", device=device, predict_score=True)
46
  voxel_model = None
47
 
48
- config = {'vertex_threshold': 0.4, 'edge_threshold': 0.6, 'only_predicted_connections': False}
49
 
50
  idx = 0
51
  prediction_times = []
@@ -60,9 +84,13 @@ for a in tqdm(ds['train'], desc="Processing dataset"):
60
  end_time = time.time()
61
  prediction_time = end_time - start_time
62
  prediction_times.append(prediction_time)
63
- mean_time = np.mean(prediction_times)
64
- print(f"Prediction time: {prediction_time:.4f} seconds, Mean time: {mean_time:.4f} seconds")
65
- except:
 
 
 
 
66
  pred_vertices, pred_edges = empty_solution()
67
 
68
  score = hss(pred_vertices, pred_edges, a['wf_vertices'], a['wf_edges'], vert_thresh=0.5, edge_thresh=0.5)
@@ -82,13 +110,51 @@ for a in tqdm(ds['train'], desc="Processing dataset"):
82
  o3d.visualization.draw_geometries(visu_all, window_name=f"3D Reconstruction - HSS: {score.hss:.4f}, F1: {score.f1:.4f}, IoU: {score.iou:.4f}")
83
 
84
  idx += 1
85
- if idx >= 100: # Limit to first 10 samples for testing
 
86
  break
87
 
88
- for i in range(10):
89
  print("END OF DATASET")
90
- print(f"Mean HSS: {np.mean(scores_hss):.4f}")
91
- print(f"Mean F1: {np.mean(scores_f1):.4f}")
92
- print(f"Mean IoU: {np.mean(scores_iou):.4f}")
93
- print(config)
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import io
7
  import open3d as o3d
8
  import os
9
+ import argparse # Added for command-line arguments
10
+ import numpy as np # Make sure numpy is imported if not already implicitly
11
 
12
  from visu import plot_reconstruction_local, plot_wireframe_local, plot_bpo_cameras_from_entry_local, _plotly_rgb_to_normalized_o3d_color
13
  from utils import read_colmap_rec, empty_solution
 
23
  import torch
24
  import time
25
 
26
+ # --- Argument Parsing ---
27
+ parser = argparse.ArgumentParser(description="Train and evaluate HoHo model with custom config.")
28
+ parser.add_argument('--vertex_threshold', type=float, default=0.4, help='Vertex threshold for prediction.')
29
+ parser.add_argument('--edge_threshold', type=float, default=0.625, help='Edge threshold for prediction.')
30
+ parser.add_argument('--only_predicted_connections', type=lambda x: (str(x).lower() == 'true'), default=True, help='Use only predicted connections (True/False).')
31
+ parser.add_argument('--max_samples', type=int, default=100, help='Maximum number of samples to process.')
32
+ parser.add_argument('--results_dir', type=str, default="results", help='Directory to save result files.')
33
+
34
+
35
+ args = parser.parse_args()
36
+
37
+ # --- Configuration from Arguments ---
38
+ config = {
39
+ 'vertex_threshold': args.vertex_threshold,
40
+ 'edge_threshold': args.edge_threshold,
41
+ 'only_predicted_connections': args.only_predicted_connections
42
+ }
43
+ print(f"Running with configuration: {config}")
44
+
45
+ # Create results directory if it doesn't exist
46
+ os.makedirs(args.results_dir, exist_ok=True)
47
+
48
+
49
+ ds = load_dataset("usm3d/hoho25k", cache_dir="/media/skvrnjan/sd/hoho25k/", trust_remote_code=True)
50
+ #ds = load_dataset("usm3d/hoho25k", cache_dir="/mnt/personal/skvrnjan/hoho25k/", trust_remote_code=True)
51
  #ds = ds.shuffle()
52
 
53
  scores_hss = []
 
70
  #voxel_model = load_3dcnn_model(model_path="/home/skvrnjan/personal/hoho_voxel/initial_epoch_100.pth", device=device, predict_score=True)
71
  voxel_model = None
72
 
 
73
 
74
  idx = 0
75
  prediction_times = []
 
84
  end_time = time.time()
85
  prediction_time = end_time - start_time
86
  prediction_times.append(prediction_time)
87
+ if prediction_times: # ensure not empty before calculating mean
88
+ mean_time = np.mean(prediction_times)
89
+ print(f"Prediction time: {prediction_time:.4f} seconds, Mean time: {mean_time:.4f} seconds")
90
+ else:
91
+ print(f"Prediction time: {prediction_time:.4f} seconds")
92
+ except Exception as e: # Catch specific exceptions if possible, or log the error
93
+ print(f"Error during prediction: {e}")
94
  pred_vertices, pred_edges = empty_solution()
95
 
96
  score = hss(pred_vertices, pred_edges, a['wf_vertices'], a['wf_edges'], vert_thresh=0.5, edge_thresh=0.5)
 
110
  o3d.visualization.draw_geometries(visu_all, window_name=f"3D Reconstruction - HSS: {score.hss:.4f}, F1: {score.f1:.4f}, IoU: {score.iou:.4f}")
111
 
112
  idx += 1
113
+ if idx >= args.max_samples:
114
+ print(f"Reached max_samples limit: {args.max_samples}")
115
  break
116
 
117
+ for i in range(10): # This loop seems to be for console output spacing
118
  print("END OF DATASET")
 
 
 
 
119
 
120
+ mean_hss_val = np.mean(scores_hss) if scores_hss else 0.0
121
+ mean_f1_val = np.mean(scores_f1) if scores_f1 else 0.0
122
+ mean_iou_val = np.mean(scores_iou) if scores_iou else 0.0
123
+
124
+
125
+ print(f"Mean HSS: {mean_hss_val:.4f}")
126
+ print(f"Mean F1: {mean_f1_val:.4f}")
127
+ print(f"Mean IoU: {mean_iou_val:.4f}")
128
+ print(f"Final Config: {config}")
129
+ if prediction_times:
130
+ print(f"Overall Mean Prediction Time: {np.mean(prediction_times):.4f} seconds")
131
+
132
+
133
+ # --- Writing results to a file ---
134
+ vt_str = str(config['vertex_threshold']).replace('.', 'p')
135
+ et_str = str(config['edge_threshold']).replace('.', 'p')
136
+ opc_str = str(config['only_predicted_connections'])
137
+
138
+ results_filename = f"results_vt{vt_str}_et{et_str}_opc{opc_str}_samples{args.max_samples}.txt"
139
+ results_filepath = os.path.join(args.results_dir, results_filename)
140
+
141
+ with open(results_filepath, 'w') as f:
142
+ f.write(f"Configuration: {config}\n")
143
+ f.write(f"Max Samples Processed: {args.max_samples}\n")
144
+ f.write(f"Mean HSS: {mean_hss_val:.4f}\n")
145
+ f.write(f"Mean F1: {mean_f1_val:.4f}\n")
146
+ f.write(f"Mean IoU: {mean_iou_val:.4f}\n")
147
+ if prediction_times:
148
+ f.write(f"Overall Mean Prediction Time: {np.mean(prediction_times):.4f} seconds\n")
149
+ f.write("\nIndividual HSS Scores:\n")
150
+ for s_hss in scores_hss:
151
+ f.write(f"{s_hss:.4f}\n")
152
+ f.write("\nIndividual F1 Scores:\n")
153
+ for s_f1 in scores_f1:
154
+ f.write(f"{s_f1:.4f}\n")
155
+ f.write("\nIndividual IoU Scores:\n")
156
+ for s_iou in scores_iou:
157
+ f.write(f"{s_iou:.4f}\n")
158
+
159
+
160
+ print(f"Results saved to {results_filepath}")