File size: 3,950 Bytes
c22c8c5
 
 
 
 
 
 
 
 
 
 
b073e54
573ee90
 
 
c22c8c5
 
 
 
 
 
 
 
 
 
 
 
9518589
 
 
 
c22c8c5
9518589
 
c22c8c5
 
 
 
 
 
 
 
 
 
9518589
c22c8c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
573ee90
 
33113fd
573ee90
 
 
 
 
32fd25b
573ee90
c22c8c5
 
 
 
 
573ee90
c22c8c5
 
 
 
 
 
 
 
 
17323cc
c22c8c5
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from pathlib import Path
from tqdm import tqdm
import pandas as pd
from datasets import load_dataset
import os
import json
import gc

from utils import empty_solution
from predict import predict_wireframe

from fast_pointnet_v2 import load_pointnet_model
from fast_pointnet_class import load_pointnet_model as load_pointnet_class_model
import torch

if __name__ == "__main__":
    print ("------------ Loading dataset------------ ")
    param_path = Path('params.json')
    print(param_path)
    with param_path.open() as f:
        params = json.load(f)
    print(params)
    import os
    
    print('pwd:')
    os.system('pwd')
    print(os.system('ls -lahtr'))
    print('/generic/path/to/data_dir/') # Placeholder for '/tmp/data/'
    print(os.system('ls -lahtr /generic/path/to/data_dir/')) # Placeholder for /tmp/data/
    print('/generic/path/to/data_dir/data') # Placeholder for '/tmp/data/data'
    print(os.system('ls -lahtrR /generic/path/to/data_dir/data')) # Placeholder for /tmp/data/data

    data_path_test_server = Path('/generic/path/to/data_dir') # Placeholder for Path('/tmp/data')
    data_path_local = Path("/generic/path/to/user_home") / '.cache/huggingface/datasets/usm3d___hoho25k_test_x/' # Placeholder for Path().home()

    if data_path_test_server.exists():
        # data_path = data_path_test_server
        TEST_ENV = True
    else: 
        # data_path = data_path_local
        TEST_ENV = False
        from huggingface_hub import snapshot_download 
        _ = snapshot_download(
            repo_id=params['dataset'],
            local_dir="/generic/path/to/data_dir", # Placeholder for "/tmp/data"
            repo_type="dataset",
        )
    data_path = data_path_test_server
    
    
    print(data_path)

    # dataset = load_dataset(params['dataset'], trust_remote_code=True, use_auth_token=params['token'])
    # data_files = {
    #     "validation": [str(p) for p in [*data_path.rglob('*validation*.arrow')]+[*data_path.rglob('*public*/**/*.tar')]],
    #     "test": [str(p) for p in [*data_path.rglob('*test*.arrow')]+[*data_path.rglob('*private*/**/*.tar')]],
    # }
    data_files = {
        "validation": [str(p) for p in data_path.rglob('*public*/**/*.tar')],
        "test": [str(p) for p in data_path.rglob('*private*/**/*.tar')],
    }
    print(data_files)
    dataset = load_dataset(
        str(data_path / 'hoho25k_test_x.py'),
        data_files=data_files,
        trust_remote_code=True,
        writer_batch_size=100
    )

    print('load with webdataset')


    print(dataset, flush=True)
    
    device = "cuda" if torch.cuda.is_available() else "cpu"

    pnet_model = load_pointnet_model(model_path="pnet.pth", device=device, predict_score=True)

    pnet_class_model = load_pointnet_class_model(model_path="pnet_class.pth", device=device)

    voxel_model = None

    config = {'vertex_threshold': 0.59, 'edge_threshold': 0.65, 'only_predicted_connections': True}

    print('------------ Now you can do your solution ---------------')
    solution = []

    def process_sample(sample, i):
        try:
            pred_vertices, pred_edges = predict_wireframe(sample, pnet_model, voxel_model, pnet_class_model, config)
        except:
            pred_vertices, pred_edges = empty_solution()
        if i %10 == 0:
            gc.collect()
        return {
            'order_id': sample['order_id'], 
            'wf_vertices': pred_vertices.tolist(),
            'wf_edges': pred_edges
        }
    
    num_cores = 4
    
    for subset_name in dataset.keys():
        print (f"Predicting {subset_name}")
        for i, sample in enumerate(tqdm(dataset[subset_name])):
            res = process_sample(sample, i)
            solution.append(res)

    print('------------ Saving results ---------------')
    sub = pd.DataFrame(solution, columns=["order_id", "wf_vertices", "wf_edges"])
    sub.to_parquet("submission.parquet")
    print("------------ Done ------------ ")