IhorIvanyshyn01 commited on
Commit
db2ae8e
·
1 Parent(s): 6d0029c

Fix sys.path and networkx dependency for Hugging Face

Browse files
Files changed (2) hide show
  1. script.py +105 -447
  2. sklearn_submission.py +6 -0
script.py CHANGED
@@ -1,472 +1,130 @@
1
- ### This is example of the script that will be run in the test environment.
2
-
3
- ### You can change the rest of the code to define and test your solution.
4
- ### However, you should not change the signature of the provided function.
5
- ### The script saves "submission.json" file in the current directory.
6
- ### You can use any additional files and subdirectories to organize your code.
7
 
8
  from pathlib import Path
9
  from tqdm import tqdm
10
- import json
11
  import numpy as np
12
  from datasets import load_dataset
13
  from typing import Dict
14
- from joblib import Parallel, delayed
15
- from sklearn.cluster import DBSCAN
16
- from hoho2025 import example_solutions as hoho_example
17
-
18
- VERTEX_MERGE_EPS = 1.0
19
- EDGE_MIN_LENGTH = 0.5
20
- EDGE_MIN_SUPPORT_IMAGES = 1
21
- EDGE_MAX_ANGLE_DEG = 45.0
22
- VERTEX_MAX_COLMAP_DIST = 2.0
23
- VERTEX_MIN_EDGE_DEGREE = 1
24
- VERTEX_MIN_VIEW_COUNT = 1
25
-
26
- def empty_solution(sample):
27
- '''Return a minimal valid solution, i.e. 2 vertices and 1 edge.'''
28
- return np.zeros((2,3)), [(0, 1)]
29
-
30
- def _unit_vector(vector):
31
- norm = np.linalg.norm(vector)
32
- if norm == 0:
33
- return None
34
- return vector / norm
35
-
36
- def merge_vertices_dbscan(vertices_3d, edges, eps=VERTEX_MERGE_EPS):
37
- """Cluster nearby 3D vertices, replace clusters with centroids, and remap edges."""
38
- vertices_3d = np.asarray(vertices_3d, dtype=float)
39
- if vertices_3d.ndim != 2 or vertices_3d.shape[1] != 3 or len(vertices_3d) == 0:
40
- return vertices_3d.reshape((-1, 3)), []
41
-
42
- clustering = DBSCAN(eps=eps, min_samples=1).fit(vertices_3d)
43
- labels = clustering.labels_
44
- unique_labels = np.unique(labels)
45
- label_to_new_idx = {label: idx for idx, label in enumerate(unique_labels)}
46
-
47
- merged_vertices = np.stack(
48
- [vertices_3d[labels == label].mean(axis=0) for label in unique_labels],
49
- axis=0,
50
- )
51
-
52
- remapped_edges = []
53
- seen_edges = set()
54
- for a, b in edges:
55
- a = int(a)
56
- b = int(b)
57
- if a < 0 or b < 0 or a >= len(labels) or b >= len(labels):
58
- continue
59
-
60
- new_a = label_to_new_idx[labels[a]]
61
- new_b = label_to_new_idx[labels[b]]
62
- if new_a == new_b:
63
- continue
64
-
65
- edge = (min(new_a, new_b), max(new_a, new_b))
66
- if edge in seen_edges:
67
- continue
68
-
69
- seen_edges.add(edge)
70
- remapped_edges.append(edge)
71
-
72
- return merged_vertices, remapped_edges
73
-
74
- def merge_vertices_dbscan_with_edge_stats(vert_edge_per_image, eps=VERTEX_MERGE_EPS):
75
- """Merge multi-view vertices with DBSCAN and keep edge support/direction stats."""
76
- all_vertices = []
77
- old_edges = []
78
- vertex_types = []
79
- vertex_src_images = []
80
- cur_start = 0
81
-
82
- for img_idx, (vertices, connections, vertices_3d) in vert_edge_per_image.items():
83
- vertices_3d = np.asarray(vertices_3d, dtype=float).reshape((-1, 3))
84
- if len(vertices_3d) == 0:
85
- continue
86
-
87
- all_vertices.append(vertices_3d)
88
- vertex_types.extend([int(v.get('type') == 'apex') for v in vertices])
89
- vertex_src_images.extend([img_idx] * len(vertices_3d))
90
-
91
- for a, b in connections:
92
- a = int(a)
93
- b = int(b)
94
- if a < 0 or b < 0 or a >= len(vertices_3d) or b >= len(vertices_3d):
95
- continue
96
- old_edges.append((cur_start + a, cur_start + b, img_idx))
97
-
98
- cur_start += len(vertices_3d)
99
-
100
- if not all_vertices:
101
- return np.empty((0, 3)), [], np.array([], dtype=int), {}, {}
102
-
103
- all_vertices = np.concatenate(all_vertices, axis=0)
104
- labels = DBSCAN(eps=eps, min_samples=1).fit(all_vertices).labels_
105
-
106
- # Keep apex and non-apex vertices from collapsing into the same final corner.
107
- vertex_types = np.asarray(vertex_types, dtype=int)
108
- cluster_keys = [(int(label), int(vtype)) for label, vtype in zip(labels, vertex_types)]
109
- unique_keys = sorted(set(cluster_keys))
110
- key_to_new_idx = {key: idx for idx, key in enumerate(unique_keys)}
111
- old_to_new_idx = np.array([key_to_new_idx[key] for key in cluster_keys], dtype=int)
112
-
113
- merged_vertices = np.stack(
114
- [
115
- all_vertices[old_to_new_idx == idx].mean(axis=0)
116
- for idx in range(len(unique_keys))
117
- ],
118
- axis=0,
119
- )
120
-
121
- vertex_src_images = np.asarray(vertex_src_images)
122
- vertex_view_count = np.array(
123
- [
124
- len(set(vertex_src_images[old_to_new_idx == idx]))
125
- for idx in range(len(unique_keys))
126
- ],
127
- dtype=int,
128
- )
129
-
130
- edge_image_sets = {}
131
- edge_dirs = {}
132
- for old_a, old_b, img_idx in old_edges:
133
- new_a = int(old_to_new_idx[old_a])
134
- new_b = int(old_to_new_idx[old_b])
135
- if new_a == new_b:
136
- continue
137
-
138
- edge = (min(new_a, new_b), max(new_a, new_b))
139
- direction = _unit_vector(all_vertices[old_b] - all_vertices[old_a])
140
- if direction is None:
141
- continue
142
-
143
- final_direction = _unit_vector(merged_vertices[edge[1]] - merged_vertices[edge[0]])
144
- if final_direction is None:
145
- continue
146
- if np.dot(direction, final_direction) < 0:
147
- direction = -direction
148
-
149
- edge_image_sets.setdefault(edge, set()).add(img_idx)
150
- edge_dirs.setdefault(edge, []).append(direction)
151
-
152
- edge_vote_count = {edge: len(imgs) for edge, imgs in edge_image_sets.items()}
153
- edge_angle_ok = {
154
- edge: edge_directions_are_consistent(dirs, EDGE_MAX_ANGLE_DEG)
155
- for edge, dirs in edge_dirs.items()
156
- }
157
-
158
- return merged_vertices, list(edge_image_sets.keys()), vertex_view_count, edge_vote_count, edge_angle_ok
159
-
160
- def edge_directions_are_consistent(directions, max_angle_deg=EDGE_MAX_ANGLE_DEG):
161
- """Check whether per-image 3D edge directions agree after sign alignment."""
162
- if len(directions) <= 1:
163
- return True
164
-
165
- directions = np.asarray(directions, dtype=float)
166
- mean_direction = _unit_vector(directions.mean(axis=0))
167
- if mean_direction is None:
168
- return False
169
-
170
- min_cos = np.cos(np.deg2rad(max_angle_deg))
171
- return bool(np.all(directions @ mean_direction >= min_cos))
172
-
173
- def filter_edges_by_geometry(vertices_3d, edges, edge_vote_count=None, edge_angle_ok=None,
174
- min_len=EDGE_MIN_LENGTH, min_support=EDGE_MIN_SUPPORT_IMAGES):
175
- """Drop short, weakly supported, or direction-inconsistent edges."""
176
- vertices_3d = np.asarray(vertices_3d, dtype=float)
177
- if len(vertices_3d) == 0:
178
- return []
179
-
180
- filtered_edges = []
181
- seen_edges = set()
182
- for a, b in edges:
183
- a = int(a)
184
- b = int(b)
185
- if a < 0 or b < 0 or a >= len(vertices_3d) or b >= len(vertices_3d) or a == b:
186
- continue
187
-
188
- edge = (min(a, b), max(a, b))
189
- if edge in seen_edges:
190
- continue
191
-
192
- edge_length = np.linalg.norm(vertices_3d[edge[1]] - vertices_3d[edge[0]])
193
- if edge_length < min_len:
194
- continue
195
-
196
- if edge_vote_count is not None and edge_vote_count.get(edge, 0) < min_support:
197
- continue
198
-
199
- if edge_angle_ok is not None and not edge_angle_ok.get(edge, True):
200
- continue
201
-
202
- seen_edges.add(edge)
203
- filtered_edges.append(edge)
204
-
205
- return filtered_edges
206
-
207
- def prune_bad_vertices(vertices_3d, edges, colmap_rec, vertex_view_count,
208
- max_colmap_dist=VERTEX_MAX_COLMAP_DIST,
209
- min_edge_degree=VERTEX_MIN_EDGE_DEGREE,
210
- min_view_count=VERTEX_MIN_VIEW_COUNT):
211
- """Remove vertices with weak COLMAP, edge-degree, or multi-view support."""
212
- vertices_3d = np.asarray(vertices_3d, dtype=float)
213
- vertex_view_count = np.asarray(vertex_view_count, dtype=int)
214
- if len(vertices_3d) == 0:
215
- return vertices_3d.reshape((-1, 3)), []
216
-
217
- valid_edges = []
218
- for a, b in edges:
219
- a = int(a)
220
- b = int(b)
221
- if 0 <= a < len(vertices_3d) and 0 <= b < len(vertices_3d) and a != b:
222
- valid_edges.append((a, b))
223
-
224
- colmap_mask = np.ones(len(vertices_3d), dtype=bool)
225
- xyz_sfm = [point.xyz for point in colmap_rec.points3D.values()]
226
- if xyz_sfm:
227
- xyz_sfm = np.asarray(xyz_sfm, dtype=float)
228
- diff = vertices_3d[:, None, :] - xyz_sfm[None, :, :]
229
- min_dist = np.sqrt((diff ** 2).sum(axis=-1)).min(axis=1)
230
- colmap_mask = min_dist <= max_colmap_dist
231
-
232
- view_mask = vertex_view_count >= min_view_count
233
- keep_mask = colmap_mask & view_mask
234
-
235
- while True:
236
- degree = np.zeros(len(vertices_3d), dtype=int)
237
- for a, b in valid_edges:
238
- if keep_mask[a] and keep_mask[b]:
239
- degree[a] += 1
240
- degree[b] += 1
241
-
242
- next_keep_mask = keep_mask & (degree >= min_edge_degree)
243
- if np.array_equal(next_keep_mask, keep_mask):
244
- break
245
- keep_mask = next_keep_mask
246
-
247
- old_to_new = {}
248
- new_vertices = []
249
- for old_idx, keep in enumerate(keep_mask):
250
- if keep:
251
- old_to_new[old_idx] = len(new_vertices)
252
- new_vertices.append(vertices_3d[old_idx])
253
-
254
- new_edges = []
255
- seen_edges = set()
256
- for a, b in valid_edges:
257
- if a not in old_to_new or b not in old_to_new:
258
- continue
259
- edge = tuple(sorted((old_to_new[a], old_to_new[b])))
260
- if edge in seen_edges:
261
- continue
262
- seen_edges.add(edge)
263
- new_edges.append(edge)
264
-
265
- if not new_vertices:
266
- return np.empty((0, 3)), []
267
- return np.asarray(new_vertices), new_edges
268
-
269
- def predict_wireframe_filtered(sample, verbose=False):
270
- """Baseline prediction with DBSCAN vertex merge plus stricter edge filtering."""
271
- good_entry = hoho_example.convert_entry_to_human_readable(sample)
272
- vert_edge_per_image = {}
273
- colmap_rec = None
274
-
275
- for i, (gest, depth, img_id, ade_seg) in enumerate(zip(
276
- good_entry['gestalt'],
277
- good_entry['depth'],
278
- good_entry['image_ids'],
279
- good_entry['ade'],
280
- )):
281
- if 'colmap' in good_entry:
282
- colmap_rec = good_entry['colmap']
283
- else:
284
- colmap_rec = good_entry['colmap_binary']
285
-
286
- depth_size = (np.array(depth).shape[1], np.array(depth).shape[0])
287
- gest_seg_np = np.array(gest.resize(depth_size)).astype(np.uint8)
288
- vertices, connections = hoho_example.get_vertices_and_edges_from_segmentation(
289
- gest_seg_np, edge_th=10.
290
- )
291
-
292
- ade_seg_np = np.array(ade_seg.resize(depth_size)).astype(np.uint8)
293
- vertices, connections = hoho_example.filter_vertices_by_background(
294
- vertices, connections, ade_seg_np
295
- )
296
-
297
- if (len(vertices) < 2) or (len(connections) < 1):
298
- if verbose:
299
- print(f'Not enough vertices or connections found in image {i}, skipping.')
300
- vert_edge_per_image[i] = [], [], np.empty((0, 3))
301
- continue
302
-
303
- vertices_3d = hoho_example.create_3d_wireframe_single_image(
304
- vertices, connections, depth, colmap_rec, img_id, ade_seg, verbose=verbose
305
- )
306
- vert_edge_per_image[i] = vertices, connections, vertices_3d
307
-
308
- all_3d_vertices, connections_3d, vertex_view_count, edge_vote_count, edge_angle_ok = (
309
- merge_vertices_dbscan_with_edge_stats(vert_edge_per_image, VERTEX_MERGE_EPS)
310
- )
311
- connections_3d = filter_edges_by_geometry(
312
- all_3d_vertices,
313
- connections_3d,
314
- edge_vote_count=edge_vote_count,
315
- edge_angle_ok=edge_angle_ok,
316
- )
317
- if colmap_rec is None or len(all_3d_vertices) < 2 or len(connections_3d) < 1:
318
- return empty_solution(sample)
319
-
320
- all_3d_vertices_clean, connections_3d_clean = prune_bad_vertices(
321
- all_3d_vertices, connections_3d, colmap_rec, vertex_view_count
322
- )
323
- all_3d_vertices_clean, connections_3d_clean = hoho_example.prune_not_connected(
324
- all_3d_vertices_clean, connections_3d_clean, keep_largest=False
325
- )
326
-
327
- if (len(all_3d_vertices_clean) < 2) or len(connections_3d_clean) < 1:
328
- if verbose:
329
- print(f'Not enough vertices or connections in the 3D vertices')
330
- return empty_solution(sample)
331
 
