colin1842 commited on
Commit
10af265
·
verified ·
1 Parent(s): 485abd5

Upload script.py

Browse files
Files changed (1) hide show
  1. script.py +83 -183
script.py CHANGED
@@ -64,34 +64,12 @@ import hoho; hoho.setup() # YOU MUST CALL hoho.setup() BEFORE ANYTHING ELSE
64
  ### Here you can import any library or module you want.
65
  ### The code below is used to read and parse the input dataset.
66
  ### Please, do not modify it.
67
- import subprocess
68
- import sys
69
- import os
70
- # Setup environment and install necessary packages
71
- def setup_environment():
72
- subprocess.check_call([sys.executable, "-m", "pip", "install", "git+http://hf.co/usm3d/tools.git"])
73
- import hoho
74
- hoho.setup()
75
-
76
- subprocess.check_call([sys.executable, "-m", "pip", "install", "torch==2.0.1", "torchvision==0.15.2", "torchaudio==2.0.2", "-f", "https://download.pytorch.org/whl/cu117.html"])
77
- subprocess.check_call([sys.executable, "-m", "pip", "install", "scikit-learn"])
78
- subprocess.check_call([sys.executable, "-m", "pip", "install", "tqdm"])
79
- subprocess.check_call([sys.executable, "-m", "pip", "install", "scipy"])
80
- subprocess.check_call([sys.executable, "-m", "pip", "install", "open3d"])
81
- subprocess.check_call([sys.executable, "-m", "pip", "install", "easydict"])
82
-
83
- # pc_util_path = os.path.join(os.getcwd(), 'pc_util')
84
- # if os.path.isdir(pc_util_path):
85
- # os.chdir(pc_util_path)
86
- # subprocess.check_call([sys.executable, "setup.py", "install"])
87
- # else:
88
- # print(f"Directory {pc_util_path} does not exist")
89
 
90
  import webdataset as wds
91
  from tqdm import tqdm
92
  from typing import Dict
93
  import pandas as pd
94
- # from transformers import AutoTokenizer
95
  import os
96
  import time
97
  import io
@@ -101,168 +79,90 @@ import numpy as np
101
  from hoho.read_write_colmap import read_cameras_binary, read_images_binary, read_points3D_binary
102
  from hoho import proc, Sample
103
 
104
- ### Ours Import Settings
105
- import os
106
- import torch
107
- import torch.nn as nn
108
- import argparse
109
- import datetime
110
- import glob
111
- import torch.distributed as dist
112
- from dataset.data_utils import build_dataloader
113
- from test_util import test_model
114
- from model.roofnet import RoofNet
115
- from torch import optim
116
- from utils import common_utils
117
- from model import model_utils
118
-
119
- import webdataset as wds
120
- from tqdm import tqdm
121
- from typing import Dict
122
- import pandas as pd
123
- # from transformer import AutoTokenizer
124
- import os
125
- import time
126
- import io
127
- from PIL import Image as PImage
128
- import numpy as np
129
-
130
- from hoho.read_write_colmap import read_cameras_binary, read_images_binary, read_points3D_binary
131
- from hoho import proc, Sample
132
-
133
- def remove_z_outliers(pcd_data, low_threshold_percentage=50, high_threshold_percentage=0):
134
- """
135
- Remove outliers from a point cloud data based on z-value.
136
-
137
- Parameters:
138
- - pcd_data (np.array): Nx3 numpy array containing the point cloud data.
139
- - low_threshold_percentage (float): Percentage of points to be removed based on the lowest z-values.
140
- - high_threshold_percentage (float): Percentage of points to be removed based on the highest z-values.
141
-
142
- Returns:
143
- - np.array: Filtered point cloud data as a Nx3 numpy array.
144
- """
145
- num_std=3
146
- low_z_threshold = np.percentile(pcd_data[:, 2], low_threshold_percentage)
147
- high_z_threshold = np.percentile(pcd_data[:, 2], 100 - high_threshold_percentage)
148
- mean_z, std_z = np.mean(pcd_data[:, 2]), np.std(pcd_data[:, 2])
149
- z_range = (mean_z - num_std * std_z, mean_z + num_std * std_z)
150
-
151
- # filtered_pcd_data = pcd_data[(pcd_data[:, 2] > low_z_threshold) & (pcd_data[:, 2] < z_range[1])]
152
- filtered_pcd_data = pcd_data[(pcd_data[:, 2] > low_z_threshold)]
153
-
154
- return filtered_pcd_data
155
-
156
  def convert_entry_to_human_readable(entry):
