dmytromishkin commited on
Commit
8e16bfb
·
verified ·
1 Parent(s): 25af80e

Update script.py

Browse files
Files changed (1) hide show
  1. script.py +51 -52
script.py CHANGED
@@ -2,25 +2,30 @@
2
 
3
  ### You can change the rest of the code to define and test your solution.
4
  ### However, you should not change the signature of the provided function.
5
- ### The script saves "submission.parquet" file in the current directory.
6
  ### You can use any additional files and subdirectories to organize your code.
7
 
8
  from pathlib import Path
9
  from tqdm import tqdm
10
- import pandas as pd
11
  import numpy as np
12
  from datasets import load_dataset
13
  from typing import Dict
14
- from joblib import Parallel, delayed
15
- import os
16
- import json
17
- import gc
18
  from hoho2025.example_solutions import predict_wireframe
19
- # check the https://github.com/s23dr/hoho2025/blob/main/hoho2025/example_solutions.py for the example solution
20
 
21
- def empty_solution():
22
  '''Return a minimal valid solution, i.e. 2 vertices and 1 edge.'''
23
- return np.zeros((2,3)), [(0, 1)]
 
 
 
 
 
 
 
 
 
24
 
25
  class Sample(Dict):
26
  def pick_repr_data(self, x):
@@ -35,8 +40,10 @@ class Sample(Dict):
35
  def __repr__(self):
36
  # return str({k: v.shape if hasattr(v, 'shape') else [type(v[0])] if isinstance(v, list) else type(v) for k,v in self.items()})
37
  return str({k: self.pick_repr_data(v) for k,v in self.items()})
38
-
39
 
 
 
 
40
  if __name__ == "__main__":
41
  print ("------------ Loading dataset------------ ")
42
  param_path = Path('params.json')
@@ -92,54 +99,46 @@ if __name__ == "__main__":
92
  writer_batch_size=100
93
  )
94
 
 
 
 
 
 
 
 
95
  print('load with webdataset')
96
-
97
-
 
 
 
 
 
 
 
 
 
98
  print(dataset, flush=True)
 
99
 
100
  print('------------ Now you can do your solution ---------------')
101
  solution = []
102
-
103
- def convert(obj):
104
- if isinstance(obj, np.ndarray):
105
- return obj.tolist()
106
- if isinstance(obj, np.integer):
107
- return int(obj)
108
- if isinstance(obj, np.floating):
109
- return float(obj)
110
- if isinstance(obj, dict):
111
- return {k: convert(v) for k, v in obj.items()}
112
- if isinstance(obj, (list, tuple)):
113
- return [convert(i) for i in obj]
114
- return obj
115
-
116
- def process_sample(sample, i):
117
- try:
118
- pred_vertices, pred_edges = predict_wireframe(sample)
119
- except:
120
- pred_vertices, pred_edges = empty_solution()
121
- if i %10 == 0:
122
- gc.collect()
123
- return {
124
- 'order_id': sample['order_id'],
125
- 'wf_vertices': pred_vertices.tolist(),
126
- 'wf_edges': pred_edges
127
- }
128
- num_cores = 4
129
-
130
- for subset_name in dataset.keys():
131
- print (f"Predicting {subset_name}")
132
- for i, sample in enumerate(tqdm(dataset[subset_name])):
133
- res = process_sample(sample, i)
134
- solution.append(res)
135
-
136
  print('------------ Saving results ---------------')
137
- # sub = pd.DataFrame(solution, columns=["order_id", "wf_vertices", "wf_edges"])
138
- # sub.to_parquet("submission.parquet")
139
-
140
- solution_as_lists = convert(solution)
141
-
142
  with open("submission.json", "w") as f:
143
- json.dump(solution_as_lists, f)
144
 
145
  print("------------ Done ------------ ")
 
2
 
3
  ### You can change the rest of the code to define and test your solution.
4
  ### However, you should not change the signature of the provided function.
5
+ ### The script saves "submission.json" file in the current directory.
6
  ### You can use any additional files and subdirectories to organize your code.
7
 
8
  from pathlib import Path
9
  from tqdm import tqdm
10
+ import json
11
  import numpy as np
12
  from datasets import load_dataset
13
  from typing import Dict
 
 
 
 
14
  from hoho2025.example_solutions import predict_wireframe
15
+ from joblib import Parallel, delayed
16
 
17
+ def empty_solution(sample):
18
  '''Return a minimal valid solution, i.e. 2 vertices and 1 edge.'''
19
+ return np.zeros((2,3)), [(0, 1)], sample['order_id']
20
+
21
+ def predict_wireframe_safely(sample):
22
+ try:
23
+ pred_vertices, pred_edges = predict_wireframe(sample)
24
+ except Exception as e:
25
+ print (f"Failed due to {e}, returning empty solution")
26
+ pred_vertices, pred_edges = empty_solution(sample)
27
+ pred_edges = [(int(a), int(b)) for a, b in pred_edges] # to remove possible np.int64
28
+ return pred_vertices, pred_edges, sample['order_id']
29
 
30
  class Sample(Dict):
31
  def pick_repr_data(self, x):
 
40
  def __repr__(self):
41
  # return str({k: v.shape if hasattr(v, 'shape') else [type(v[0])] if isinstance(v, list) else type(v) for k,v in self.items()})
42
  return str({k: self.pick_repr_data(v) for k,v in self.items()})
 
43
 
44
+
45
+
46
+ import json
47
  if __name__ == "__main__":
48
  print ("------------ Loading dataset------------ ")
49
  param_path = Path('params.json')
 
99
  writer_batch_size=100
100
  )
101
 
102
+ # if TEST_ENV:
103
+ # dataset = load_dataset(
104
+ # "webdataset",
105
+ # data_files=data_files,
106
+ # trust_remote_code=True,
107
+ # # streaming=True
108
+ # )
109
  print('load with webdataset')
110
+ # else:
111
+
112
+ # dataset = load_dataset(
113
+ # "arrow",
114
+ # data_files=data_files,
115
+ # trust_remote_code=True,
116
+ # # streaming=True
117
+ # )
118
+ # print('load with arrow')
119
+
120
+
121
  print(dataset, flush=True)
122
+ # dataset = load_dataset('webdataset', data_files={)
123
 
124
  print('------------ Now you can do your solution ---------------')
125
  solution = []
126
+ for subset_name in dataset:
127
+ print (f"Predicitng on {subset_name}")
128
+ preds = Parallel(n_jobs=-1, prefer="processes")(
129
+ delayed(predict_wireframe_safely)(a) for a in tqdm(dataset[subset_name])
130
+ )
131
+ print ("Converting")
132
+ for p in preds:
133
+ pred_vertices, pred_edges, order_id = p
134
+ print (f'{order_id}: {len(pred_vertices)} verts, {len(pred_edges)} edges')
135
+ solution.append({
136
+ 'order_id': order_id,
137
+ 'wf_vertices': pred_vertices.tolist(),
138
+ 'wf_edges': pred_edges
139
+ })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  print('------------ Saving results ---------------')
 
 
 
 
 
141
  with open("submission.json", "w") as f:
142
+ json.dump(solution, f)
143
 
144
  print("------------ Done ------------ ")