painter3000 commited on
Commit
66043e5
·
verified ·
1 Parent(s): d233634

Update inference.py

Browse files

- New Version for Up- and Download Fotoset

Files changed (1) hide show
  1. inference.py +218 -190
inference.py CHANGED
@@ -1,24 +1,28 @@
1
  import argparse
2
- import json
3
  import os
 
4
  from pathlib import Path
5
- from typing import Dict, Optional, List
 
 
 
6
  from omegaconf import OmegaConf
7
  from PIL import Image
8
- from dataclasses import dataclass
9
- from collections import defaultdict
10
  import torch
11
  import torch.utils.checkpoint
 
12
  from torchvision.utils import make_grid
13
  from accelerate.utils import set_seed
14
  from tqdm.auto import tqdm
15
- import torch.nn.functional as F
16
  from einops import rearrange
17
  from rembg import remove, new_session
 
18
  from mvdiffusion.pipelines.pipeline_mvdiffusion_unclip import StableUnCLIPImg2ImgPipeline
19
  from econdataset import SMPLDataset
20
  from reconstruct import ReMesh
21
 
 
22
  providers = [
23
  ('CUDAExecutionProvider', {
24
  'device_id': 0,
@@ -32,24 +36,9 @@ session = new_session(providers=providers)
32
  weight_dtype = torch.float16
33
 
34
 
35
- def convert_to_numpy(tensor):
36
- return tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
37
-
38
-
39
- def convert_to_pil(tensor):
40
- return Image.fromarray(convert_to_numpy(tensor))
41
-
42
-
43
- def save_tensor_image(tensor, fp):
44
- ndarr = convert_to_numpy(tensor)
45
- save_image_numpy(ndarr, fp)
46
- return ndarr
47
-
48
-
49
- def save_image_numpy(ndarr, fp):
50
- im = Image.fromarray(ndarr)
51
- im.save(fp)
52
-
53
 
54
  @dataclass
55
  class TestConfig:
@@ -72,172 +61,200 @@ class TestConfig:
72
  num_views: int
73
  enable_xformers_memory_efficient_attention: bool
74
  with_smpl: Optional[bool]
75
-
76
  recon_opt: Dict
77
 
78
- # new two-stage settings
79
- run_mode: str = "full" # full | generate | reconstruct
80
- multiview_tmp_dir: str = ""
81
  prefer_edited_views: bool = True
82
- save_multiview_metadata: bool = True
83
 
84
 
85
- def ensure_rgba(img: Image.Image) -> Image.Image:
86
- return img.convert("RGBA") if img.mode != "RGBA" else img
 
 
 
 
87
 
88
 
89
- def get_scene_name(batch, sample_index: int) -> str:
90
- return Path(batch['filename'][sample_index]).stem
91
 
92
 
93
- def get_scene_dir(base_dir: str, scene: str) -> Path:
94
- return Path(base_dir) / scene
 
95
 
96
 
97
- def save_multiview_scene(base_dir: str, scene: str, colors: List[Image.Image], normals: List[Image.Image], meta: Optional[dict] = None):
98
- scene_dir = get_scene_dir(base_dir, scene)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  raw_dir = scene_dir / "raw"
100
  edit_dir = scene_dir / "edit"
101
- raw_dir.mkdir(parents=True, exist_ok=True)
102
- edit_dir.mkdir(parents=True, exist_ok=True)
 
 
 
 
 
 
 
103
 
104
  for idx, img in enumerate(colors):
105
- img = ensure_rgba(img)
106
- img.save(raw_dir / f"color_{idx:02d}.png")
107
- img.save(edit_dir / f"color_{idx:02d}.png")
 
108
 
109
  for idx, img in enumerate(normals):
110
- img = ensure_rgba(img)
111
- img.save(raw_dir / f"normal_{idx:02d}.png")
112
- img.save(edit_dir / f"normal_{idx:02d}.png")
113
-
114
- if meta is not None:
115
- with open(scene_dir / "meta.json", "w", encoding="utf-8") as f:
116
- json.dump(meta, f, indent=2)
117
-
118
-
119
- def load_multiview_scene(base_dir: str, scene: str, prefer_edit=True):
120
- scene_dir = get_scene_dir(base_dir, scene)
121
- candidate_dirs = [scene_dir / ("edit" if prefer_edit else "raw"), scene_dir / ("raw" if prefer_edit else "edit")]
122
-
123
- data_dir = None
124
- for cdir in candidate_dirs:
125
- if cdir.exists():
126
- data_dir = cdir
127
- break
128
- if data_dir is None:
129
- raise FileNotFoundError(f"No multiview directory found for scene '{scene}' under {scene_dir}")
130
-
131
- color_paths = sorted(data_dir.glob("color_*.png"))
132
- normal_paths = sorted(data_dir.glob("normal_*.png"))
133
- if not color_paths or not normal_paths:
134
- raise FileNotFoundError(f"No color/normal images found in {data_dir}")
135
-
136
- colors = [ensure_rgba(Image.open(p)) for p in color_paths]
137
- normals = [ensure_rgba(Image.open(p)) for p in normal_paths]
 
 
 
 
 
 
 
138
  return colors, normals
139
 
140
 
141
- def prepare_scene_views(batch, imgs_in, normals_pred, images_pred, out, cfg: TestConfig, save_dir, images_cond, case_id):
142
- guidance_scale = cfg.validation_guidance_scales
143
- num_views = imgs_in.shape[0] // (out.shape[0] // 2 // cfg.num_views) if False else None # unused safeguard
144
- bsz = out.shape[0] // 2
145
- num_views = cfg.num_views
146
- scene_results = []
147
-
148
- if cfg.save_mode == 'concat':
149
- cur_dir = os.path.join(save_dir, f"cropsize-{cfg.validation_dataset.crop_size}-cfg{guidance_scale:.1f}-seed{cfg.seed}-smpl-{cfg.with_smpl}")
150
- os.makedirs(cur_dir, exist_ok=True)
151
- for i in range(bsz // num_views):
152
- scene = get_scene_name(batch, i)
153
- img_in_ = images_cond[i].to(out.device)
154
- vis_ = [img_in_]
155
- for j in range(num_views):
156
- idx = i * num_views + j
157
- normal = normals_pred[idx]
158
- color = images_pred[idx]
159
- vis_.append(color)
160
- vis_.append(normal)
161
-
162
- out_filename = f"{cur_dir}/{scene}.png"
163
- vis_ = torch.stack(vis_, dim=0)
164
- vis_ = make_grid(vis_, nrow=len(vis_), padding=0, value_range=(0, 1))
165
- save_tensor_image(vis_, out_filename)
166
- return scene_results
167
-
168
- if cfg.save_mode != 'rgb':
169
- raise ValueError(f"Unsupported save_mode for two-stage workflow: {cfg.save_mode}")
170
-
171
- for i in range(bsz // num_views):
172
- scene = get_scene_name(batch, i)
173
- normals, colors = [], []
174
-
175
- for j in range(num_views):
176
- idx = i * num_views + j
177
- normal = normals_pred[idx]
178
- if j == 0:
179
- color = imgs_in[i * num_views].to(out.device)
180
- else:
181
- color = images_pred[idx]
182
-
183
- if j in [3, 4]:
184
- normal = torch.flip(normal, dims=[2])
185
- color = torch.flip(color, dims=[2])
186
-
187
- colors.append(color)
188
- if j == 6:
189
- normal = F.interpolate(normal.unsqueeze(0), size=(256, 256), mode='bilinear', align_corners=False).squeeze(0)
190
- normals.append(normal)
191
 
192
- normals[0][:, :256, 256:512] = normals[-1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
- color_pils = [ensure_rgba(remove(convert_to_pil(tensor), session=session)) for tensor in colors[:6]]
195
- normal_pils = [ensure_rgba(remove(convert_to_pil(tensor), session=session)) for tensor in normals[:6]]
196
 
197
- meta = None
198
- if cfg.save_multiview_metadata:
199
- meta = {
200
- "scene": scene,
201
- "case_id": case_id,
202
- "num_colors": len(color_pils),
203
- "num_normals": len(normal_pils),
204
- "seed": cfg.seed,
205
- "run_mode": cfg.run_mode,
206
- "crop_size": cfg.validation_dataset.crop_size,
207
- "with_smpl": cfg.with_smpl,
208
- }
209
 
210
- scene_results.append((scene, color_pils, normal_pils, meta))
211
 
212
- return scene_results
213
 
 
 
 
214
 
215
  def run_inference(dataloader, econdata, pipeline, carving, cfg: TestConfig, save_dir):
216
- if pipeline is not None:
217
- pipeline.set_progress_bar_config(disable=True)
218
 
219
  if cfg.seed is None:
220
  generator = None
221
  else:
222
- device = pipeline.unet.device if pipeline is not None else "cuda"
223
- generator = torch.Generator(device=device).manual_seed(cfg.seed)
 
224
 
225
  for case_id, batch in tqdm(enumerate(dataloader)):
 
 
 
226
  if cfg.run_mode == "reconstruct":
227
- batch_size = len(batch['filename'])
228
- for i in range(batch_size):
229
- scene = get_scene_name(batch, i)
230
- colors, normals = load_multiview_scene(
231
- cfg.multiview_tmp_dir,
232
- scene,
233
- prefer_edit=cfg.prefer_edited_views,
234
- )
235
- pose = econdata.__getitem__(case_id + i)
236
- carving.optimize_case(scene, pose, colors, normals)
237
- torch.cuda.empty_cache()
238
  continue
239
 
240
- images_cond = batch['imgs_in'][:, 0]
241
  imgs_in = torch.cat([batch['imgs_in']] * 2, dim=0)
242
  num_views = imgs_in.shape[1]
243
  imgs_in = rearrange(imgs_in, "B Nv C H W -> (B Nv) C H W")
@@ -248,8 +265,7 @@ def run_inference(dataloader, econdata, pipeline, carving, cfg: TestConfig, save
248
  else:
249
  smpl_in = None
250
 
251
- normal_prompt_embeddings = batch['normal_prompt_embeddings']
252
- clr_prompt_embeddings = batch['color_prompt_embeddings']
253
  prompt_embeddings = torch.cat([normal_prompt_embeddings, clr_prompt_embeddings], dim=0)
254
  prompt_embeddings = rearrange(prompt_embeddings, "B Nv N C -> (B Nv) N C")
255
 
@@ -265,54 +281,68 @@ def run_inference(dataloader, econdata, pipeline, carving, cfg: TestConfig, save
265
  guidance_scale=guidance_scale,
266
  output_type='pt',
267
  num_images_per_prompt=1,
268
- **cfg.pipe_validation_kwargs,
269
  )
270
 
271
  out = unet_out.images
272
  bsz = out.shape[0] // 2
 
273
  normals_pred = out[:bsz]
274
  images_pred = out[bsz:]
275
 
276
- scene_results = prepare_scene_views(
277
- batch=batch,
278
- imgs_in=imgs_in,
279
- normals_pred=normals_pred,
280
- images_pred=images_pred,
281
- out=out,
282
- cfg=cfg,
283
- save_dir=save_dir,
284
- images_cond=images_cond,
285
- case_id=case_id,
286
- )
287
-
288
- if cfg.save_mode == 'concat':
289
- continue
290
-
291
- for i, (scene, colors, normals, meta) in enumerate(scene_results):
292
- if cfg.run_mode == "generate":
293
- save_multiview_scene(cfg.multiview_tmp_dir, scene, colors, normals, meta=meta)
294
- print(f"[PSHuman] Saved multiview scene '{scene}' to {get_scene_dir(cfg.multiview_tmp_dir, scene)}")
 
 
 
 
 
 
 
295
  continue
296
 
297
- pose = econdata.__getitem__(case_id + i)
298
- carving.optimize_case(scene, pose, colors, normals)
299
- torch.cuda.empty_cache()
 
 
 
 
 
 
300
 
 
 
 
301
 
302
-
303
- def load_pshuman_pipeline(cfg):
304
- pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(cfg.pretrained_model_name_or_path, torch_dtype=weight_dtype)
305
- pipeline.unet.enable_xformers_memory_efficient_attention()
306
- if torch.cuda.is_available():
307
- pipeline.to('cuda')
308
- return pipeline
309
-
310
 
311
 
312
  def main(cfg: TestConfig):
313
  if cfg.seed is not None:
314
  set_seed(cfg.seed)
315
 
 
316
  pipeline = None if cfg.run_mode == "reconstruct" else load_pshuman_pipeline(cfg)
317
 
318
  if cfg.with_smpl:
@@ -325,7 +355,7 @@ def main(cfg: TestConfig):
325
  validation_dataset,
326
  batch_size=cfg.validation_batch_size,
327
  shuffle=False,
328
- num_workers=cfg.dataloader_num_workers,
329
  )
330
 
331
  dataset_param = {
@@ -333,14 +363,11 @@ def main(cfg: TestConfig):
333
  'seg_dir': None,
334
  'colab': False,
335
  'has_det': True,
336
- 'hps_type': 'pixie',
337
  }
338
  econdata = SMPLDataset(dataset_param, device='cuda')
339
  carving = ReMesh(cfg.recon_opt, econ_dataset=econdata)
340
 
341
- if cfg.run_mode in {"generate", "reconstruct"} and not cfg.multiview_tmp_dir:
342
- raise ValueError("multiview_tmp_dir must be provided for run_mode='generate' or 'reconstruct'.")
343
-
344
  run_inference(validation_dataloader, econdata, pipeline, carving, cfg, cfg.save_dir)
345
 
346
 
@@ -348,6 +375,7 @@ if __name__ == '__main__':
348
  parser = argparse.ArgumentParser()
349
  parser.add_argument('--config', type=str, required=True)
350
  args, extras = parser.parse_known_args()
 
351
  from utils.misc import load_config
352
 
353
  cfg = load_config(args.config, cli_args=extras)
 
1
  import argparse
 
2
  import os
3
+ import shutil
4
  from pathlib import Path
5
+ from typing import Dict, Optional, List, Tuple
6
+ from collections import defaultdict
7
+ from dataclasses import dataclass
8
+
9
  from omegaconf import OmegaConf
10
  from PIL import Image
11
+
 
12
  import torch
13
  import torch.utils.checkpoint
14
+ import torch.nn.functional as F
15
  from torchvision.utils import make_grid
16
  from accelerate.utils import set_seed
17
  from tqdm.auto import tqdm
 
18
  from einops import rearrange
19
  from rembg import remove, new_session
20
+
21
  from mvdiffusion.pipelines.pipeline_mvdiffusion_unclip import StableUnCLIPImg2ImgPipeline
22
  from econdataset import SMPLDataset
23
  from reconstruct import ReMesh
24
 
25
+
26
  providers = [
27
  ('CUDAExecutionProvider', {
28
  'device_id': 0,
 
36
  weight_dtype = torch.float16
37
 
38
 
39
+ # ============================================================
40
+ # Config
41
+ # ============================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  @dataclass
44
  class TestConfig:
 
61
  num_views: int
62
  enable_xformers_memory_efficient_attention: bool
63
  with_smpl: Optional[bool]
 
64
  recon_opt: Dict
65
 
66
+ # New two-stage fields
67
+ run_mode: str = "full" # full | generate | reconstruct
68
+ multiview_tmp_dir: str = "./multiview"
69
  prefer_edited_views: bool = True
 
70
 
71
 
72
+ # ============================================================
73
+ # Image helpers
74
+ # ============================================================
75
+
76
+ def convert_to_numpy(tensor):
77
+ return tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
78
 
79
 
80
+ def convert_to_pil(tensor):
81
+ return Image.fromarray(convert_to_numpy(tensor))
82
 
83
 
84
+ def save_image_numpy(ndarr, fp):
85
+ im = Image.fromarray(ndarr)
86
+ im.save(fp)
87
 
88
 
89
+ def save_image_tensor(tensor, fp):
90
+ ndarr = convert_to_numpy(tensor)
91
+ save_image_numpy(ndarr, fp)
92
+ return ndarr
93
+
94
+
95
+ # ============================================================
96
+ # Multiview storage helpers
97
+ # ============================================================
98
+
99
+ def ensure_dir(path: Path):
100
+ path.mkdir(parents=True, exist_ok=True)
101
+
102
+
103
+ def save_multiview_scene(multiview_root: str, scene: str, colors: List[Image.Image], normals: List[Image.Image]):
104
+ scene_dir = Path(multiview_root) / scene
105
  raw_dir = scene_dir / "raw"
106
  edit_dir = scene_dir / "edit"
107
+
108
+ ensure_dir(raw_dir)
109
+ ensure_dir(edit_dir)
110
+
111
+ # Clean previous files to avoid stale leftovers
112
+ for folder in (raw_dir, edit_dir):
113
+ for p in folder.glob("*"):
114
+ if p.is_file():
115
+ p.unlink()
116
 
117
  for idx, img in enumerate(colors):
118
+ raw_color = raw_dir / f"color_{idx:02d}.png"
119
+ edit_color = edit_dir / f"color_{idx:02d}.png"
120
+ img.save(raw_color)
121
+ img.save(edit_color)
122
 
123
  for idx, img in enumerate(normals):
124
+ raw_normal = raw_dir / f"normal_{idx:02d}.png"
125
+ edit_normal = edit_dir / f"normal_{idx:02d}.png"
126
+ img.save(raw_normal)
127
+ img.save(edit_normal)
128
+
129
+ meta = {
130
+ "scene": scene,
131
+ "num_colors": len(colors),
132
+ "num_normals": len(normals),
133
+ "source": "PSHuman two-stage inference",
134
+ }
135
+ with open(scene_dir / "meta.json", "w", encoding="utf-8") as f:
136
+ import json
137
+ json.dump(meta, f, indent=2)
138
+
139
+
140
+ def load_multiview_scene(multiview_root: str, scene: str, prefer_edit=True) -> Tuple[List[Image.Image], List[Image.Image]]:
141
+ scene_dir = Path(multiview_root) / scene
142
+ preferred = scene_dir / ("edit" if prefer_edit else "raw")
143
+ fallback = scene_dir / ("raw" if prefer_edit else "edit")
144
+
145
+ base_dir = preferred if preferred.exists() else fallback
146
+ if not base_dir.exists():
147
+ raise FileNotFoundError(f"Kein Multiview-Ordner für Szene '{scene}' gefunden: {preferred}")
148
+
149
+ color_paths = sorted(base_dir.glob("color_*.png"))
150
+ normal_paths = sorted(base_dir.glob("normal_*.png"))
151
+
152
+ if not color_paths:
153
+ raise FileNotFoundError(f"Keine Color-Bilder gefunden in: {base_dir}")
154
+ if not normal_paths:
155
+ raise FileNotFoundError(f"Keine Normalmaps gefunden in: {base_dir}")
156
+
157
+ colors = [Image.open(p).convert("RGBA") for p in color_paths]
158
+ normals = [Image.open(p).convert("RGBA") for p in normal_paths]
159
  return colors, normals
160
 
161
 
162
+ # ============================================================
163
+ # Pipeline helpers
164
+ # ============================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
+ def load_pshuman_pipeline(cfg):
167
+ pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(
168
+ cfg.pretrained_model_name_or_path,
169
+ torch_dtype=weight_dtype
170
+ )
171
+ pipeline.unet.enable_xformers_memory_efficient_attention()
172
+ if torch.cuda.is_available():
173
+ pipeline.to('cuda')
174
+ return pipeline
175
+
176
+
177
+ def extract_scene_views_for_case(
178
+ batch,
179
+ out,
180
+ imgs_in,
181
+ i: int,
182
+ num_views: int,
183
+ ):
184
+ normals_pred = out[: out.shape[0] // 2]
185
+ images_pred = out[out.shape[0] // 2:]
186
+
187
+ scene = batch['filename'][i].split('.')[0]
188
+
189
+ normals, colors = [], []
190
+
191
+ for j in range(num_views):
192
+ idx = i * num_views + j
193
+ normal = normals_pred[idx]
194
+
195
+ # Fix from original code: use scene-local first input image
196
+ if j == 0:
197
+ color = imgs_in[i * num_views].to(out.device)
198
+ else:
199
+ color = images_pred[idx]
200
+
201
+ if j in [3, 4]:
202
+ normal = torch.flip(normal, dims=[2])
203
+ color = torch.flip(color, dims=[2])
204
+
205
+ colors.append(color)
206
+
207
+ if j == 6:
208
+ normal = F.interpolate(
209
+ normal.unsqueeze(0),
210
+ size=(256, 256),
211
+ mode='bilinear',
212
+ align_corners=False
213
+ ).squeeze(0)
214
 
215
+ normals.append(normal)
 
216
 
217
+ # Preserve original PSHuman behavior
218
+ if len(normals) >= 2:
219
+ normals[0][:, :256, 256:512] = normals[-1]
220
+
221
+ # Original code keeps first 6 views only
222
+ colors_pil = [remove(convert_to_pil(tensor), session=session) for tensor in colors[:6]]
223
+ normals_pil = [remove(convert_to_pil(tensor), session=session) for tensor in normals[:6]]
 
 
 
 
 
224
 
225
+ return scene, colors_pil, normals_pil
226
 
 
227
 
228
+ # ============================================================
229
+ # Main inference logic
230
+ # ============================================================
231
 
232
  def run_inference(dataloader, econdata, pipeline, carving, cfg: TestConfig, save_dir):
233
+ pipeline.set_progress_bar_config(disable=True)
 
234
 
235
  if cfg.seed is None:
236
  generator = None
237
  else:
238
+ generator = torch.Generator(device=pipeline.unet.device).manual_seed(cfg.seed)
239
+
240
+ images_cond, pred_cat = [], defaultdict(list)
241
 
242
  for case_id, batch in tqdm(enumerate(dataloader)):
243
+ images_cond.append(batch['imgs_in'][:, 0])
244
+
245
+ # Reconstruct-only path: skip diffusion, load saved views instead
246
  if cfg.run_mode == "reconstruct":
247
+ scene = batch['filename'][0].split('.')[0]
248
+ colors, normals = load_multiview_scene(
249
+ cfg.multiview_tmp_dir,
250
+ scene,
251
+ prefer_edit=cfg.prefer_edited_views
252
+ )
253
+ pose = econdata.__getitem__(case_id)
254
+ carving.optimize_case(scene, pose, colors, normals)
255
+ torch.cuda.empty_cache()
 
 
256
  continue
257
 
 
258
  imgs_in = torch.cat([batch['imgs_in']] * 2, dim=0)
259
  num_views = imgs_in.shape[1]
260
  imgs_in = rearrange(imgs_in, "B Nv C H W -> (B Nv) C H W")
 
265
  else:
266
  smpl_in = None
267
 
268
+ normal_prompt_embeddings, clr_prompt_embeddings = batch['normal_prompt_embeddings'], batch['color_prompt_embeddings']
 
269
  prompt_embeddings = torch.cat([normal_prompt_embeddings, clr_prompt_embeddings], dim=0)
270
  prompt_embeddings = rearrange(prompt_embeddings, "B Nv N C -> (B Nv) N C")
271
 
 
281
  guidance_scale=guidance_scale,
282
  output_type='pt',
283
  num_images_per_prompt=1,
284
+ **cfg.pipe_validation_kwargs
285
  )
286
 
287
  out = unet_out.images
288
  bsz = out.shape[0] // 2
289
+
290
  normals_pred = out[:bsz]
291
  images_pred = out[bsz:]
292
 
293
+ if cfg.save_mode == 'concat':
294
+ pred_cat[f"cfg{guidance_scale:.1f}"].append(torch.cat([normals_pred, images_pred], dim=-1))
295
+ cur_dir = os.path.join(
296
+ save_dir,
297
+ f"cropsize-{cfg.validation_dataset.crop_size}-cfg{guidance_scale:.1f}-seed{cfg.seed}-smpl-{cfg.with_smpl}"
298
+ )
299
+ os.makedirs(cur_dir, exist_ok=True)
300
+
301
+ for i in range(bsz // num_views):
302
+ scene = batch['filename'][i].split('.')[0]
303
+ img_in_ = images_cond[-1][i].to(out.device)
304
+ vis_ = [img_in_]
305
+
306
+ for j in range(num_views):
307
+ idx = i * num_views + j
308
+ normal = normals_pred[idx]
309
+ color = images_pred[idx]
310
+ vis_.append(color)
311
+ vis_.append(normal)
312
+
313
+ out_filename = f"{cur_dir}/{scene}.png"
314
+ vis_ = torch.stack(vis_, dim=0)
315
+ vis_ = make_grid(vis_, nrow=len(vis_), padding=0, value_range=(0, 1))
316
+ save_image_tensor(vis_, out_filename)
317
+
318
+ # concat mode is only for legacy visualization
319
  continue
320
 
321
+ elif cfg.save_mode == 'rgb':
322
+ for i in range(bsz // num_views):
323
+ scene, colors, normals = extract_scene_views_for_case(
324
+ batch=batch,
325
+ out=out,
326
+ imgs_in=imgs_in,
327
+ i=i,
328
+ num_views=num_views,
329
+ )
330
 
331
+ if cfg.run_mode == "generate":
332
+ save_multiview_scene(cfg.multiview_tmp_dir, scene, colors, normals)
333
+ continue
334
 
335
+ # full mode: original one-pass behavior
336
+ pose = econdata.__getitem__(case_id)
337
+ carving.optimize_case(scene, pose, colors, normals)
338
+ torch.cuda.empty_cache()
 
 
 
 
339
 
340
 
341
  def main(cfg: TestConfig):
342
  if cfg.seed is not None:
343
  set_seed(cfg.seed)
344
 
345
+ # Reconstruct mode does not need the diffusion pipeline at all
346
  pipeline = None if cfg.run_mode == "reconstruct" else load_pshuman_pipeline(cfg)
347
 
348
  if cfg.with_smpl:
 
355
  validation_dataset,
356
  batch_size=cfg.validation_batch_size,
357
  shuffle=False,
358
+ num_workers=cfg.dataloader_num_workers
359
  )
360
 
361
  dataset_param = {
 
363
  'seg_dir': None,
364
  'colab': False,
365
  'has_det': True,
366
+ 'hps_type': 'pixie'
367
  }
368
  econdata = SMPLDataset(dataset_param, device='cuda')
369
  carving = ReMesh(cfg.recon_opt, econ_dataset=econdata)
370
 
 
 
 
371
  run_inference(validation_dataloader, econdata, pipeline, carving, cfg, cfg.save_dir)
372
 
373
 
 
375
  parser = argparse.ArgumentParser()
376
  parser.add_argument('--config', type=str, required=True)
377
  args, extras = parser.parse_known_args()
378
+
379
  from utils.misc import load_config
380
 
381
  cfg = load_config(args.config, cli_args=extras)