157
- out = {}
158
- already_good = ['__key__', 'wf_vertices', 'wf_edges', 'edge_semantics', 'mesh_vertices', 'mesh_faces', 'face_semantics', 'K', 'R', 't']
159
- for k, v in entry.items():
160
- if k in already_good:
161
- out[k] = v
162
- continue
163
- if k == 'points3d':
164
- out[k] = read_points3D_binary(fid=io.BytesIO(v))
165
- if k == 'cameras':
166
- out[k] = read_cameras_binary(fid=io.BytesIO(v))
167
- if k == 'images':
168
- out[k] = read_images_binary(fid=io.BytesIO(v))
169
- if k in ['ade20k', 'gestalt']:
170
- out[k] = [PImage.open(io.BytesIO(x)).convert('RGB') for x in v]
171
- if k == 'depthcm':
172
- out[k] = [PImage.open(io.BytesIO(x)) for x in entry['depthcm']]
173
- return out
174
-
175
- def parse_config():
176
- parser = argparse.ArgumentParser()
177
- parser.add_argument('--data_path', type=str, default='Data/hoho_data_train', help='dataset path')
178
- parser.add_argument('--cfg_file', type=str, default='../model_cfg.yaml', help='model config for training')
179
- parser.add_argument('--batch_size', type=int, default=1, help='batch size for training')
180
- parser.add_argument('--gpu', type=str, default='0', help='gpu for training')
181
- parser.add_argument('--test_tag', type=str, default='hoho_test', help='extra tag for this experiment')
182
-
183
- args = parser.parse_args()
184
- cfg = common_utils.cfg_from_yaml_file(args.cfg_file)
185
- return args, cfg
 
 
 
 
 
 
186
 
 
 
 
 
 
 
 
 
187
  def save_submission(submission, path):
188
- """
189
- Saves the submission to a specified path.
190
-
191
- Parameters:
192
- submission (List[Dict[]]): The submission to save.
193
- path (str): The path to save the submission to.
194
- """
195
- sub = pd.DataFrame(submission, columns=["__key__", "wf_vertices", "wf_edges"])
196
- sub['wf_edges'] = sub['wf_edges'].apply(lambda x: x.tolist()) # Convert to list of lists
197
- sub.to_parquet(path)
198
- print(f"Submission saved to {path}")
199
-
200
-
201
- def main():
202
- # setup packages
203
  setup_environment()
204
 
205
- args, cfg = parse_config()
206
- os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
207
-
208
- extra_tag = args.test_tag
209
- output_dir = cfg.ROOT_DIR / 'output' / extra_tag
210
- assert output_dir.exists(), '%s does not exist!!!' % str(output_dir)
211
- ckpt_dir = output_dir #/ 'ckpt'
212
- output_dir = output_dir / 'test'
213
- output_dir.mkdir(parents=True, exist_ok=True)
214
-
215
- log_file = output_dir / 'log.txt'
216
- logger = common_utils.create_logger(log_file)
217
-
218
- logger.info('**********************Start logging**********************')
219
- for key, val in vars(args).items():
220
- logger.info('{:16} {}'.format(key, val))
221
- common_utils.log_config_to_file(cfg, logger=logger)
222
-
223
- print ("------------ Loading dataset------------ ")
224
- params = hoho.get_params()
225
- dataset = hoho.get_dataset(decode=None, split='val', dataset_type='webdataset')
226
- # dataset = dataset.decode()
227
- # dataset = dataset.map(proc)
228
-
229
- for entry in tqdm(dataset, desc="Processing entries"):
230
- human_entry = convert_entry_to_human_readable(entry)
231
- # human_entry = entry
232
- key = human_entry['__key__']
233
- points3D = human_entry['points3d']
234
- xyz_ = np.stack([p.xyz for p in points3D.values()])
235
- xyz = remove_z_outliers(xyz_, low_threshold_percentage=30, high_threshold_percentage=1.0)
236
- #TODO: from webd dataset to ours dataloader roofn3d_dataset.py L152
237
- test_loader = build_dataloader(key, xyz, args.batch_size, cfg.DATA, logger=logger)
238
- net = RoofNet(cfg.MODEL)
239
- net.cuda()
240
- net.eval()
241
-
242
- ckpt_list = glob.glob(str(ckpt_dir / '*checkpoint_epoch_*.pth'))
243
- if len(ckpt_list) > 0:
244
- ckpt_list.sort(key=os.path.getmtime)
245
- model_utils.load_params(net, ckpt_list[-1], logger=logger)
246
-
247
- logger.info('**********************Start testing**********************')
248
- logger.info(net)
249
-
250
- solution = []
251
-
252
- for sample in tqdm(test_loader):
253
- key, pred_vertices, pred_edges = test_model(net, test_loader, logger)
254
- solution.append({
255
- '__key__': key,
256
- 'wf_vertices': pred_vertices.tolist(),
257
- 'wf_edges': pred_edges
258
- })
259
- print(f"predict solution: {key}")
260
-
261
- # save_submission(solution, output_dir / "submission.parquet")
262
- save_submission(solution, "submission.parquet")
263
-
264
- # test_model(net, test_loader, logger)
265
-
266
-
267
- if __name__ == '__main__':
268
- main()
 