332
- connections_3d_clean = [(int(a), int(b)) for a, b in connections_3d_clean]
333
- return all_3d_vertices_clean, connections_3d_clean
 
 
334
 
335
- def predict_wireframe_safely(sample):
336
  try:
337
- pred_vertices, pred_edges = predict_wireframe_filtered(sample)
338
- except Exception as e:
339
- print (f"Failed due to {e}, returning empty solution")
340
- pred_vertices, pred_edges = empty_solution(sample)
341
- pred_edges = [(int(a), int(b)) for a, b in pred_edges] # to remove possible np.int64
342
- return pred_vertices, pred_edges, sample['order_id']
343
 
344
- class Sample(Dict):
345
- def pick_repr_data(self, x):
346
- if hasattr(x, 'shape'):
347
- return x.shape
348
- if isinstance(x, (str, float, int)):
349
- return x
350
- if isinstance(x, list):
351
- return [type(x[0])] if len(x) > 0 else []
352
- return type(x)
353
 
 
354
  def __repr__(self):
355
- # return str({k: v.shape if hasattr(v, 'shape') else [type(v[0])] if isinstance(v, list) else type(v) for k,v in self.items()})
356
- return str({k: self.pick_repr_data(v) for k,v in self.items()})
357
-
358
- def load_competition_dataset(params):
359
- """
360
- Loads dataset both:
361
- 1. Locally from public parquet files.
362
- 2. In official competition environment from /tmp/data.
363
- """
364
- import os
365
-
366
- data_path = Path("/tmp/data")
367
-
368
- print("------------ Dataset path check ------------")
369
- print("pwd:")
370
- os.system("pwd")
371
-
372
- print("/tmp/data:")
373
- os.system("ls -lahtr /tmp/data || true")
374
 
