Denys Rozumnyi commited on
Commit
b28c65a
·
1 Parent(s): fc034ff

Public release

Browse files
Files changed (2) hide show
  1. geom_solver.py +4 -0
  2. main.ipynb +16 -19
geom_solver.py CHANGED
@@ -14,6 +14,10 @@ def my_empty_solution():
14
  return np.zeros((20,3)), [(0, 0)]
15
 
16
 
 
 
 
 
17
  def fully_connected_solution(vertices=None):
18
  if vertices is None:
19
  nverts = 20
 
14
  return np.zeros((20,3)), [(0, 0)]
15
 
16
 
17
+ def one_line_solution(n):
18
+ return np.zeros((n,3)), list(itertools.product(list(range(n)), list(range(n))))
19
+
20
+
21
  def fully_connected_solution(vertices=None):
22
  if vertices is None:
23
  nverts = 20
main.ipynb CHANGED
@@ -23,14 +23,14 @@
23
  },
24
  {
25
  "cell_type": "code",
26
- "execution_count": 2,
27
  "id": "503c6bcb-aa46-46c6-8b86-566b0a470b43",
28
  "metadata": {},
29
  "outputs": [
30
  {
31
  "data": {
32
  "application/vnd.jupyter.widget-view+json": {
33
- "model_id": "adf065a3fb644479b707047372091ea9",
34
  "version_major": 2,
35
  "version_minor": 0
36
  },
@@ -76,7 +76,7 @@
76
  "\n",
77
  "from my_solution import predict\n",
78
  "from helpers import *\n",
79
- "from geom_solver import GeomSolver\n",
80
  "\n",
81
  "import huggingface_hub \n",
82
  "huggingface_hub.login()"
@@ -84,7 +84,7 @@
84
  },
85
  {
86
  "cell_type": "code",
87
- "execution_count": 8,
88
  "id": "88f4fc8f-efa9-404b-9073-c7d4a73f9075",
89
  "metadata": {},
90
  "outputs": [
@@ -94,16 +94,16 @@
94
  "text": [
95
  "2.1428452992319946 1.8753585980967309\n",
96
  "1.2634583203026981 1.2634583203026981\n",
97
- "2.542730868722136 1.8969040131019905\n",
98
- "1.608507505603034 1.3744402739344028\n",
99
  "1.6604589895130097 1.4945561543255432\n",
100
  "1.342868026940067 1.5159383672128168\n",
101
- "1.3937990493440202 1.492159322757521\n",
102
- "1.2228932592065895 1.2545726800817392\n",
103
  "2.1274713572398154 1.9577219841781546\n",
104
  "1.2389391918918637 1.1907499853854275\n",
105
  "1.301931556958707 1.5291655623486928\n",
106
- "1.4968274710208265 1.349185516938256\n",
107
  "1.300528267625377 1.1364744990514501\n",
108
  "1.9971234414064012 1.484359575186481\n",
109
  "1.7864233783834014 1.2997637652539\n",
@@ -111,10 +111,10 @@
111
  "1.2503083972410651 1.2672496454499096\n",
112
  "1.9554601999832553 2.4290767592247864\n",
113
  "1.5169825260067826 1.862860315670855\n",
114
- "1.6424954827356966 1.782831036227262\n",
115
  "1.40382235616253 0.9727095077709508\n",
116
  "2.5531277383538464 2.2956668643388705\n",
117
- "1.4201505123873752 1.367457468141674\n",
118
  "1.1352963385173234 1.3060154430256012\n",
119
  "2.1787116657747303 1.7257473870225462\n",
120
  "2.199844649970312 2.2529145560735175\n",
@@ -123,7 +123,7 @@
123
  "1.6472782885605137 1.4709840223700106\n",
124
  "1.332152878564964 1.6610485499747143\n",
125
  "Averages\n",
126
- "1.6801981329818685 1.6107076453713232\n"
127
  ]
128
  }
129
  ],
@@ -144,15 +144,12 @@
144
  " solver = GeomSolver()\n",
145
  " vertices, edges = solver.solve(entry)\n",
146
  " \n",
147
- " nverts = entry['wf_vertices'].shape[0]\n",
148
  " # nverts = vertices.shape[0]\n",
149
  " nverts = 20\n",
150
- " \n",
151
- " all_verts = list(range(nverts))\n",
152
- " edges0 = list(itertools.product(all_verts, all_verts))\n",
153
- " edges0 = [edg for edg in edges0 if edg[0] < edg[1]]\n",
154
- " \n",
155
- " scores0 = (compute_WED(np.zeros((nverts,3)),\n",
156
  " edges0,\n",
157
  " np.array(entry['wf_vertices']),\n",
158
  " np.array(entry['wf_edges'])))\n",
 
23
  },
24
  {
25
  "cell_type": "code",
26
+ "execution_count": 11,
27
  "id": "503c6bcb-aa46-46c6-8b86-566b0a470b43",
28
  "metadata": {},
29
  "outputs": [
30
  {
31
  "data": {
32
  "application/vnd.jupyter.widget-view+json": {
33
+ "model_id": "bccfd9e961614341bcbd9a4dba7a4a99",
34
  "version_major": 2,
35
  "version_minor": 0
36
  },
 
76
  "\n",
77
  "from my_solution import predict\n",
78
  "from helpers import *\n",
79
+ "from geom_solver import GeomSolver, one_line_solution\n",
80
  "\n",
81
  "import huggingface_hub \n",
82
  "huggingface_hub.login()"
 
84
  },
85
  {
86
  "cell_type": "code",
87
+ "execution_count": 13,
88
  "id": "88f4fc8f-efa9-404b-9073-c7d4a73f9075",
89
  "metadata": {},
90
  "outputs": [
 
94
  "text": [
95
  "2.1428452992319946 1.8753585980967309\n",
96
  "1.2634583203026981 1.2634583203026981\n",
97
+ "2.5396091977186415 2.542730868722136\n",
98
+ "1.5784396320724732 1.3744402739344028\n",
99
  "1.6604589895130097 1.4945561543255432\n",
100
  "1.342868026940067 1.5159383672128168\n",
101
+ "1.3318850656966443 1.492159322757521\n",
102
+ "1.2228932592065895 1.3374845788105776\n",
103
  "2.1274713572398154 1.9577219841781546\n",
104
  "1.2389391918918637 1.1907499853854275\n",
105
  "1.301931556958707 1.5291655623486928\n",
106
+ "1.3586449315580662 1.349185516938256\n",
107
  "1.300528267625377 1.1364744990514501\n",
108
  "1.9971234414064012 1.484359575186481\n",
109
  "1.7864233783834014 1.2997637652539\n",
 
111
  "1.2503083972410651 1.2672496454499096\n",
112
  "1.9554601999832553 2.4290767592247864\n",
113
  "1.5169825260067826 1.862860315670855\n",
114
+ "1.6424954827356966 1.887332771285391\n",
115
  "1.40382235616253 0.9727095077709508\n",
116
  "2.5531277383538464 2.2956668643388705\n",
117
+ "1.3602630905708255 1.367457468141674\n",
118
  "1.1352963385173234 1.3060154430256012\n",
119
  "2.1787116657747303 1.7257473870225462\n",
120
  "2.199844649970312 2.2529145560735175\n",
 
123
  "1.6472782885605137 1.4709840223700106\n",
124
  "1.332152878564964 1.6610485499747143\n",
125
  "Averages\n",
126
+ "1.6704256833331772 1.6384823283515602\n"
127
  ]
128
  }
129
  ],
 
144
  " solver = GeomSolver()\n",
145
  " vertices, edges = solver.solve(entry)\n",
146
  " \n",
147
+ " # nverts = entry['wf_vertices'].shape[0]\n",
148
  " # nverts = vertices.shape[0]\n",
149
  " nverts = 20\n",
150
+ " vertices0, edges0 = one_line_solution(nverts)\n",
151
+ "\n",
152
+ " scores0 = (compute_WED(vertices0,\n",
 
 
 
153
  " edges0,\n",
154
  " np.array(entry['wf_vertices']),\n",
155
  " np.array(entry['wf_edges'])))\n",