hoho / script.py
jskvrna's picture
Update to the hyperparameters, so they are the winning solution :)
32fd25b
from pathlib import Path
from tqdm import tqdm
import pandas as pd
from datasets import load_dataset
import os
import json
import gc
from utils import empty_solution
from predict import predict_wireframe
from fast_pointnet_v2 import load_pointnet_model
from fast_pointnet_class import load_pointnet_model as load_pointnet_class_model
import torch
if __name__ == "__main__":
print ("------------ Loading dataset------------ ")
param_path = Path('params.json')
print(param_path)
with param_path.open() as f:
params = json.load(f)
print(params)
import os
print('pwd:')
os.system('pwd')
print(os.system('ls -lahtr'))
print('/generic/path/to/data_dir/') # Placeholder for '/tmp/data/'
print(os.system('ls -lahtr /generic/path/to/data_dir/')) # Placeholder for /tmp/data/
print('/generic/path/to/data_dir/data') # Placeholder for '/tmp/data/data'
print(os.system('ls -lahtrR /generic/path/to/data_dir/data')) # Placeholder for /tmp/data/data
data_path_test_server = Path('/generic/path/to/data_dir') # Placeholder for Path('/tmp/data')
data_path_local = Path("/generic/path/to/user_home") / '.cache/huggingface/datasets/usm3d___hoho25k_test_x/' # Placeholder for Path().home()
if data_path_test_server.exists():
# data_path = data_path_test_server
TEST_ENV = True
else:
# data_path = data_path_local
TEST_ENV = False
from huggingface_hub import snapshot_download
_ = snapshot_download(
repo_id=params['dataset'],
local_dir="/generic/path/to/data_dir", # Placeholder for "/tmp/data"
repo_type="dataset",
)
data_path = data_path_test_server
print(data_path)
# dataset = load_dataset(params['dataset'], trust_remote_code=True, use_auth_token=params['token'])
# data_files = {
# "validation": [str(p) for p in [*data_path.rglob('*validation*.arrow')]+[*data_path.rglob('*public*/**/*.tar')]],
# "test": [str(p) for p in [*data_path.rglob('*test*.arrow')]+[*data_path.rglob('*private*/**/*.tar')]],
# }
data_files = {
"validation": [str(p) for p in data_path.rglob('*public*/**/*.tar')],
"test": [str(p) for p in data_path.rglob('*private*/**/*.tar')],
}
print(data_files)
dataset = load_dataset(
str(data_path / 'hoho25k_test_x.py'),
data_files=data_files,
trust_remote_code=True,
writer_batch_size=100
)
print('load with webdataset')
print(dataset, flush=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
pnet_model = load_pointnet_model(model_path="pnet.pth", device=device, predict_score=True)
pnet_class_model = load_pointnet_class_model(model_path="pnet_class.pth", device=device)
voxel_model = None
config = {'vertex_threshold': 0.59, 'edge_threshold': 0.65, 'only_predicted_connections': True}
print('------------ Now you can do your solution ---------------')
solution = []
def process_sample(sample, i):
try:
pred_vertices, pred_edges = predict_wireframe(sample, pnet_model, voxel_model, pnet_class_model, config)
except:
pred_vertices, pred_edges = empty_solution()
if i %10 == 0:
gc.collect()
return {
'order_id': sample['order_id'],
'wf_vertices': pred_vertices.tolist(),
'wf_edges': pred_edges
}
num_cores = 4
for subset_name in dataset.keys():
print (f"Predicting {subset_name}")
for i, sample in enumerate(tqdm(dataset[subset_name])):
res = process_sample(sample, i)
solution.append(res)
print('------------ Saving results ---------------')
sub = pd.DataFrame(solution, columns=["order_id", "wf_vertices", "wf_edges"])
sub.to_parquet("submission.parquet")
print("------------ Done ------------ ")