375
- print("/tmp/data/data:")
376
- os.system("ls -lahtr /tmp/data/data || true")
377
 
378
- # Case 1: local debugging with public parquet dataset
379
- parquet_dir = data_path / "data"
380
- train_parquet = list(parquet_dir.glob("train-*.parquet"))
381
- val_parquet = list(parquet_dir.glob("validation-*.parquet"))
382
 
383
- if len(train_parquet) > 0 or len(val_parquet) > 0:
384
- print("Loading local/public parquet dataset")
385
 
386
- data_files = {}
387
-
388
- if len(train_parquet) > 0:
389
- data_files["train"] = str(parquet_dir / "train-*.parquet")
390
-
391
- if len(val_parquet) > 0:
392
- data_files["validation"] = str(parquet_dir / "validation-*.parquet")
393
-
394
- dataset = load_dataset("parquet", data_files=data_files)
395
- return dataset
396
-
397
- # Case 2: official test environment with custom dataset script
398
- dataset_script_candidates = list(data_path.glob("*.py"))
399
-
400
- if len(dataset_script_candidates) > 0:
401
- dataset_script = dataset_script_candidates[0]
402
- print(f"Loading official dataset script: {dataset_script}")
403
-
404
- data_files = {
405
- "validation": [str(p) for p in data_path.rglob("*public*/**/*.tar")],
406
- "test": [str(p) for p in data_path.rglob("*private*/**/*.tar")],
407
- }
408
-
409
- print("data_files:", data_files)
410
-
411
- dataset = load_dataset(
412
- str(dataset_script),
413
- data_files=data_files,
414
- trust_remote_code=True,
415
- writer_batch_size=100,
416
- )
417
-
418
- return dataset
419
-
420
- # Case 3: fallback download for local run
421
- print("No local /tmp/data files found. Trying Hugging Face download.")
422
-
423
- from huggingface_hub import snapshot_download
424
-
425
- snapshot_download(
426
- repo_id=params["dataset"],
427
- local_dir="/tmp/data",
428
- repo_type="dataset",
429
- token=params.get("token", None),
430
- )
431
 
