from pathlib import Path from tqdm import tqdm import pandas as pd import numpy as np from datasets import load_dataset from typing import Dict from joblib import Parallel, delayed import os import json import gc from utils import empty_solution from predict import predict_wireframe if __name__ == "__main__": print ("------------ Loading dataset------------ ") param_path = Path('params.json') print(param_path) with param_path.open() as f: params = json.load(f) print(params) import os print('pwd:') os.system('pwd') print(os.system('ls -lahtr')) print('/tmp/data/') print(os.system('ls -lahtr /tmp/data/')) print('/tmp/data/data') print(os.system('ls -lahtrR /tmp/data/data')) data_path_test_server = Path('/tmp/data') data_path_local = Path().home() / '.cache/huggingface/datasets/usm3d___hoho25k_test_x/' if data_path_test_server.exists(): # data_path = data_path_test_server TEST_ENV = True else: # data_path = data_path_local TEST_ENV = False from huggingface_hub import snapshot_download _ = snapshot_download( repo_id=params['dataset'], local_dir="/tmp/data", repo_type="dataset", ) data_path = data_path_test_server print(data_path) # dataset = load_dataset(params['dataset'], trust_remote_code=True, use_auth_token=params['token']) # data_files = { # "validation": [str(p) for p in [*data_path.rglob('*validation*.arrow')]+[*data_path.rglob('*public*/**/*.tar')]], # "test": [str(p) for p in [*data_path.rglob('*test*.arrow')]+[*data_path.rglob('*private*/**/*.tar')]], # } data_files = { "validation": [str(p) for p in data_path.rglob('*public*/**/*.tar')], "test": [str(p) for p in data_path.rglob('*private*/**/*.tar')], } print(data_files) dataset = load_dataset( str(data_path / 'hoho25k_test_x.py'), data_files=data_files, trust_remote_code=True, writer_batch_size=100 ) print('load with webdataset') print(dataset, flush=True) print('------------ Now you can do your solution ---------------') solution = [] def process_sample(sample, i): try: pred_vertices, pred_edges = predict_wireframe(sample) except: pred_vertices, pred_edges = empty_solution() if i %10 == 0: gc.collect() return { 'order_id': sample['order_id'], 'wf_vertices': pred_vertices.tolist(), 'wf_edges': pred_edges } num_cores = 4 for subset_name in dataset.keys(): print (f"Predicting {subset_name}") for i, sample in enumerate(tqdm(dataset[subset_name])): res = process_sample(sample, i) solution.append(res) print('------------ Saving results ---------------') sub = pd.DataFrame(solution, columns=["order_id", "wf_vertices", "wf_edges"]) sub.to_parquet("submission.parquet") print("------------ Done ------------ ")