Denys Rozumnyi commited on
Commit ·
b28c65a
1
Parent(s): fc034ff
Public release
Browse files- geom_solver.py +4 -0
- main.ipynb +16 -19
geom_solver.py
CHANGED
|
@@ -14,6 +14,10 @@ def my_empty_solution():
|
|
| 14 |
return np.zeros((20,3)), [(0, 0)]
|
| 15 |
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
def fully_connected_solution(vertices=None):
|
| 18 |
if vertices is None:
|
| 19 |
nverts = 20
|
|
|
|
| 14 |
return np.zeros((20,3)), [(0, 0)]
|
| 15 |
|
| 16 |
|
| 17 |
+
def one_line_solution(n):
|
| 18 |
+
return np.zeros((n,3)), list(itertools.product(list(range(n)), list(range(n))))
|
| 19 |
+
|
| 20 |
+
|
| 21 |
def fully_connected_solution(vertices=None):
|
| 22 |
if vertices is None:
|
| 23 |
nverts = 20
|
main.ipynb
CHANGED
|
@@ -23,14 +23,14 @@
|
|
| 23 |
},
|
| 24 |
{
|
| 25 |
"cell_type": "code",
|
| 26 |
-
"execution_count":
|
| 27 |
"id": "503c6bcb-aa46-46c6-8b86-566b0a470b43",
|
| 28 |
"metadata": {},
|
| 29 |
"outputs": [
|
| 30 |
{
|
| 31 |
"data": {
|
| 32 |
"application/vnd.jupyter.widget-view+json": {
|
| 33 |
-
"model_id": "
|
| 34 |
"version_major": 2,
|
| 35 |
"version_minor": 0
|
| 36 |
},
|
|
@@ -76,7 +76,7 @@
|
|
| 76 |
"\n",
|
| 77 |
"from my_solution import predict\n",
|
| 78 |
"from helpers import *\n",
|
| 79 |
-
"from geom_solver import GeomSolver\n",
|
| 80 |
"\n",
|
| 81 |
"import huggingface_hub \n",
|
| 82 |
"huggingface_hub.login()"
|
|
@@ -84,7 +84,7 @@
|
|
| 84 |
},
|
| 85 |
{
|
| 86 |
"cell_type": "code",
|
| 87 |
-
"execution_count":
|
| 88 |
"id": "88f4fc8f-efa9-404b-9073-c7d4a73f9075",
|
| 89 |
"metadata": {},
|
| 90 |
"outputs": [
|
|
@@ -94,16 +94,16 @@
|
|
| 94 |
"text": [
|
| 95 |
"2.1428452992319946 1.8753585980967309\n",
|
| 96 |
"1.2634583203026981 1.2634583203026981\n",
|
| 97 |
-
"2.
|
| 98 |
-
"1.
|
| 99 |
"1.6604589895130097 1.4945561543255432\n",
|
| 100 |
"1.342868026940067 1.5159383672128168\n",
|
| 101 |
-
"1.
|
| 102 |
-
"1.2228932592065895 1.
|
| 103 |
"2.1274713572398154 1.9577219841781546\n",
|
| 104 |
"1.2389391918918637 1.1907499853854275\n",
|
| 105 |
"1.301931556958707 1.5291655623486928\n",
|
| 106 |
-
"1.
|
| 107 |
"1.300528267625377 1.1364744990514501\n",
|
| 108 |
"1.9971234414064012 1.484359575186481\n",
|
| 109 |
"1.7864233783834014 1.2997637652539\n",
|
|
@@ -111,10 +111,10 @@
|
|
| 111 |
"1.2503083972410651 1.2672496454499096\n",
|
| 112 |
"1.9554601999832553 2.4290767592247864\n",
|
| 113 |
"1.5169825260067826 1.862860315670855\n",
|
| 114 |
-
"1.6424954827356966 1.
|
| 115 |
"1.40382235616253 0.9727095077709508\n",
|
| 116 |
"2.5531277383538464 2.2956668643388705\n",
|
| 117 |
-
"1.
|
| 118 |
"1.1352963385173234 1.3060154430256012\n",
|
| 119 |
"2.1787116657747303 1.7257473870225462\n",
|
| 120 |
"2.199844649970312 2.2529145560735175\n",
|
|
@@ -123,7 +123,7 @@
|
|
| 123 |
"1.6472782885605137 1.4709840223700106\n",
|
| 124 |
"1.332152878564964 1.6610485499747143\n",
|
| 125 |
"Averages\n",
|
| 126 |
-
"1.
|
| 127 |
]
|
| 128 |
}
|
| 129 |
],
|
|
@@ -144,15 +144,12 @@
|
|
| 144 |
" solver = GeomSolver()\n",
|
| 145 |
" vertices, edges = solver.solve(entry)\n",
|
| 146 |
" \n",
|
| 147 |
-
" nverts = entry['wf_vertices'].shape[0]\n",
|
| 148 |
" # nverts = vertices.shape[0]\n",
|
| 149 |
" nverts = 20\n",
|
| 150 |
-
" \n",
|
| 151 |
-
"
|
| 152 |
-
"
|
| 153 |
-
" edges0 = [edg for edg in edges0 if edg[0] < edg[1]]\n",
|
| 154 |
-
" \n",
|
| 155 |
-
" scores0 = (compute_WED(np.zeros((nverts,3)),\n",
|
| 156 |
" edges0,\n",
|
| 157 |
" np.array(entry['wf_vertices']),\n",
|
| 158 |
" np.array(entry['wf_edges'])))\n",
|
|
|
|
| 23 |
},
|
| 24 |
{
|
| 25 |
"cell_type": "code",
|
| 26 |
+
"execution_count": 11,
|
| 27 |
"id": "503c6bcb-aa46-46c6-8b86-566b0a470b43",
|
| 28 |
"metadata": {},
|
| 29 |
"outputs": [
|
| 30 |
{
|
| 31 |
"data": {
|
| 32 |
"application/vnd.jupyter.widget-view+json": {
|
| 33 |
+
"model_id": "bccfd9e961614341bcbd9a4dba7a4a99",
|
| 34 |
"version_major": 2,
|
| 35 |
"version_minor": 0
|
| 36 |
},
|
|
|
|
| 76 |
"\n",
|
| 77 |
"from my_solution import predict\n",
|
| 78 |
"from helpers import *\n",
|
| 79 |
+
"from geom_solver import GeomSolver, one_line_solution\n",
|
| 80 |
"\n",
|
| 81 |
"import huggingface_hub \n",
|
| 82 |
"huggingface_hub.login()"
|
|
|
|
| 84 |
},
|
| 85 |
{
|
| 86 |
"cell_type": "code",
|
| 87 |
+
"execution_count": 13,
|
| 88 |
"id": "88f4fc8f-efa9-404b-9073-c7d4a73f9075",
|
| 89 |
"metadata": {},
|
| 90 |
"outputs": [
|
|
|
|
| 94 |
"text": [
|
| 95 |
"2.1428452992319946 1.8753585980967309\n",
|
| 96 |
"1.2634583203026981 1.2634583203026981\n",
|
| 97 |
+
"2.5396091977186415 2.542730868722136\n",
|
| 98 |
+
"1.5784396320724732 1.3744402739344028\n",
|
| 99 |
"1.6604589895130097 1.4945561543255432\n",
|
| 100 |
"1.342868026940067 1.5159383672128168\n",
|
| 101 |
+
"1.3318850656966443 1.492159322757521\n",
|
| 102 |
+
"1.2228932592065895 1.3374845788105776\n",
|
| 103 |
"2.1274713572398154 1.9577219841781546\n",
|
| 104 |
"1.2389391918918637 1.1907499853854275\n",
|
| 105 |
"1.301931556958707 1.5291655623486928\n",
|
| 106 |
+
"1.3586449315580662 1.349185516938256\n",
|
| 107 |
"1.300528267625377 1.1364744990514501\n",
|
| 108 |
"1.9971234414064012 1.484359575186481\n",
|
| 109 |
"1.7864233783834014 1.2997637652539\n",
|
|
|
|
| 111 |
"1.2503083972410651 1.2672496454499096\n",
|
| 112 |
"1.9554601999832553 2.4290767592247864\n",
|
| 113 |
"1.5169825260067826 1.862860315670855\n",
|
| 114 |
+
"1.6424954827356966 1.887332771285391\n",
|
| 115 |
"1.40382235616253 0.9727095077709508\n",
|
| 116 |
"2.5531277383538464 2.2956668643388705\n",
|
| 117 |
+
"1.3602630905708255 1.367457468141674\n",
|
| 118 |
"1.1352963385173234 1.3060154430256012\n",
|
| 119 |
"2.1787116657747303 1.7257473870225462\n",
|
| 120 |
"2.199844649970312 2.2529145560735175\n",
|
|
|
|
| 123 |
"1.6472782885605137 1.4709840223700106\n",
|
| 124 |
"1.332152878564964 1.6610485499747143\n",
|
| 125 |
"Averages\n",
|
| 126 |
+
"1.6704256833331772 1.6384823283515602\n"
|
| 127 |
]
|
| 128 |
}
|
| 129 |
],
|
|
|
|
| 144 |
" solver = GeomSolver()\n",
|
| 145 |
" vertices, edges = solver.solve(entry)\n",
|
| 146 |
" \n",
|
| 147 |
+
" # nverts = entry['wf_vertices'].shape[0]\n",
|
| 148 |
" # nverts = vertices.shape[0]\n",
|
| 149 |
" nverts = 20\n",
|
| 150 |
+
" vertices0, edges0 = one_line_solution(nverts)\n",
|
| 151 |
+
"\n",
|
| 152 |
+
" scores0 = (compute_WED(vertices0,\n",
|
|
|
|
|
|
|
|
|
|
| 153 |
" edges0,\n",
|
| 154 |
" np.array(entry['wf_vertices']),\n",
|
| 155 |
" np.array(entry['wf_edges'])))\n",
|