432
- return load_competition_dataset(params)
433
 
434
- import json
435
  if __name__ == "__main__":
436
- print ("------------ Loading dataset------------ ")
437
  param_path = Path('params.json')
438
- print(param_path)
439
- with param_path.open() as f:
440
- params = json.load(f)
441
- safe_params = dict(params)
442
- if "token" in safe_params:
443
- safe_params["token"] = "hf_******"
444
-
445
- print(safe_params)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
446
 
447
- print("------------ Loading dataset ------------")
448
- dataset = load_competition_dataset(params)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
449
 
450
- print(dataset, flush=True)
451
-
452
- print('------------ Now you can do your solution ---------------')
453
  solution = []
454
- for subset_name in dataset:
455
- print (f"Predicitng on {subset_name}")
456
- preds = Parallel(n_jobs=-1, prefer="processes")(
457
- delayed(predict_wireframe_safely)(a) for a in tqdm(dataset[subset_name])
458
- )
459
- print ("Converting")
460
- for p in preds:
461
- pred_vertices, pred_edges, order_id = p
462
- print (f'{order_id}: {len(pred_vertices)} verts, {len(pred_edges)} edges')
463
- solution.append({
464
- 'order_id': order_id,
465
- 'wf_vertices': pred_vertices.tolist(),
466
- 'wf_edges': pred_edges
467
- })
468
- print('------------ Saving results ---------------')
469
- with open("submission.json", "w") as f:
470
  json.dump(solution, f)
 
 
 
 
 
 
 
 
