| from pathlib import Path |
| from tqdm import tqdm |
| import pandas as pd |
| import numpy as np |
| from datasets import load_dataset |
| from typing import Dict |
| from joblib import Parallel, delayed |
| import os |
| import json |
| import gc |
|
|
| from utils import empty_solution |
| from predict import predict_wireframe |
|
|
| if __name__ == "__main__": |
| print ("------------ Loading dataset------------ ") |
| param_path = Path('params.json') |
| print(param_path) |
| with param_path.open() as f: |
| params = json.load(f) |
| print(params) |
| import os |
| |
| print('pwd:') |
| os.system('pwd') |
| print(os.system('ls -lahtr')) |
| print('/tmp/data/') |
| print(os.system('ls -lahtr /tmp/data/')) |
| print('/tmp/data/data') |
| print(os.system('ls -lahtrR /tmp/data/data')) |
| |
|
|
| data_path_test_server = Path('/tmp/data') |
| data_path_local = Path().home() / '.cache/huggingface/datasets/usm3d___hoho25k_test_x/' |
|
|
| if data_path_test_server.exists(): |
| |
| TEST_ENV = True |
| else: |
| |
| TEST_ENV = False |
| from huggingface_hub import snapshot_download |
| _ = snapshot_download( |
| repo_id=params['dataset'], |
| local_dir="/tmp/data", |
| repo_type="dataset", |
| ) |
| data_path = data_path_test_server |
| |
| |
| print(data_path) |
|
|
| |
| |
| |
| |
| |
| data_files = { |
| "validation": [str(p) for p in data_path.rglob('*public*/**/*.tar')], |
| "test": [str(p) for p in data_path.rglob('*private*/**/*.tar')], |
| } |
| print(data_files) |
| dataset = load_dataset( |
| str(data_path / 'hoho25k_test_x.py'), |
| data_files=data_files, |
| trust_remote_code=True, |
| writer_batch_size=100 |
| ) |
|
|
| print('load with webdataset') |
|
|
|
|
| print(dataset, flush=True) |
| |
| print('------------ Now you can do your solution ---------------') |
| solution = [] |
|
|
| def process_sample(sample, i): |
| try: |
| pred_vertices, pred_edges = predict_wireframe(sample) |
| except: |
| pred_vertices, pred_edges = empty_solution() |
| if i %10 == 0: |
| gc.collect() |
| return { |
| 'order_id': sample['order_id'], |
| 'wf_vertices': pred_vertices.tolist(), |
| 'wf_edges': pred_edges |
| } |
| num_cores = 4 |
| |
| for subset_name in dataset.keys(): |
| print (f"Predicting {subset_name}") |
| for i, sample in enumerate(tqdm(dataset[subset_name])): |
| res = process_sample(sample, i) |
| solution.append(res) |
|
|
| print('------------ Saving results ---------------') |
| sub = pd.DataFrame(solution, columns=["order_id", "wf_vertices", "wf_edges"]) |
| sub.to_parquet("submission.parquet") |
| print("------------ Done ------------ ") |