Improves training process and model initialization
Browse filesInitializes model weights using Xavier/Glorot initialization for better training.
Adds a check to skip saving patches if they already exist, preventing unnecessary writes.
Introduces a training script for the pointnet model.
Adds progress bar during training loop and shuffles the dataset for better performance.
- fast_pointnet.py +20 -0
- train.py +4 -2
- train_pnet.py +12 -0
fast_pointnet.py
CHANGED
|
@@ -226,6 +226,10 @@ def save_patches_dataset(patches: List[Dict], dataset_dir: str, entry_id: str):
|
|
| 226 |
filename = f"{entry_id}_patch_{i}.pkl"
|
| 227 |
filepath = os.path.join(dataset_dir, filename)
|
| 228 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
# Save patch data
|
| 230 |
with open(filepath, 'wb') as f:
|
| 231 |
pickle.dump(patch, f)
|
|
@@ -266,6 +270,22 @@ def train_pointnet(dataset_dir: str, model_save_path: str, epochs: int = 100, ba
|
|
| 266 |
|
| 267 |
# Initialize model with score prediction
|
| 268 |
model = FastPointNet(input_dim=7, output_dim=3, max_points=1024, predict_score=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
model.to(device)
|
| 270 |
|
| 271 |
# Loss functions
|
|
|
|
| 226 |
filename = f"{entry_id}_patch_{i}.pkl"
|
| 227 |
filepath = os.path.join(dataset_dir, filename)
|
| 228 |
|
| 229 |
+
# Skip if file already exists
|
| 230 |
+
if os.path.exists(filepath):
|
| 231 |
+
continue
|
| 232 |
+
|
| 233 |
# Save patch data
|
| 234 |
with open(filepath, 'wb') as f:
|
| 235 |
pickle.dump(patch, f)
|
|
|
|
| 270 |
|
| 271 |
# Initialize model with score prediction
|
| 272 |
model = FastPointNet(input_dim=7, output_dim=3, max_points=1024, predict_score=True)
|
| 273 |
+
|
| 274 |
+
# Initialize weights using Xavier/Glorot initialization
|
| 275 |
+
def init_weights(m):
|
| 276 |
+
if isinstance(m, nn.Conv1d):
|
| 277 |
+
nn.init.xavier_uniform_(m.weight)
|
| 278 |
+
if m.bias is not None:
|
| 279 |
+
nn.init.zeros_(m.bias)
|
| 280 |
+
elif isinstance(m, nn.Linear):
|
| 281 |
+
nn.init.xavier_uniform_(m.weight)
|
| 282 |
+
if m.bias is not None:
|
| 283 |
+
nn.init.zeros_(m.bias)
|
| 284 |
+
elif isinstance(m, nn.BatchNorm1d):
|
| 285 |
+
nn.init.ones_(m.weight)
|
| 286 |
+
nn.init.zeros_(m.bias)
|
| 287 |
+
|
| 288 |
+
model.apply(init_weights)
|
| 289 |
model.to(device)
|
| 290 |
|
| 291 |
# Loss functions
|
train.py
CHANGED
|
@@ -12,8 +12,10 @@ from utils import read_colmap_rec, empty_solution
|
|
| 12 |
#from hoho2025.example_solutions import predict_wireframe
|
| 13 |
from hoho2025.metric_helper import hss
|
| 14 |
from predict import predict_wireframe
|
|
|
|
| 15 |
|
| 16 |
-
ds = load_dataset("usm3d/hoho25k",
|
|
|
|
| 17 |
|
| 18 |
scores_hss = []
|
| 19 |
scores_f1 = []
|
|
@@ -22,7 +24,7 @@ scores_iou = []
|
|
| 22 |
show_visu = False
|
| 23 |
|
| 24 |
idx = 0
|
| 25 |
-
for a in ds['train']:
|
| 26 |
#plot_all_modalities(a)
|
| 27 |
#pred_vertices, pred_edges = predict_wireframe(a)
|
| 28 |
try:
|
|
|
|
| 12 |
#from hoho2025.example_solutions import predict_wireframe
|
| 13 |
from hoho2025.metric_helper import hss
|
| 14 |
from predict import predict_wireframe
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
|
| 17 |
+
ds = load_dataset("usm3d/hoho25k", cache_dir='/home/skvrnjan/personal/hoho25k', trust_remote_code=True)
|
| 18 |
+
ds = ds.shuffle()
|
| 19 |
|
| 20 |
scores_hss = []
|
| 21 |
scores_f1 = []
|
|
|
|
| 24 |
show_visu = False
|
| 25 |
|
| 26 |
idx = 0
|
| 27 |
+
for a in tqdm(ds['train'], desc="Processing dataset"):
|
| 28 |
#plot_all_modalities(a)
|
| 29 |
#pred_vertices, pred_edges = predict_wireframe(a)
|
| 30 |
try:
|
train_pnet.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fast_pointnet import FastPointNet
|
| 2 |
+
|
| 3 |
+
if __name__ == "__main__":
|
| 4 |
+
# Initialize the FastPointNet model
|
| 5 |
+
model = FastPointNet()
|
| 6 |
+
|
| 7 |
+
# Load the dataset
|
| 8 |
+
dataset = model.load_dataset("/home/skvrnjan/personal/hohocustom/")
|
| 9 |
+
model_save_path = "/home/skvrnjan/personal/hoho_pnet/"
|
| 10 |
+
|
| 11 |
+
# Train the model
|
| 12 |
+
model.train(dataset, model_save_path, epochs=10)
|