471
 
472
- print("------------ Done ------------ ")
 
1
+ """S23DR 2026 submission sklearn edges + edge validation + improved heuristic."""
 
 
 
 
 
2
 
3
  from pathlib import Path
4
  from tqdm import tqdm
 
5
  import numpy as np
6
  from datasets import load_dataset
7
  from typing import Dict
8
+ import os
9
+ import json
10
+ import gc
11
+ import subprocess
12
+ import sys
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ # Ensure local imports work regardless of how script.py is invoked
15
+ current_dir = str(Path(__file__).parent.absolute())
16
+ if current_dir not in sys.path:
17
+ sys.path.insert(0, current_dir)
18
 
19
+ def install_if_missing(package):
20
  try:
21
+ __import__(package)
22
+ except ImportError:
23
+ subprocess.check_call([sys.executable, "-m", "pip", "install", package])
 
 
 
24
 
 
 
 
 
 
 
 
 
 
25
 
26
+ class Sample(Dict):
27
  def __repr__(self):
28
+ return str({k: v.shape if hasattr(v, 'shape') else [type(v[0])] if isinstance(v, list) else type(v) for k, v in self.items()})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
 
 
30
 
31
+ def empty_solution():
32
+ return np.zeros((2, 3)), [(0, 1)]
 
 
33
 
 
 
