Update inference.py - Split inference into two stages

#1
Files changed (1) hide show
  1. inference.py +248 -115
inference.py CHANGED
@@ -1,24 +1,24 @@
1
  import argparse
2
- import sys
3
  import os
4
-
5
- from typing import Dict, Optional, Tuple, List
6
  from omegaconf import OmegaConf
7
  from PIL import Image
8
  from dataclasses import dataclass
9
  from collections import defaultdict
10
  import torch
11
  import torch.utils.checkpoint
12
- from torchvision.utils import make_grid, save_image
13
- from accelerate.utils import set_seed
14
  from tqdm.auto import tqdm
15
  import torch.nn.functional as F
16
  from einops import rearrange
17
  from rembg import remove, new_session
18
- import pdb
19
  from mvdiffusion.pipelines.pipeline_mvdiffusion_unclip import StableUnCLIPImg2ImgPipeline
20
  from econdataset import SMPLDataset
21
  from reconstruct import ReMesh
 
22
  providers = [
23
  ('CUDAExecutionProvider', {
24
  'device_id': 0,
@@ -30,10 +30,27 @@ providers = [
30
  session = new_session(providers=providers)
31
 
32
  weight_dtype = torch.float16
33
- def tensor_to_numpy(tensor):
 
 
34
  return tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
35
 
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  @dataclass
38
  class TestConfig:
39
  pretrained_model_name_or_path: str
@@ -43,7 +60,6 @@ class TestConfig:
43
  seed: Optional[int]
44
  validation_batch_size: int
45
  dataloader_num_workers: int
46
- # save_single_views: bool
47
  save_mode: str
48
  local_rank: int
49
 
@@ -56,123 +72,233 @@ class TestConfig:
56
  num_views: int
57
  enable_xformers_memory_efficient_attention: bool
58
  with_smpl: Optional[bool]
59
-
60
  recon_opt: Dict
61
 
 
 
 
 
 
62
 
63
- def convert_to_numpy(tensor):
64
- return tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
65
 
66
- def convert_to_pil(tensor):
67
- return Image.fromarray(convert_to_numpy(tensor))
68
 
69
- def save_image(tensor, fp):
70
- ndarr = convert_to_numpy(tensor)
71
- # pdb.set_trace()
72
- save_image_numpy(ndarr, fp)
73
- return ndarr
74
 
75
- def save_image_numpy(ndarr, fp):
76
- im = Image.fromarray(ndarr)
77
- im.save(fp)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
- def run_inference(dataloader, econdata, pipeline, carving, cfg: TestConfig, save_dir):
80
- pipeline.set_progress_bar_config(disable=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  if cfg.seed is None:
83
  generator = None
84
  else:
85
- generator = torch.Generator(device=pipeline.unet.device).manual_seed(cfg.seed)
86
-
87
- images_cond, pred_cat = [], defaultdict(list)
88
  for case_id, batch in tqdm(enumerate(dataloader)):
89
- images_cond.append(batch['imgs_in'][:, 0])
90
-
91
- imgs_in = torch.cat([batch['imgs_in']]*2, dim=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  num_views = imgs_in.shape[1]
93
- imgs_in = rearrange(imgs_in, "B Nv C H W -> (B Nv) C H W")# (B*Nv, 3, H, W)
 
94
  if cfg.with_smpl:
95
- smpl_in = torch.cat([batch['smpl_imgs_in']]*2, dim=0)
96
  smpl_in = rearrange(smpl_in, "B Nv C H W -> (B Nv) C H W")
97
  else:
98
  smpl_in = None
99
 
100
- normal_prompt_embeddings, clr_prompt_embeddings = batch['normal_prompt_embeddings'], batch['color_prompt_embeddings']
 
101
  prompt_embeddings = torch.cat([normal_prompt_embeddings, clr_prompt_embeddings], dim=0)
102
  prompt_embeddings = rearrange(prompt_embeddings, "B Nv N C -> (B Nv) N C")
103
 
104
  with torch.autocast("cuda"):
105
- # B*Nv images
106
  guidance_scale = cfg.validation_guidance_scales
107
  unet_out = pipeline(
108
- imgs_in, None, prompt_embeds=prompt_embeddings,
109
- dino_feature=None, smpl_in=smpl_in,
110
- generator=generator, guidance_scale=guidance_scale, output_type='pt', num_images_per_prompt=1,
111
- **cfg.pipe_validation_kwargs
 
 
 
 
 
 
112
  )
113
-
114
  out = unet_out.images
115
  bsz = out.shape[0] // 2
116
-
117
  normals_pred = out[:bsz]
118
- images_pred = out[bsz:]
119
- if cfg.save_mode == 'concat': ## save concatenated color and normal---------------------
120
- pred_cat[f"cfg{guidance_scale:.1f}"].append(torch.cat([normals_pred, images_pred], dim=-1)) # b, 3, h, w
121
- cur_dir = os.path.join(save_dir, f"cropsize-{cfg.validation_dataset.crop_size}-cfg{guidance_scale:.1f}-seed{cfg.seed}-smpl-{cfg.with_smpl}")
122
- os.makedirs(cur_dir, exist_ok=True)
123
- for i in range(bsz//num_views):
124
- scene = batch['filename'][i].split('.')[0]
125
-
126
- img_in_ = images_cond[-1][i].to(out.device)
127
- vis_ = [img_in_]
128
- for j in range(num_views):
129
- idx = i*num_views + j
130
- normal = normals_pred[idx]
131
- color = images_pred[idx]
132
-
133
- vis_.append(color)
134
- vis_.append(normal)
135
-
136
- out_filename = f"{cur_dir}/{scene}.png"
137
- vis_ = torch.stack(vis_, dim=0)
138
- vis_ = make_grid(vis_, nrow=len(vis_), padding=0, value_range=(0, 1))
139
- save_image(vis_, out_filename)
140
- elif cfg.save_mode == 'rgb':
141
- for i in range(bsz//num_views):
142
- scene = batch['filename'][i].split('.')[0]
143
-
144
- img_in_ = images_cond[-1][i].to(out.device)
145
- normals, colors = [], []
146
- for j in range(num_views):
147
- idx = i*num_views + j
148
- normal = normals_pred[idx]
149
- if j == 0:
150
- color = imgs_in[0].to(out.device)
151
- else:
152
- color = images_pred[idx]
153
- if j in [3, 4]:
154
- normal = torch.flip(normal, dims=[2])
155
- color = torch.flip(color, dims=[2])
156
-
157
- colors.append(color)
158
- if j == 6:
159
- normal = F.interpolate(normal.unsqueeze(0), size=(256, 256), mode='bilinear', align_corners=False).squeeze(0)
160
- normals.append(normal)
161
-
162
- ## save color and normal---------------------
163
- # normal_filename = f"normals_{view}_masked.png"
164
- # rgb_filename = f"color_{view}_masked.png"
165
- # save_image(normal, os.path.join(scene_dir, normal_filename))
166
- # save_image(color, os.path.join(scene_dir, rgb_filename))
167
- normals[0][:, :256, 256:512] = normals[-1]
168
-
169
- colors = [remove(convert_to_pil(tensor), session=session) for tensor in colors[:6]]
170
- normals = [remove(convert_to_pil(tensor), session=session) for tensor in normals[:6]]
171
- pose = econdata.__getitem__(case_id)
172
- carving.optimize_case(scene, pose, colors, normals)
173
- torch.cuda.empty_cache()
174
-
175
-
176
 
177
  def load_pshuman_pipeline(cfg):
178
  pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(cfg.pretrained_model_name_or_path, torch_dtype=weight_dtype)
@@ -181,43 +307,50 @@ def load_pshuman_pipeline(cfg):
181
  pipeline.to('cuda')
182
  return pipeline
183
 
184
- def main(
185
- cfg: TestConfig
186
- ):
187
 
188
- # If passed along, set the training seed now.
 
189
  if cfg.seed is not None:
190
  set_seed(cfg.seed)
191
- pipeline = load_pshuman_pipeline(cfg)
192
-
193
 
194
  if cfg.with_smpl:
195
  from mvdiffusion.data.testdata_with_smpl import SingleImageDataset
196
  else:
197
  from mvdiffusion.data.single_image_dataset import SingleImageDataset
198
-
199
- # Get the dataset
200
- validation_dataset = SingleImageDataset(
201
- **cfg.validation_dataset
202
- )
203
  validation_dataloader = torch.utils.data.DataLoader(
204
- validation_dataset, batch_size=cfg.validation_batch_size, shuffle=False, num_workers=cfg.dataloader_num_workers
 
 
 
205
  )
206
- dataset_param = {'image_dir': validation_dataset.root_dir, 'seg_dir': None, 'colab': False, 'has_det': True, 'hps_type': 'pixie'}
207
- econdata = SMPLDataset(dataset_param, device='cuda')
208
 
 
 
 
 
 
 
 
 
209
  carving = ReMesh(cfg.recon_opt, econ_dataset=econdata)
 
 
 
 
210
  run_inference(validation_dataloader, econdata, pipeline, carving, cfg, cfg.save_dir)
211
-
212
 
213
  if __name__ == '__main__':
214
  parser = argparse.ArgumentParser()
215
  parser.add_argument('--config', type=str, required=True)
216
  args, extras = parser.parse_known_args()
217
- from utils.misc import load_config
218
 
219
- # parse YAML config to OmegaConf
220
  cfg = load_config(args.config, cli_args=extras)
221
  schema = OmegaConf.structured(TestConfig)
222
  cfg = OmegaConf.merge(schema, cfg)
223
- main(cfg)
 
1
  import argparse
2
+ import json
3
  import os
4
+ from pathlib import Path
5
+ from typing import Dict, Optional, List
6
  from omegaconf import OmegaConf
7
  from PIL import Image
8
  from dataclasses import dataclass
9
  from collections import defaultdict
10
  import torch
11
  import torch.utils.checkpoint
12
+ from torchvision.utils import make_grid
13
+ from accelerate.utils import set_seed
14
  from tqdm.auto import tqdm
15
  import torch.nn.functional as F
16
  from einops import rearrange
17
  from rembg import remove, new_session
 
18
  from mvdiffusion.pipelines.pipeline_mvdiffusion_unclip import StableUnCLIPImg2ImgPipeline
19
  from econdataset import SMPLDataset
20
  from reconstruct import ReMesh
21
+
22
  providers = [
23
  ('CUDAExecutionProvider', {
24
  'device_id': 0,
 
30
  session = new_session(providers=providers)
31
 
32
  weight_dtype = torch.float16
33
+
34
+
35
+ def convert_to_numpy(tensor):
36
  return tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
37
 
38
 
39
+ def convert_to_pil(tensor):
40
+ return Image.fromarray(convert_to_numpy(tensor))
41
+
42
+
43
+ def save_tensor_image(tensor, fp):
44
+ ndarr = convert_to_numpy(tensor)
45
+ save_image_numpy(ndarr, fp)
46
+ return ndarr
47
+
48
+
49
+ def save_image_numpy(ndarr, fp):
50
+ im = Image.fromarray(ndarr)
51
+ im.save(fp)
52
+
53
+
54
  @dataclass
55
  class TestConfig:
56
  pretrained_model_name_or_path: str
 
60
  seed: Optional[int]
61
  validation_batch_size: int
62
  dataloader_num_workers: int
 
63
  save_mode: str
64
  local_rank: int
65
 
 
72
  num_views: int
73
  enable_xformers_memory_efficient_attention: bool
74
  with_smpl: Optional[bool]
75
+
76
  recon_opt: Dict
77
 
78
+ # new two-stage settings
79
+ run_mode: str = "full" # full | generate | reconstruct
80
+ multiview_tmp_dir: str = ""
81
+ prefer_edited_views: bool = True
82
+ save_multiview_metadata: bool = True
83
 
 
 
84
 
85
+ def ensure_rgba(img: Image.Image) -> Image.Image:
86
+ return img.convert("RGBA") if img.mode != "RGBA" else img
87
 
 
 
 
 
 
88
 
89
+ def get_scene_name(batch, sample_index: int) -> str:
90
+ return Path(batch['filename'][sample_index]).stem
91
+
92
+
93
+ def get_scene_dir(base_dir: str, scene: str) -> Path:
94
+ return Path(base_dir) / scene
95
+
96
+
97
+ def save_multiview_scene(base_dir: str, scene: str, colors: List[Image.Image], normals: List[Image.Image], meta: Optional[dict] = None):
98
+ scene_dir = get_scene_dir(base_dir, scene)
99
+ raw_dir = scene_dir / "raw"
100
+ edit_dir = scene_dir / "edit"
101
+ raw_dir.mkdir(parents=True, exist_ok=True)
102
+ edit_dir.mkdir(parents=True, exist_ok=True)
103
+
104
+ for idx, img in enumerate(colors):
105
+ img = ensure_rgba(img)
106
+ img.save(raw_dir / f"color_{idx:02d}.png")
107
+ img.save(edit_dir / f"color_{idx:02d}.png")
108
+
109
+ for idx, img in enumerate(normals):
110
+ img = ensure_rgba(img)
111
+ img.save(raw_dir / f"normal_{idx:02d}.png")
112
+ img.save(edit_dir / f"normal_{idx:02d}.png")
113
+
114
+ if meta is not None:
115
+ with open(scene_dir / "meta.json", "w", encoding="utf-8") as f:
116
+ json.dump(meta, f, indent=2)
117
+
118
+
119
+ def load_multiview_scene(base_dir: str, scene: str, prefer_edit=True):
120
+ scene_dir = get_scene_dir(base_dir, scene)
121
+ candidate_dirs = [scene_dir / ("edit" if prefer_edit else "raw"), scene_dir / ("raw" if prefer_edit else "edit")]
122
 
123
+ data_dir = None
124
+ for cdir in candidate_dirs:
125
+ if cdir.exists():
126
+ data_dir = cdir
127
+ break
128
+ if data_dir is None:
129
+ raise FileNotFoundError(f"No multiview directory found for scene '{scene}' under {scene_dir}")
130
+
131
+ color_paths = sorted(data_dir.glob("color_*.png"))
132
+ normal_paths = sorted(data_dir.glob("normal_*.png"))
133
+ if not color_paths or not normal_paths:
134
+ raise FileNotFoundError(f"No color/normal images found in {data_dir}")
135
+
136
+ colors = [ensure_rgba(Image.open(p)) for p in color_paths]
137
+ normals = [ensure_rgba(Image.open(p)) for p in normal_paths]
138
+ return colors, normals
139
+
140
+
141
+ def prepare_scene_views(batch, imgs_in, normals_pred, images_pred, out, cfg: TestConfig, save_dir, images_cond, case_id):
142
+ guidance_scale = cfg.validation_guidance_scales
143
+ num_views = imgs_in.shape[0] // (out.shape[0] // 2 // cfg.num_views) if False else None # unused safeguard
144
+ bsz = out.shape[0] // 2
145
+ num_views = cfg.num_views
146
+ scene_results = []
147
+
148
+ if cfg.save_mode == 'concat':
149
+ cur_dir = os.path.join(save_dir, f"cropsize-{cfg.validation_dataset.crop_size}-cfg{guidance_scale:.1f}-seed{cfg.seed}-smpl-{cfg.with_smpl}")
150
+ os.makedirs(cur_dir, exist_ok=True)
151
+ for i in range(bsz // num_views):
152
+ scene = get_scene_name(batch, i)
153
+ img_in_ = images_cond[i].to(out.device)
154
+ vis_ = [img_in_]
155
+ for j in range(num_views):
156
+ idx = i * num_views + j
157
+ normal = normals_pred[idx]
158
+ color = images_pred[idx]
159
+ vis_.append(color)
160
+ vis_.append(normal)
161
+
162
+ out_filename = f"{cur_dir}/{scene}.png"
163
+ vis_ = torch.stack(vis_, dim=0)
164
+ vis_ = make_grid(vis_, nrow=len(vis_), padding=0, value_range=(0, 1))
165
+ save_tensor_image(vis_, out_filename)
166
+ return scene_results
167
+
168
+ if cfg.save_mode != 'rgb':
169
+ raise ValueError(f"Unsupported save_mode for two-stage workflow: {cfg.save_mode}")
170
+
171
+ for i in range(bsz // num_views):
172
+ scene = get_scene_name(batch, i)
173
+ normals, colors = [], []
174
+
175
+ for j in range(num_views):
176
+ idx = i * num_views + j
177
+ normal = normals_pred[idx]
178
+ if j == 0:
179
+ color = imgs_in[i * num_views].to(out.device)
180
+ else:
181
+ color = images_pred[idx]
182
+
183
+ if j in [3, 4]:
184
+ normal = torch.flip(normal, dims=[2])
185
+ color = torch.flip(color, dims=[2])
186
+
187
+ colors.append(color)
188
+ if j == 6:
189
+ normal = F.interpolate(normal.unsqueeze(0), size=(256, 256), mode='bilinear', align_corners=False).squeeze(0)
190
+ normals.append(normal)
191
+
192
+ normals[0][:, :256, 256:512] = normals[-1]
193
+
194
+ color_pils = [ensure_rgba(remove(convert_to_pil(tensor), session=session)) for tensor in colors[:6]]
195
+ normal_pils = [ensure_rgba(remove(convert_to_pil(tensor), session=session)) for tensor in normals[:6]]
196
+
197
+ meta = None
198
+ if cfg.save_multiview_metadata:
199
+ meta = {
200
+ "scene": scene,
201
+ "case_id": case_id,
202
+ "num_colors": len(color_pils),
203
+ "num_normals": len(normal_pils),
204
+ "seed": cfg.seed,
205
+ "run_mode": cfg.run_mode,
206
+ "crop_size": cfg.validation_dataset.crop_size,
207
+ "with_smpl": cfg.with_smpl,
208
+ }
209
+
210
+ scene_results.append((scene, color_pils, normal_pils, meta))
211
+
212
+ return scene_results
213
+
214
+
215
+ def run_inference(dataloader, econdata, pipeline, carving, cfg: TestConfig, save_dir):
216
+ if pipeline is not None:
217
+ pipeline.set_progress_bar_config(disable=True)
218
 
219
  if cfg.seed is None:
220
  generator = None
221
  else:
222
+ device = pipeline.unet.device if pipeline is not None else "cuda"
223
+ generator = torch.Generator(device=device).manual_seed(cfg.seed)
224
+
225
  for case_id, batch in tqdm(enumerate(dataloader)):
226
+ if cfg.run_mode == "reconstruct":
227
+ batch_size = len(batch['filename'])
228
+ for i in range(batch_size):
229
+ scene = get_scene_name(batch, i)
230
+ colors, normals = load_multiview_scene(
231
+ cfg.multiview_tmp_dir,
232
+ scene,
233
+ prefer_edit=cfg.prefer_edited_views,
234
+ )
235
+ pose = econdata.__getitem__(case_id + i)
236
+ carving.optimize_case(scene, pose, colors, normals)
237
+ torch.cuda.empty_cache()
238
+ continue
239
+
240
+ images_cond = batch['imgs_in'][:, 0]
241
+ imgs_in = torch.cat([batch['imgs_in']] * 2, dim=0)
242
  num_views = imgs_in.shape[1]
243
+ imgs_in = rearrange(imgs_in, "B Nv C H W -> (B Nv) C H W")
244
+
245
  if cfg.with_smpl:
246
+ smpl_in = torch.cat([batch['smpl_imgs_in']] * 2, dim=0)
247
  smpl_in = rearrange(smpl_in, "B Nv C H W -> (B Nv) C H W")
248
  else:
249
  smpl_in = None
250
 
251
+ normal_prompt_embeddings = batch['normal_prompt_embeddings']
252
+ clr_prompt_embeddings = batch['color_prompt_embeddings']
253
  prompt_embeddings = torch.cat([normal_prompt_embeddings, clr_prompt_embeddings], dim=0)
254
  prompt_embeddings = rearrange(prompt_embeddings, "B Nv N C -> (B Nv) N C")
255
 
256
  with torch.autocast("cuda"):
 
257
  guidance_scale = cfg.validation_guidance_scales
258
  unet_out = pipeline(
259
+ imgs_in,
260
+ None,
261
+ prompt_embeds=prompt_embeddings,
262
+ dino_feature=None,
263
+ smpl_in=smpl_in,
264
+ generator=generator,
265
+ guidance_scale=guidance_scale,
266
+ output_type='pt',
267
+ num_images_per_prompt=1,
268
+ **cfg.pipe_validation_kwargs,
269
  )
270
+
271
  out = unet_out.images
272
  bsz = out.shape[0] // 2
 
273
  normals_pred = out[:bsz]
274
+ images_pred = out[bsz:]
275
+
276
+ scene_results = prepare_scene_views(
277
+ batch=batch,
278
+ imgs_in=imgs_in,
279
+ normals_pred=normals_pred,
280
+ images_pred=images_pred,
281
+ out=out,
282
+ cfg=cfg,
283
+ save_dir=save_dir,
284
+ images_cond=images_cond,
285
+ case_id=case_id,
286
+ )
287
+
288
+ if cfg.save_mode == 'concat':
289
+ continue
290
+
291
+ for i, (scene, colors, normals, meta) in enumerate(scene_results):
292
+ if cfg.run_mode == "generate":
293
+ save_multiview_scene(cfg.multiview_tmp_dir, scene, colors, normals, meta=meta)
294
+ print(f"[PSHuman] Saved multiview scene '{scene}' to {get_scene_dir(cfg.multiview_tmp_dir, scene)}")
295
+ continue
296
+
297
+ pose = econdata.__getitem__(case_id + i)
298
+ carving.optimize_case(scene, pose, colors, normals)
299
+ torch.cuda.empty_cache()
300
+
301
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
 
303
  def load_pshuman_pipeline(cfg):
304
  pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(cfg.pretrained_model_name_or_path, torch_dtype=weight_dtype)
 
307
  pipeline.to('cuda')
308
  return pipeline
309
 
 
 
 
310
 
311
+
312
+ def main(cfg: TestConfig):
313
  if cfg.seed is not None:
314
  set_seed(cfg.seed)
315
+
316
+ pipeline = None if cfg.run_mode == "reconstruct" else load_pshuman_pipeline(cfg)
317
 
318
  if cfg.with_smpl:
319
  from mvdiffusion.data.testdata_with_smpl import SingleImageDataset
320
  else:
321
  from mvdiffusion.data.single_image_dataset import SingleImageDataset
322
+
323
+ validation_dataset = SingleImageDataset(**cfg.validation_dataset)
 
 
 
324
  validation_dataloader = torch.utils.data.DataLoader(
325
+ validation_dataset,
326
+ batch_size=cfg.validation_batch_size,
327
+ shuffle=False,
328
+ num_workers=cfg.dataloader_num_workers,
329
  )
 
 
330
 
331
+ dataset_param = {
332
+ 'image_dir': validation_dataset.root_dir,
333
+ 'seg_dir': None,
334
+ 'colab': False,
335
+ 'has_det': True,
336
+ 'hps_type': 'pixie',
337
+ }
338
+ econdata = SMPLDataset(dataset_param, device='cuda')
339
  carving = ReMesh(cfg.recon_opt, econ_dataset=econdata)
340
+
341
+ if cfg.run_mode in {"generate", "reconstruct"} and not cfg.multiview_tmp_dir:
342
+ raise ValueError("multiview_tmp_dir must be provided for run_mode='generate' or 'reconstruct'.")
343
+
344
  run_inference(validation_dataloader, econdata, pipeline, carving, cfg, cfg.save_dir)
345
+
346
 
347
  if __name__ == '__main__':
348
  parser = argparse.ArgumentParser()
349
  parser.add_argument('--config', type=str, required=True)
350
  args, extras = parser.parse_known_args()
351
+ from utils.misc import load_config
352
 
 
353
  cfg = load_config(args.config, cli_args=extras)
354
  schema = OmegaConf.structured(TestConfig)
355
  cfg = OmegaConf.merge(schema, cfg)
356
+ main(cfg)