dmytromishkin's picture
Update script.py
0758125 verified
### This is example of the script that will be run in the test environment.
### You can change the rest of the code to define and test your solution.
### However, you should not change the signature of the provided function.
### The script saves "submission.json" file in the current directory.
### You can use any additional files and subdirectories to organize your code.
from pathlib import Path
from tqdm import tqdm
import json
import numpy as np
from datasets import load_dataset
from typing import Dict
from joblib import Parallel, delayed
def empty_solution(sample):
'''Return a minimal valid solution, i.e. 2 vertices and 1 edge.'''
return np.zeros((2,3)), [(0, 1)]
def predict_wireframe_safely(sample):
pred_vertices, pred_edges = empty_solution(sample)
pred_edges = [(int(a), int(b)) for a, b in pred_edges] # to remove possible np.int64
return pred_vertices, pred_edges, sample['order_id']
class Sample(Dict):
def pick_repr_data(self, x):
if hasattr(x, 'shape'):
return x.shape
if isinstance(x, (str, float, int)):
return x
if isinstance(x, list):
return [type(x[0])] if len(x) > 0 else []
return type(x)
def __repr__(self):
# return str({k: v.shape if hasattr(v, 'shape') else [type(v[0])] if isinstance(v, list) else type(v) for k,v in self.items()})
return str({k: self.pick_repr_data(v) for k,v in self.items()})
import json
if __name__ == "__main__":
print ("------------ Loading dataset------------ ")
param_path = Path('params.json')
print(param_path)
with param_path.open() as f:
params = json.load(f)
print(params)
import os
print('pwd:')
os.system('pwd')
print(os.system('ls -lahtr'))
print('/tmp/data/')
print(os.system('ls -lahtr /tmp/data/'))
print('/tmp/data/data')
print(os.system('ls -lahtrR /tmp/data/data'))
data_path_test_server = Path('/tmp/data')
data_path_local = Path().home() / '.cache/huggingface/datasets/usm3d___hoho22k_2026_test_x_anon/'
if data_path_test_server.exists():
# data_path = data_path_test_server
TEST_ENV = True
else:
# data_path = data_path_local
TEST_ENV = False
from huggingface_hub import snapshot_download
_ = snapshot_download(
repo_id=params['dataset'],
local_dir="/tmp/data",
repo_type="dataset",
)
data_path = data_path_test_server
print(data_path)
# dataset = load_dataset(params['dataset'], trust_remote_code=True, use_auth_token=params['token'])
# data_files = {
# "validation": [str(p) for p in [*data_path.rglob('*validation*.arrow')]+[*data_path.rglob('*public*/**/*.tar')]],
# "test": [str(p) for p in [*data_path.rglob('*test*.arrow')]+[*data_path.rglob('*private*/**/*.tar')]],
# }
data_files = {
"validation": [str(p) for p in data_path.rglob('*public*/**/*.tar')],
"test": [str(p) for p in data_path.rglob('*private*/**/*.tar')],
}
print(data_files)
dataset = load_dataset(
str(data_path / 'hoho22k_2026_test_x_anon.py'),
data_files=data_files,
trust_remote_code=True,
writer_batch_size=100
)
# if TEST_ENV:
# dataset = load_dataset(
# "webdataset",
# data_files=data_files,
# trust_remote_code=True,
# # streaming=True
# )
print('load with webdataset')
# else:
# dataset = load_dataset(
# "arrow",
# data_files=data_files,
# trust_remote_code=True,
# # streaming=True
# )
# print('load with arrow')
print(dataset, flush=True)
# dataset = load_dataset('webdataset', data_files={)
print('------------ Now you can do your solution ---------------')
solution = []
for subset_name in dataset:
print (f"Predicitng on {subset_name}")
preds = Parallel(n_jobs=-1, prefer="processes")(
delayed(predict_wireframe_safely)(a) for a in tqdm(dataset[subset_name])
)
print ("Converting")
for p in preds:
pred_vertices, pred_edges, order_id = p
print (f'{order_id}: {len(pred_vertices)} verts, {len(pred_edges)} edges')
solution.append({
'order_id': order_id,
'wf_vertices': pred_vertices.tolist(),
'wf_edges': pred_edges
})
print('------------ Saving results ---------------')
with open("submission.json", "w") as f:
json.dump(solution, f)
print("------------ Done ------------ ")