File size: 3,116 Bytes
c22c8c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from pathlib import Path
from tqdm import tqdm
import pandas as pd
import numpy as np
from datasets import load_dataset
from typing import Dict
from joblib import Parallel, delayed
import os
import json
import gc

from utils import empty_solution
from predict import predict_wireframe

if __name__ == "__main__":
    print ("------------ Loading dataset------------ ")
    param_path = Path('params.json')
    print(param_path)
    with param_path.open() as f:
        params = json.load(f)
    print(params)
    import os
    
    print('pwd:')
    os.system('pwd')
    print(os.system('ls -lahtr'))
    print('/tmp/data/')
    print(os.system('ls -lahtr /tmp/data/'))
    print('/tmp/data/data')
    print(os.system('ls -lahtrR /tmp/data/data'))
    

    data_path_test_server = Path('/tmp/data')
    data_path_local = Path().home() / '.cache/huggingface/datasets/usm3d___hoho25k_test_x/'

    if data_path_test_server.exists():
        # data_path = data_path_test_server
        TEST_ENV = True
    else: 
        # data_path = data_path_local
        TEST_ENV = False
        from huggingface_hub import snapshot_download 
        _ = snapshot_download(
            repo_id=params['dataset'],
            local_dir="/tmp/data",
            repo_type="dataset",
        )
    data_path = data_path_test_server
    
    
    print(data_path)

    # dataset = load_dataset(params['dataset'], trust_remote_code=True, use_auth_token=params['token'])
    # data_files = {
    #     "validation": [str(p) for p in [*data_path.rglob('*validation*.arrow')]+[*data_path.rglob('*public*/**/*.tar')]],
    #     "test": [str(p) for p in [*data_path.rglob('*test*.arrow')]+[*data_path.rglob('*private*/**/*.tar')]],
    # }
    data_files = {
        "validation": [str(p) for p in data_path.rglob('*public*/**/*.tar')],
        "test": [str(p) for p in data_path.rglob('*private*/**/*.tar')],
    }
    print(data_files)
    dataset = load_dataset(
        str(data_path / 'hoho25k_test_x.py'),
        data_files=data_files,
        trust_remote_code=True,
        writer_batch_size=100
    )

    print('load with webdataset')


    print(dataset, flush=True)
    
    print('------------ Now you can do your solution ---------------')
    solution = []

    def process_sample(sample, i):
        try:
            pred_vertices, pred_edges = predict_wireframe(sample)
        except:
            pred_vertices, pred_edges = empty_solution()
        if i %10 == 0:
            gc.collect()
        return {
            'order_id': sample['order_id'], 
            'wf_vertices': pred_vertices.tolist(),
            'wf_edges': pred_edges
        }
    num_cores = 4
    
    for subset_name in dataset.keys():
        print (f"Predicting {subset_name}")
        for i, sample in enumerate(tqdm(dataset[subset_name])):
            res = process_sample(sample, i)
            solution.append(res)

    print('------------ Saving results ---------------')
    sub = pd.DataFrame(solution, columns=["order_id", "wf_vertices", "wf_edges"])
    sub.to_parquet("submission.parquet")
    print("------------ Done ------------ ")