34
 
35
+ def process_sample(sample, i, sklearn_model=None):
36
+ try:
37
+ from sklearn_submission import predict_wireframe_sklearn
38
+ pred_vertices, pred_edges = predict_wireframe_sklearn(sample, sklearn_model)
39
+ except Exception as e:
40
+ if i < 5:
41
+ print(f" Sample {i} sklearn failed: {e}", flush=True)
42
+ try:
43
+ from hoho2025.example_solutions import predict_wireframe
44
+ pred_vertices, pred_edges = predict_wireframe(sample)
45
+ except Exception:
46
+ pred_vertices, pred_edges = empty_solution()
47
+ if i % 10 == 0:
48
+ gc.collect()
49
+ return {
50
+ 'order_id': sample['order_id'],
51
+ 'wf_vertices': np.array(pred_vertices).tolist(),
52
+ 'wf_edges': [list(e) for e in pred_edges],
53
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
 
55
 
 
56
  if __name__ == "__main__":
57
+ print("------------ Loading dataset ------------", flush=True)
58
  param_path = Path('params.json')
59
+ if param_path.exists():
60
+ with param_path.open() as f:
61
+ params = json.load(f)
62
+ else:
63
+ params = {"dataset": "usm3d/hoho22k_2026_trainval", "output_path": "."}
64
+
65
+ data_path = Path('/tmp/data')
66
+ if not data_path.exists():
67
+ from huggingface_hub import snapshot_download
68
+ snapshot_download(repo_id=params['dataset'], local_dir="/tmp/data", repo_type="dataset")
69
+
70
+ os.system(f'ls -lahtrR {data_path}')
71
+
72
+ data_files = {}
73
+ public_tars = sorted([str(p) for p in data_path.rglob('*public*/**/*.tar')])
74
+ private_tars = sorted([str(p) for p in data_path.rglob('*private*/**/*.tar')])
75
+ if public_tars:
76
+ data_files["validation"] = public_tars
77
+ if private_tars:
78
+ data_files["test"] = private_tars
79
+
80
+ loading_scripts = sorted(data_path.rglob('*.py'))
81
+ loading_script = str(loading_scripts[0]) if loading_scripts else str(data_path)
82
+
83
+ dataset = load_dataset(
84
+ loading_script, data_files=data_files,
85
+ trust_remote_code=True, writer_batch_size=100,
86
+ )
87
+ print(f"Dataset: {dataset}", flush=True)
88
 
89
+ # Try to load sklearn model
90
+ sklearn_model = None
91
+ try:
92
+ install_if_missing('scikit-learn')
93
+ install_if_missing('networkx')
94
+ import pickle
95
+ model_path = Path(__file__).parent / 'sklearn_edge.pkl'
96
+ print(f"Looking for sklearn model at: {model_path} (exists={model_path.exists()})", flush=True)
97
+ if model_path.exists():
98
+ with open(model_path, 'rb') as f:
99
+ sklearn_model = pickle.load(f)
100
+ print("Loaded sklearn edge model OK", flush=True)
101
+ else:
102
+ print("sklearn model not found — using heuristic + edge validation only", flush=True)
103
+ except Exception as e:
104
+ print(f"sklearn failed: {e} — using heuristic + edge validation only", flush=True)
105
 
106
+ print("------------ Running predictions ---------------", flush=True)
 
 
107
  solution = []
108
+ for subset_name in dataset.keys():
109
+ print(f"Predicting {subset_name}", flush=True)
110
+ for i, sample in enumerate(tqdm(dataset[subset_name])):
111
+ res = process_sample(sample, i, sklearn_model)
112
+ solution.append(res)
113
+ if i % 50 == 0:
114
+ print(f" Processed {i} samples", flush=True)
115
+
116
+ print("------------ Saving results ---------------", flush=True)
117
+ output_path = Path(params.get('output_path', '.'))
118
+
119
+ with open(output_path / "submission.json", 'w') as f:
 
 
 
 
120
  json.dump(solution, f)
121
+ print(f"Saved {len(solution)} predictions to submission.json", flush=True)
122
+
123
+ try:
124
+ import pandas as pd
125
+ sub = pd.DataFrame(solution, columns=["order_id", "wf_vertices", "wf_edges"])
126
+ sub.to_parquet(output_path / "submission.parquet")
127
+ except Exception:
128
+ pass
129
 
130
+ print("------------ Done ------------", flush=True)
sklearn_submission.py CHANGED
@@ -3,6 +3,12 @@
3
  import numpy as np
4
  import cv2
5
  from typing import Tuple, List
 
 
 
 
 
 
6
 
7
  from hoho2025.example_solutions import (
8
  convert_entry_to_human_readable, empty_solution,
 
3
  import numpy as np
4
  import cv2
5
  from typing import Tuple, List
6
+ import sys
7
+ from pathlib import Path
8
+
9
+ _cur_dir = str(Path(__file__).parent.absolute())
10
+ if _cur_dir not in sys.path:
11
+ sys.path.insert(0, _cur_dir)
12
 
13
  from hoho2025.example_solutions import (
14
  convert_entry_to_human_readable, empty_solution,