Update script.py
Browse files
script.py
CHANGED
|
@@ -2,25 +2,30 @@
|
|
| 2 |
|
| 3 |
### You can change the rest of the code to define and test your solution.
|
| 4 |
### However, you should not change the signature of the provided function.
|
| 5 |
-
### The script saves "submission.
|
| 6 |
### You can use any additional files and subdirectories to organize your code.
|
| 7 |
|
| 8 |
from pathlib import Path
|
| 9 |
from tqdm import tqdm
|
| 10 |
-
import
|
| 11 |
import numpy as np
|
| 12 |
from datasets import load_dataset
|
| 13 |
from typing import Dict
|
| 14 |
-
from joblib import Parallel, delayed
|
| 15 |
-
import os
|
| 16 |
-
import json
|
| 17 |
-
import gc
|
| 18 |
from hoho2025.example_solutions import predict_wireframe
|
| 19 |
-
|
| 20 |
|
| 21 |
-
def empty_solution():
|
| 22 |
'''Return a minimal valid solution, i.e. 2 vertices and 1 edge.'''
|
| 23 |
-
return np.zeros((2,3)), [(0, 1)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
class Sample(Dict):
|
| 26 |
def pick_repr_data(self, x):
|
|
@@ -35,8 +40,10 @@ class Sample(Dict):
|
|
| 35 |
def __repr__(self):
|
| 36 |
# return str({k: v.shape if hasattr(v, 'shape') else [type(v[0])] if isinstance(v, list) else type(v) for k,v in self.items()})
|
| 37 |
return str({k: self.pick_repr_data(v) for k,v in self.items()})
|
| 38 |
-
|
| 39 |
|
|
|
|
|
|
|
|
|
|
| 40 |
if __name__ == "__main__":
|
| 41 |
print ("------------ Loading dataset------------ ")
|
| 42 |
param_path = Path('params.json')
|
|
@@ -92,54 +99,46 @@ if __name__ == "__main__":
|
|
| 92 |
writer_batch_size=100
|
| 93 |
)
|
| 94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
print('load with webdataset')
|
| 96 |
-
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
print(dataset, flush=True)
|
|
|
|
| 99 |
|
| 100 |
print('------------ Now you can do your solution ---------------')
|
| 101 |
solution = []
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
def process_sample(sample, i):
|
| 117 |
-
try:
|
| 118 |
-
pred_vertices, pred_edges = predict_wireframe(sample)
|
| 119 |
-
except:
|
| 120 |
-
pred_vertices, pred_edges = empty_solution()
|
| 121 |
-
if i %10 == 0:
|
| 122 |
-
gc.collect()
|
| 123 |
-
return {
|
| 124 |
-
'order_id': sample['order_id'],
|
| 125 |
-
'wf_vertices': pred_vertices.tolist(),
|
| 126 |
-
'wf_edges': pred_edges
|
| 127 |
-
}
|
| 128 |
-
num_cores = 4
|
| 129 |
-
|
| 130 |
-
for subset_name in dataset.keys():
|
| 131 |
-
print (f"Predicting {subset_name}")
|
| 132 |
-
for i, sample in enumerate(tqdm(dataset[subset_name])):
|
| 133 |
-
res = process_sample(sample, i)
|
| 134 |
-
solution.append(res)
|
| 135 |
-
|
| 136 |
print('------------ Saving results ---------------')
|
| 137 |
-
# sub = pd.DataFrame(solution, columns=["order_id", "wf_vertices", "wf_edges"])
|
| 138 |
-
# sub.to_parquet("submission.parquet")
|
| 139 |
-
|
| 140 |
-
solution_as_lists = convert(solution)
|
| 141 |
-
|
| 142 |
with open("submission.json", "w") as f:
|
| 143 |
-
json.dump(
|
| 144 |
|
| 145 |
print("------------ Done ------------ ")
|
|
|
|
| 2 |
|
| 3 |
### You can change the rest of the code to define and test your solution.
|
| 4 |
### However, you should not change the signature of the provided function.
|
| 5 |
+
### The script saves "submission.json" file in the current directory.
|
| 6 |
### You can use any additional files and subdirectories to organize your code.
|
| 7 |
|
| 8 |
from pathlib import Path
|
| 9 |
from tqdm import tqdm
|
| 10 |
+
import json
|
| 11 |
import numpy as np
|
| 12 |
from datasets import load_dataset
|
| 13 |
from typing import Dict
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
from hoho2025.example_solutions import predict_wireframe
|
| 15 |
+
from joblib import Parallel, delayed
|
| 16 |
|
| 17 |
+
def empty_solution(sample):
|
| 18 |
'''Return a minimal valid solution, i.e. 2 vertices and 1 edge.'''
|
| 19 |
+
return np.zeros((2,3)), [(0, 1)], sample['order_id']
|
| 20 |
+
|
| 21 |
+
def predict_wireframe_safely(sample):
|
| 22 |
+
try:
|
| 23 |
+
pred_vertices, pred_edges = predict_wireframe(sample)
|
| 24 |
+
except Exception as e:
|
| 25 |
+
print (f"Failed due to {e}, returning empty solution")
|
| 26 |
+
pred_vertices, pred_edges = empty_solution(sample)
|
| 27 |
+
pred_edges = [(int(a), int(b)) for a, b in pred_edges] # to remove possible np.int64
|
| 28 |
+
return pred_vertices, pred_edges, sample['order_id']
|
| 29 |
|
| 30 |
class Sample(Dict):
|
| 31 |
def pick_repr_data(self, x):
|
|
|
|
| 40 |
def __repr__(self):
|
| 41 |
# return str({k: v.shape if hasattr(v, 'shape') else [type(v[0])] if isinstance(v, list) else type(v) for k,v in self.items()})
|
| 42 |
return str({k: self.pick_repr_data(v) for k,v in self.items()})
|
|
|
|
| 43 |
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
import json
|
| 47 |
if __name__ == "__main__":
|
| 48 |
print ("------------ Loading dataset------------ ")
|
| 49 |
param_path = Path('params.json')
|
|
|
|
| 99 |
writer_batch_size=100
|
| 100 |
)
|
| 101 |
|
| 102 |
+
# if TEST_ENV:
|
| 103 |
+
# dataset = load_dataset(
|
| 104 |
+
# "webdataset",
|
| 105 |
+
# data_files=data_files,
|
| 106 |
+
# trust_remote_code=True,
|
| 107 |
+
# # streaming=True
|
| 108 |
+
# )
|
| 109 |
print('load with webdataset')
|
| 110 |
+
# else:
|
| 111 |
+
|
| 112 |
+
# dataset = load_dataset(
|
| 113 |
+
# "arrow",
|
| 114 |
+
# data_files=data_files,
|
| 115 |
+
# trust_remote_code=True,
|
| 116 |
+
# # streaming=True
|
| 117 |
+
# )
|
| 118 |
+
# print('load with arrow')
|
| 119 |
+
|
| 120 |
+
|
| 121 |
print(dataset, flush=True)
|
| 122 |
+
# dataset = load_dataset('webdataset', data_files={)
|
| 123 |
|
| 124 |
print('------------ Now you can do your solution ---------------')
|
| 125 |
solution = []
|
| 126 |
+
for subset_name in dataset:
|
| 127 |
+
print (f"Predicitng on {subset_name}")
|
| 128 |
+
preds = Parallel(n_jobs=-1, prefer="processes")(
|
| 129 |
+
delayed(predict_wireframe_safely)(a) for a in tqdm(dataset[subset_name])
|
| 130 |
+
)
|
| 131 |
+
print ("Converting")
|
| 132 |
+
for p in preds:
|
| 133 |
+
pred_vertices, pred_edges, order_id = p
|
| 134 |
+
print (f'{order_id}: {len(pred_vertices)} verts, {len(pred_edges)} edges')
|
| 135 |
+
solution.append({
|
| 136 |
+
'order_id': order_id,
|
| 137 |
+
'wf_vertices': pred_vertices.tolist(),
|
| 138 |
+
'wf_edges': pred_edges
|
| 139 |
+
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
print('------------ Saving results ---------------')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
with open("submission.json", "w") as f:
|
| 142 |
+
json.dump(solution, f)
|
| 143 |
|
| 144 |
print("------------ Done ------------ ")
|