64
  ### Here you can import any library or module you want.
65
  ### The code below is used to read and parse the input dataset.
66
  ### Please, do not modify it.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  import webdataset as wds
69
  from tqdm import tqdm
70
  from typing import Dict
71
  import pandas as pd
72
+ from transformers import AutoTokenizer
73
  import os
74
  import time
75
  import io
 
79
  from hoho.read_write_colmap import read_cameras_binary, read_images_binary, read_points3D_binary
80
  from hoho import proc, Sample
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  def convert_entry_to_human_readable(entry):
83
+ out = {}
84
+ already_good = ['__key__', 'wf_vertices', 'wf_edges', 'edge_semantics', 'mesh_vertices', 'mesh_faces', 'face_semantics', 'K', 'R', 't']
85
+ for k, v in entry.items():
86
+ if k in already_good:
87
+ out[k] = v
88
+ continue
89
+ if k == 'points3d':
90
+ out[k] = read_points3D_binary(fid=io.BytesIO(v))
91
+ if k == 'cameras':
92
+ out[k] = read_cameras_binary(fid=io.BytesIO(v))
93
+ if k == 'images':
94
+ out[k] = read_images_binary(fid=io.BytesIO(v))
95
+ if k in ['ade20k', 'gestalt']:
96
+ out[k] = [PImage.open(io.BytesIO(x)).convert('RGB') for x in v]
97
+ if k == 'depthcm':
98
+ out[k] = [PImage.open(io.BytesIO(x)) for x in entry['depthcm']]
99
+ return out
100
+
101
+ '''---end of compulsory---'''
102
+
103
+ ### The part below is used to define and test your solution.
104
+ import subprocess
105
+ import sys
106
+ import os
107
+ def setup_environment():
108
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "git+http://hf.co/usm3d/tools.git"])
109
+ import hoho
110
+ hoho.setup()
111
+
112
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "torch==2.0.1", "torchvision==0.15.2", "torchaudio==2.0.2", "-f", "https://download.pytorch.org/whl/cu117.html"])
113
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "scikit-learn"])
114
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "tqdm"])
115
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "scipy"])
116
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "open3d"])
117
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "easydict"])
118
 
119
+ pc_util_path = os.path.join(os.getcwd(), 'pc_util')
120
+ if os.path.isdir(pc_util_path):
121
+ os.chdir(pc_util_path)
122
+ subprocess.check_call([sys.executable, "setup.py", "install"])
123
+ else:
124
+ print(f"Directory {pc_util_path} does not exist")
125
+
126
+ from pathlib import Path
127
  def save_submission(submission, path):
128
+ """
129
+ Saves the submission to a specified path.
130
+
131
+ Parameters:
132
+ submission (List[Dict[]]): The submission to save.
133
+ path (str): The path to save the submission to.
134
+ """
135
+ sub = pd.DataFrame(submission, columns=["__key__", "wf_vertices", "wf_edges"])
136
+ sub.to_parquet(path)
137
+ print(f"Submission saved to {path}")
138
+
139
+ if __name__ == "__main__":
 
 
 
140
  setup_environment()
141
 
142
+ from handcrafted_solution import predict
143
+ print ("------------ Loading dataset------------ ")
144
+ params = hoho.get_params()
145
+ dataset = hoho.get_dataset(decode=None, split='all', dataset_type='webdataset')
146
+
147
+ print('------------ Now you can do your solution ---------------')
148
+ solution = []
149
+ from concurrent.futures import ProcessPoolExecutor
150
+ with ProcessPoolExecutor(max_workers=8) as pool:
151
+ results = []
152
+ for i, sample in enumerate(tqdm(dataset)):
153
+ results.append(pool.submit(predict, sample, visualize=False))
154
+
155
+ for i, result in enumerate(tqdm(results)):
156
+ key, pred_vertices, pred_edges = result.result()
157
+ solution.append({
158
+ '__key__': key,
159
+ 'wf_vertices': pred_vertices.tolist(),
160
+ 'wf_edges': pred_edges
161
+ })
162
+ if i % 100 == 0:
163
+ # incrementally save the results in case we run out of time
164
+ print(f"Processed {i} samples")
165
+ # save_submission(solution, Path(params['output_path']) / "submission.parquet")
166
+ print('------------ Saving results ---------------')
167
+ save_submission(solution, Path(params['output_path']) / "submission.parquet")
168
+ print("------------ Done ------------ ")