hoho / script.py
jskvrna's picture
Adds inference script and utils
c22c8c5
raw
history blame
3.12 kB
from pathlib import Path
from tqdm import tqdm
import pandas as pd
import numpy as np
from datasets import load_dataset
from typing import Dict
from joblib import Parallel, delayed
import os
import json
import gc
from utils import empty_solution
from predict import predict_wireframe
if __name__ == "__main__":
print ("------------ Loading dataset------------ ")
param_path = Path('params.json')
print(param_path)
with param_path.open() as f:
params = json.load(f)
print(params)
import os
print('pwd:')
os.system('pwd')
print(os.system('ls -lahtr'))
print('/tmp/data/')
print(os.system('ls -lahtr /tmp/data/'))
print('/tmp/data/data')
print(os.system('ls -lahtrR /tmp/data/data'))
data_path_test_server = Path('/tmp/data')
data_path_local = Path().home() / '.cache/huggingface/datasets/usm3d___hoho25k_test_x/'
if data_path_test_server.exists():
# data_path = data_path_test_server
TEST_ENV = True
else:
# data_path = data_path_local
TEST_ENV = False
from huggingface_hub import snapshot_download
_ = snapshot_download(
repo_id=params['dataset'],
local_dir="/tmp/data",
repo_type="dataset",
)
data_path = data_path_test_server
print(data_path)
# dataset = load_dataset(params['dataset'], trust_remote_code=True, use_auth_token=params['token'])
# data_files = {
# "validation": [str(p) for p in [*data_path.rglob('*validation*.arrow')]+[*data_path.rglob('*public*/**/*.tar')]],
# "test": [str(p) for p in [*data_path.rglob('*test*.arrow')]+[*data_path.rglob('*private*/**/*.tar')]],
# }
data_files = {
"validation": [str(p) for p in data_path.rglob('*public*/**/*.tar')],
"test": [str(p) for p in data_path.rglob('*private*/**/*.tar')],
}
print(data_files)
dataset = load_dataset(
str(data_path / 'hoho25k_test_x.py'),
data_files=data_files,
trust_remote_code=True,
writer_batch_size=100
)
print('load with webdataset')
print(dataset, flush=True)
print('------------ Now you can do your solution ---------------')
solution = []
def process_sample(sample, i):
try:
pred_vertices, pred_edges = predict_wireframe(sample)
except:
pred_vertices, pred_edges = empty_solution()
if i %10 == 0:
gc.collect()
return {
'order_id': sample['order_id'],
'wf_vertices': pred_vertices.tolist(),
'wf_edges': pred_edges
}
num_cores = 4
for subset_name in dataset.keys():
print (f"Predicting {subset_name}")
for i, sample in enumerate(tqdm(dataset[subset_name])):
res = process_sample(sample, i)
solution.append(res)
print('------------ Saving results ---------------')
sub = pd.DataFrame(solution, columns=["order_id", "wf_vertices", "wf_edges"])
sub.to_parquet("submission.parquet")
print("------------ Done ------------ ")