Commit
·
3150ada
1
Parent(s):
91b350c
separate cv_ins and cv_del + bump to 0.4
Browse files- hoho/wed.py +16 -10
- setup.py +1 -1
hoho/wed.py
CHANGED
|
@@ -28,7 +28,17 @@ def preregister_mean_std(verts_to_transform, target_verts, single_scale=True):
|
|
| 28 |
return transformed_verts
|
| 29 |
|
| 30 |
|
| 31 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
'''The function computes the Wireframe Edge Distance (WED) between two graphs.
|
| 33 |
pd_vertices: list of predicted vertices
|
| 34 |
pd_edges: list of predicted edges
|
|
@@ -51,13 +61,9 @@ def compute_WED(pd_vertices, pd_edges, gt_vertices, gt_edges, cv=-1/4, ce=1.0, n
|
|
| 51 |
if preregister:
|
| 52 |
pd_vertices = preregister_mean_std(pd_vertices, gt_vertices, single_scale=single_scale)
|
| 53 |
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
cv = -cv * diameter
|
| 58 |
-
elif cv == 0:
|
| 59 |
-
# Cost of adding or deleting a vertex is set to the average distance of the ground truth vertices from their mean
|
| 60 |
-
cv = np.linalg.norm(np.mean(gt_vertices, axis=0) - gt_vertices, axis=1).mean()
|
| 61 |
# Step 0: Prenormalize / preregister
|
| 62 |
|
| 63 |
|
|
@@ -74,11 +80,11 @@ def compute_WED(pd_vertices, pd_edges, gt_vertices, gt_edges, cv=-1/4, ce=1.0, n
|
|
| 74 |
|
| 75 |
# Additional: Vertex Deletion
|
| 76 |
unmatched_pd_indices = set(range(len(pd_vertices))) - set(row_ind)
|
| 77 |
-
deletion_costs =
|
| 78 |
|
| 79 |
# Step 3: Vertex Insertion
|
| 80 |
unmatched_gt_indices = set(range(len(gt_vertices))) - set(col_ind)
|
| 81 |
-
insertion_costs =
|
| 82 |
|
| 83 |
# Step 4: Edge Deletion and Insertion
|
| 84 |
updated_pd_edges = [(col_ind[np.where(row_ind == edge[0])[0][0]], col_ind[np.where(row_ind == edge[1])[0][0]]) for edge in pd_edges if edge[0] in row_ind and edge[1] in row_ind]
|
|
|
|
| 28 |
return transformed_verts
|
| 29 |
|
| 30 |
|
| 31 |
+
def update_cv(cv, gt_vertices):
|
| 32 |
+
if cv < 0:
|
| 33 |
+
diameter = cdist(gt_vertices, gt_vertices).max()
|
| 34 |
+
# Cost of adding or deleting a vertex is set to -cv times the diameter of the ground truth wireframe
|
| 35 |
+
cv = -cv * diameter
|
| 36 |
+
elif cv == 0:
|
| 37 |
+
# Cost of adding or deleting a vertex is set to the average distance of the ground truth vertices from their mean
|
| 38 |
+
cv = np.linalg.norm(np.mean(gt_vertices, axis=0) - gt_vertices, axis=1).mean()
|
| 39 |
+
return cv
|
| 40 |
+
|
| 41 |
+
def compute_WED(pd_vertices, pd_edges, gt_vertices, gt_edges, cv_ins=-1/2, cv_del=-1/4, ce=1.0, normalized=True, preregister=True, single_scale=True):
|
| 42 |
'''The function computes the Wireframe Edge Distance (WED) between two graphs.
|
| 43 |
pd_vertices: list of predicted vertices
|
| 44 |
pd_edges: list of predicted edges
|
|
|
|
| 61 |
if preregister:
|
| 62 |
pd_vertices = preregister_mean_std(pd_vertices, gt_vertices, single_scale=single_scale)
|
| 63 |
|
| 64 |
+
cv_del = update_cv(cv_del, gt_vertices)
|
| 65 |
+
cv_ins = update_cv(cv_ins, gt_vertices)
|
| 66 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
# Step 0: Prenormalize / preregister
|
| 68 |
|
| 69 |
|
|
|
|
| 80 |
|
| 81 |
# Additional: Vertex Deletion
|
| 82 |
unmatched_pd_indices = set(range(len(pd_vertices))) - set(row_ind)
|
| 83 |
+
deletion_costs = cv_del * len(unmatched_pd_indices)
|
| 84 |
|
| 85 |
# Step 3: Vertex Insertion
|
| 86 |
unmatched_gt_indices = set(range(len(gt_vertices))) - set(col_ind)
|
| 87 |
+
insertion_costs = cv_ins * len(unmatched_gt_indices)
|
| 88 |
|
| 89 |
# Step 4: Edge Deletion and Insertion
|
| 90 |
updated_pd_edges = [(col_ind[np.where(row_ind == edge[0])[0][0]], col_ind[np.where(row_ind == edge[1])[0][0]]) for edge in pd_edges if edge[0] in row_ind and edge[1] in row_ind]
|
setup.py
CHANGED
|
@@ -6,7 +6,7 @@ with open('requirements.txt') as f:
|
|
| 6 |
required = f.read().splitlines()
|
| 7 |
|
| 8 |
setup(name='hoho',
|
| 9 |
-
version='0.0.
|
| 10 |
description='Tools and utilites for the HoHo Dataset and S23DR Competition',
|
| 11 |
url='usm3d.github.io',
|
| 12 |
author='Jack Langerman, Dmytro Mishkin, S23DR Orgainizing Team',
|
|
|
|
| 6 |
required = f.read().splitlines()
|
| 7 |
|
| 8 |
setup(name='hoho',
|
| 9 |
+
version='0.0.4',
|
| 10 |
description='Tools and utilites for the HoHo Dataset and S23DR Competition',
|
| 11 |
url='usm3d.github.io',
|
| 12 |
author='Jack Langerman, Dmytro Mishkin, S23DR Orgainizing Team',
|