| |
|
|
| |
| |
| |
| |
|
|
| from pathlib import Path |
| from tqdm import tqdm |
| import json |
| import numpy as np |
| from datasets import load_dataset |
| from typing import Dict |
| from joblib import Parallel, delayed |
|
|
| def empty_solution(sample): |
| '''Return a minimal valid solution, i.e. 2 vertices and 1 edge.''' |
| return np.zeros((2,3)), [(0, 1)] |
|
|
| def predict_wireframe_safely(sample): |
| pred_vertices, pred_edges = empty_solution(sample) |
| pred_edges = [(int(a), int(b)) for a, b in pred_edges] |
| return pred_vertices, pred_edges, sample['order_id'] |
|
|
| class Sample(Dict): |
| def pick_repr_data(self, x): |
| if hasattr(x, 'shape'): |
| return x.shape |
| if isinstance(x, (str, float, int)): |
| return x |
| if isinstance(x, list): |
| return [type(x[0])] if len(x) > 0 else [] |
| return type(x) |
|
|
| def __repr__(self): |
| |
| return str({k: self.pick_repr_data(v) for k,v in self.items()}) |
| |
| import json |
| if __name__ == "__main__": |
| print ("------------ Loading dataset------------ ") |
| param_path = Path('params.json') |
| print(param_path) |
| with param_path.open() as f: |
| params = json.load(f) |
| print(params) |
| import os |
| |
| print('pwd:') |
| os.system('pwd') |
| print(os.system('ls -lahtr')) |
| print('/tmp/data/') |
| print(os.system('ls -lahtr /tmp/data/')) |
| print('/tmp/data/data') |
| print(os.system('ls -lahtrR /tmp/data/data')) |
| |
|
|
| data_path_test_server = Path('/tmp/data') |
| data_path_local = Path().home() / '.cache/huggingface/datasets/usm3d___hoho22k_2026_test_x_anon/' |
|
|
| if data_path_test_server.exists(): |
| |
| TEST_ENV = True |
| else: |
| |
| TEST_ENV = False |
| from huggingface_hub import snapshot_download |
| _ = snapshot_download( |
| repo_id=params['dataset'], |
| local_dir="/tmp/data", |
| repo_type="dataset", |
| ) |
| data_path = data_path_test_server |
| |
| |
| print(data_path) |
|
|
| |
| |
| |
| |
| |
| data_files = { |
| "validation": [str(p) for p in data_path.rglob('*public*/**/*.tar')], |
| "test": [str(p) for p in data_path.rglob('*private*/**/*.tar')], |
| } |
| print(data_files) |
| dataset = load_dataset( |
| str(data_path / 'hoho22k_2026_test_x_anon.py'), |
| data_files=data_files, |
| trust_remote_code=True, |
| writer_batch_size=100 |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| print('load with webdataset') |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| print(dataset, flush=True) |
| |
| |
| print('------------ Now you can do your solution ---------------') |
| solution = [] |
| for subset_name in dataset: |
| print (f"Predicitng on {subset_name}") |
| preds = Parallel(n_jobs=-1, prefer="processes")( |
| delayed(predict_wireframe_safely)(a) for a in tqdm(dataset[subset_name]) |
| ) |
| print ("Converting") |
| for p in preds: |
| pred_vertices, pred_edges, order_id = p |
| print (f'{order_id}: {len(pred_vertices)} verts, {len(pred_edges)} edges') |
| solution.append({ |
| 'order_id': order_id, |
| 'wf_vertices': pred_vertices.tolist(), |
| 'wf_edges': pred_edges |
| }) |
| print('------------ Saving results ---------------') |
| with open("submission.json", "w") as f: |
| json.dump(solution, f) |
|
|
| print("------------ Done ------------ ") |