jskvrna commited on
Commit
2b67e89
·
1 Parent(s): 9b4ba1f

Improves training process and model initialization

Browse files

Initializes model weights using Xavier/Glorot initialization for better training.

Adds a check to skip saving patches if they already exist, preventing unnecessary writes.

Introduces a training script for the pointnet model.

Adds progress bar during training loop and shuffles the dataset for better performance.

Files changed (3) hide show
  1. fast_pointnet.py +20 -0
  2. train.py +4 -2
  3. train_pnet.py +12 -0
fast_pointnet.py CHANGED
@@ -226,6 +226,10 @@ def save_patches_dataset(patches: List[Dict], dataset_dir: str, entry_id: str):
226
  filename = f"{entry_id}_patch_{i}.pkl"
227
  filepath = os.path.join(dataset_dir, filename)
228
 
 
 
 
 
229
  # Save patch data
230
  with open(filepath, 'wb') as f:
231
  pickle.dump(patch, f)
@@ -266,6 +270,22 @@ def train_pointnet(dataset_dir: str, model_save_path: str, epochs: int = 100, ba
266
 
267
  # Initialize model with score prediction
268
  model = FastPointNet(input_dim=7, output_dim=3, max_points=1024, predict_score=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
  model.to(device)
270
 
271
  # Loss functions
 
226
  filename = f"{entry_id}_patch_{i}.pkl"
227
  filepath = os.path.join(dataset_dir, filename)
228
 
229
+ # Skip if file already exists
230
+ if os.path.exists(filepath):
231
+ continue
232
+
233
  # Save patch data
234
  with open(filepath, 'wb') as f:
235
  pickle.dump(patch, f)
 
270
 
271
  # Initialize model with score prediction
272
  model = FastPointNet(input_dim=7, output_dim=3, max_points=1024, predict_score=True)
273
+
274
+ # Initialize weights using Xavier/Glorot initialization
275
+ def init_weights(m):
276
+ if isinstance(m, nn.Conv1d):
277
+ nn.init.xavier_uniform_(m.weight)
278
+ if m.bias is not None:
279
+ nn.init.zeros_(m.bias)
280
+ elif isinstance(m, nn.Linear):
281
+ nn.init.xavier_uniform_(m.weight)
282
+ if m.bias is not None:
283
+ nn.init.zeros_(m.bias)
284
+ elif isinstance(m, nn.BatchNorm1d):
285
+ nn.init.ones_(m.weight)
286
+ nn.init.zeros_(m.bias)
287
+
288
+ model.apply(init_weights)
289
  model.to(device)
290
 
291
  # Loss functions
train.py CHANGED
@@ -12,8 +12,10 @@ from utils import read_colmap_rec, empty_solution
12
  #from hoho2025.example_solutions import predict_wireframe
13
  from hoho2025.metric_helper import hss
14
  from predict import predict_wireframe
 
15
 
16
- ds = load_dataset("usm3d/hoho25k", streaming=True, trust_remote_code=True)
 
17
 
18
  scores_hss = []
19
  scores_f1 = []
@@ -22,7 +24,7 @@ scores_iou = []
22
  show_visu = False
23
 
24
  idx = 0
25
- for a in ds['train']:
26
  #plot_all_modalities(a)
27
  #pred_vertices, pred_edges = predict_wireframe(a)
28
  try:
 
12
  #from hoho2025.example_solutions import predict_wireframe
13
  from hoho2025.metric_helper import hss
14
  from predict import predict_wireframe
15
+ from tqdm import tqdm
16
 
17
+ ds = load_dataset("usm3d/hoho25k", cache_dir='/home/skvrnjan/personal/hoho25k', trust_remote_code=True)
18
+ ds = ds.shuffle()
19
 
20
  scores_hss = []
21
  scores_f1 = []
 
24
  show_visu = False
25
 
26
  idx = 0
27
+ for a in tqdm(ds['train'], desc="Processing dataset"):
28
  #plot_all_modalities(a)
29
  #pred_vertices, pred_edges = predict_wireframe(a)
30
  try:
train_pnet.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fast_pointnet import FastPointNet
2
+
3
+ if __name__ == "__main__":
4
+ # Initialize the FastPointNet model
5
+ model = FastPointNet()
6
+
7
+ # Load the dataset
8
+ dataset = model.load_dataset("/home/skvrnjan/personal/hohocustom/")
9
+ model_save_path = "/home/skvrnjan/personal/hoho_pnet/"
10
+
11
+ # Train the model
12
+ model.train(dataset, model_save_path, epochs=10)