jskvrna commited on
Commit
573ee90
·
1 Parent(s): 7666e2c

Lets try the best solution we have currently.

Browse files
Files changed (1) hide show
  1. script.py +16 -1
script.py CHANGED
@@ -12,6 +12,11 @@ import gc
12
  from utils import empty_solution
13
  from predict import predict_wireframe
14
 
 
 
 
 
 
15
  if __name__ == "__main__":
16
  print ("------------ Loading dataset------------ ")
17
  param_path = Path('params.json')
@@ -72,12 +77,22 @@ if __name__ == "__main__":
72
 
73
  print(dataset, flush=True)
74
 
 
 
 
 
 
 
 
 
 
 
75
  print('------------ Now you can do your solution ---------------')
76
  solution = []
77
 
78
  def process_sample(sample, i):
79
  try:
80
- pred_vertices, pred_edges = predict_wireframe(sample)
81
  except:
82
  pred_vertices, pred_edges = empty_solution()
83
  if i %10 == 0:
 
12
  from utils import empty_solution
13
  from predict import predict_wireframe
14
 
15
+ from fast_pointnet import load_pointnet_model
16
+ from fast_voxel import load_3dcnn_model
17
+ from fast_pointnet_class import load_pointnet_model as load_pointnet_class_model
18
+ import torch
19
+
20
  if __name__ == "__main__":
21
  print ("------------ Loading dataset------------ ")
22
  param_path = Path('params.json')
 
77
 
78
  print(dataset, flush=True)
79
 
80
+ device = "cuda" if torch.cuda.is_available() else "cpu"
81
+
82
+ pnet_model = load_pointnet_model(model_path="pnet.pth", device=device, predict_score=True)
83
+
84
+ pnet_class_model = load_pointnet_class_model(model_path="pnet_class.pth", device=device)
85
+
86
+ voxel_model = None
87
+
88
+ config = {'vertex_threshold': 0.4, 'edge_threshold': 0.6, 'only_predicted_connections': False}
89
+
90
  print('------------ Now you can do your solution ---------------')
91
  solution = []
92
 
93
  def process_sample(sample, i):
94
  try:
95
+ pred_vertices, pred_edges = predict_wireframe(sample, pnet_model, voxel_model, pnet_class_model, config)
96
  except:
97
  pred_vertices, pred_edges = empty_solution()
98
  if i %10 == 0: