|
|
from pathlib import Path |
|
|
from tqdm import tqdm |
|
|
import pandas as pd |
|
|
from datasets import load_dataset |
|
|
import os |
|
|
import json |
|
|
import gc |
|
|
|
|
|
from utils import empty_solution |
|
|
from predict import predict_wireframe |
|
|
|
|
|
from fast_pointnet_v2 import load_pointnet_model |
|
|
from fast_pointnet_class import load_pointnet_model as load_pointnet_class_model |
|
|
import torch |
|
|
|
|
|
if __name__ == "__main__": |
|
|
print ("------------ Loading dataset------------ ") |
|
|
param_path = Path('params.json') |
|
|
print(param_path) |
|
|
with param_path.open() as f: |
|
|
params = json.load(f) |
|
|
print(params) |
|
|
import os |
|
|
|
|
|
print('pwd:') |
|
|
os.system('pwd') |
|
|
print(os.system('ls -lahtr')) |
|
|
print('/generic/path/to/data_dir/') |
|
|
print(os.system('ls -lahtr /generic/path/to/data_dir/')) |
|
|
print('/generic/path/to/data_dir/data') |
|
|
print(os.system('ls -lahtrR /generic/path/to/data_dir/data')) |
|
|
|
|
|
data_path_test_server = Path('/generic/path/to/data_dir') |
|
|
data_path_local = Path("/generic/path/to/user_home") / '.cache/huggingface/datasets/usm3d___hoho25k_test_x/' |
|
|
|
|
|
if data_path_test_server.exists(): |
|
|
|
|
|
TEST_ENV = True |
|
|
else: |
|
|
|
|
|
TEST_ENV = False |
|
|
from huggingface_hub import snapshot_download |
|
|
_ = snapshot_download( |
|
|
repo_id=params['dataset'], |
|
|
local_dir="/generic/path/to/data_dir", |
|
|
repo_type="dataset", |
|
|
) |
|
|
data_path = data_path_test_server |
|
|
|
|
|
|
|
|
print(data_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_files = { |
|
|
"validation": [str(p) for p in data_path.rglob('*public*/**/*.tar')], |
|
|
"test": [str(p) for p in data_path.rglob('*private*/**/*.tar')], |
|
|
} |
|
|
print(data_files) |
|
|
dataset = load_dataset( |
|
|
str(data_path / 'hoho25k_test_x.py'), |
|
|
data_files=data_files, |
|
|
trust_remote_code=True, |
|
|
writer_batch_size=100 |
|
|
) |
|
|
|
|
|
print('load with webdataset') |
|
|
|
|
|
|
|
|
print(dataset, flush=True) |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
pnet_model = load_pointnet_model(model_path="pnet.pth", device=device, predict_score=True) |
|
|
|
|
|
pnet_class_model = load_pointnet_class_model(model_path="pnet_class.pth", device=device) |
|
|
|
|
|
voxel_model = None |
|
|
|
|
|
config = {'vertex_threshold': 0.59, 'edge_threshold': 0.65, 'only_predicted_connections': True} |
|
|
|
|
|
print('------------ Now you can do your solution ---------------') |
|
|
solution = [] |
|
|
|
|
|
def process_sample(sample, i): |
|
|
try: |
|
|
pred_vertices, pred_edges = predict_wireframe(sample, pnet_model, voxel_model, pnet_class_model, config) |
|
|
except: |
|
|
pred_vertices, pred_edges = empty_solution() |
|
|
if i %10 == 0: |
|
|
gc.collect() |
|
|
return { |
|
|
'order_id': sample['order_id'], |
|
|
'wf_vertices': pred_vertices.tolist(), |
|
|
'wf_edges': pred_edges |
|
|
} |
|
|
|
|
|
num_cores = 4 |
|
|
|
|
|
for subset_name in dataset.keys(): |
|
|
print (f"Predicting {subset_name}") |
|
|
for i, sample in enumerate(tqdm(dataset[subset_name])): |
|
|
res = process_sample(sample, i) |
|
|
solution.append(res) |
|
|
|
|
|
print('------------ Saving results ---------------') |
|
|
sub = pd.DataFrame(solution, columns=["order_id", "wf_vertices", "wf_edges"]) |
|
|
sub.to_parquet("submission.parquet") |
|
|
print("------------ Done ------------ ") |