Spaces:
Runtime error
Runtime error
query colors from RGB image
Browse files- apps/infer.py +33 -46
apps/infer.py
CHANGED
|
@@ -33,6 +33,7 @@ from apps.Normal import Normal
|
|
| 33 |
from apps.IFGeo import IFGeo
|
| 34 |
from pytorch3d.ops import SubdivideMeshes
|
| 35 |
from lib.common.config import cfg
|
|
|
|
| 36 |
from lib.common.train_util import init_loss, load_normal_networks, load_networks
|
| 37 |
from lib.common.BNI import BNI
|
| 38 |
from lib.common.BNI_utils import save_normal_tensor
|
|
@@ -93,14 +94,14 @@ if __name__ == "__main__":
|
|
| 93 |
"vol_res": cfg.vol_res,
|
| 94 |
"single": args.multi,
|
| 95 |
}
|
| 96 |
-
|
| 97 |
if cfg.bni.use_ifnet:
|
| 98 |
print(colored("Use IF-Nets (Implicit)+ for completion", "green"))
|
| 99 |
else:
|
| 100 |
print(colored("Use SMPL-X (Explicit) for completion", "green"))
|
| 101 |
|
| 102 |
dataset = TestDataset(dataset_param, device)
|
| 103 |
-
|
| 104 |
print(colored(f"Dataset Size: {len(dataset)}", "green"))
|
| 105 |
|
| 106 |
pbar = tqdm(dataset)
|
|
@@ -130,11 +131,7 @@ if __name__ == "__main__":
|
|
| 130 |
|
| 131 |
os.makedirs(osp.join(args.out_dir, cfg.name, "obj"), exist_ok=True)
|
| 132 |
|
| 133 |
-
in_tensor = {
|
| 134 |
-
"smpl_faces": data["smpl_faces"],
|
| 135 |
-
"image": data["img_icon"].to(device),
|
| 136 |
-
"mask": data["img_mask"].to(device)
|
| 137 |
-
}
|
| 138 |
|
| 139 |
# The optimizer and variables
|
| 140 |
optimed_pose = data["body_pose"].requires_grad_(True)
|
|
@@ -158,7 +155,7 @@ if __name__ == "__main__":
|
|
| 158 |
N_body, N_pose = optimed_pose.shape[:2]
|
| 159 |
|
| 160 |
smpl_path = f"{args.out_dir}/{cfg.name}/obj/{data['name']}_smpl_00.obj"
|
| 161 |
-
|
| 162 |
if osp.exists(smpl_path):
|
| 163 |
|
| 164 |
smpl_verts_lst = []
|
|
@@ -183,7 +180,7 @@ if __name__ == "__main__":
|
|
| 183 |
|
| 184 |
in_tensor["smpl_verts"] = batch_smpl_verts * torch.tensor([1., -1., 1.]).to(device)
|
| 185 |
in_tensor["smpl_faces"] = batch_smpl_faces[:, :, [0, 2, 1]]
|
| 186 |
-
|
| 187 |
else:
|
| 188 |
# smpl optimization
|
| 189 |
loop_smpl = tqdm(range(args.loop_smpl))
|
|
@@ -252,16 +249,14 @@ if __name__ == "__main__":
|
|
| 252 |
|
| 253 |
# BUG: PyTorch3D silhouette renderer generates dilated mask
|
| 254 |
bg_value = in_tensor["T_normal_F"][0, 0, 0, 0]
|
| 255 |
-
smpl_arr_fake = torch.cat(
|
| 256 |
-
|
| 257 |
-
dim=-1)
|
| 258 |
|
| 259 |
body_overlap = (gt_arr * smpl_arr_fake.gt(0.0)).sum(dim=[1, 2]) / smpl_arr_fake.gt(0.0).sum(dim=[1, 2])
|
| 260 |
body_overlap_mask = (gt_arr * smpl_arr_fake).unsqueeze(1)
|
| 261 |
body_overlap_flag = body_overlap < cfg.body_overlap_thres
|
| 262 |
|
| 263 |
-
losses["normal"]["value"] = (diff_F_smpl * body_overlap_mask[..., :512] +
|
| 264 |
-
diff_B_smpl * body_overlap_mask[..., 512:]).mean() / 2.0
|
| 265 |
|
| 266 |
losses["silhouette"]["weight"] = [0 if flag else 1.0 for flag in body_overlap_flag]
|
| 267 |
occluded_idx = torch.where(body_overlap_flag)[0]
|
|
@@ -308,18 +303,15 @@ if __name__ == "__main__":
|
|
| 308 |
|
| 309 |
img_crop_path = osp.join(args.out_dir, cfg.name, "png", f"{data['name']}_crop.png")
|
| 310 |
torchvision.utils.save_image(
|
| 311 |
-
torch.cat(
|
| 312 |
-
data["img_crop"][:, :3], (in_tensor['normal_F'].detach().cpu() + 1.0) * 0.5,
|
| 313 |
-
|
| 314 |
-
],
|
| 315 |
-
dim=3), img_crop_path)
|
| 316 |
|
| 317 |
rgb_norm_F = blend_rgb_norm(in_tensor["normal_F"], data)
|
| 318 |
rgb_norm_B = blend_rgb_norm(in_tensor["normal_B"], data)
|
| 319 |
|
| 320 |
img_overlap_path = osp.join(args.out_dir, cfg.name, f"png/{data['name']}_overlap.png")
|
| 321 |
-
torchvision.utils.save_image(
|
| 322 |
-
torch.Tensor([data["img_raw"], rgb_norm_F, rgb_norm_B]).permute(0, 3, 1, 2) / 255., img_overlap_path)
|
| 323 |
|
| 324 |
smpl_obj_lst = []
|
| 325 |
|
|
@@ -397,12 +389,7 @@ if __name__ == "__main__":
|
|
| 397 |
)
|
| 398 |
|
| 399 |
# BNI process
|
| 400 |
-
BNI_object = BNI(
|
| 401 |
-
dir_path=osp.join(args.out_dir, cfg.name, "BNI"),
|
| 402 |
-
name=data["name"],
|
| 403 |
-
BNI_dict=BNI_dict,
|
| 404 |
-
cfg=cfg.bni,
|
| 405 |
-
device=device)
|
| 406 |
|
| 407 |
BNI_object.extract_surface(False)
|
| 408 |
|
|
@@ -419,16 +406,11 @@ if __name__ == "__main__":
|
|
| 419 |
side_mesh = apply_face_mask(side_mesh, ~SMPLX_object.smplx_eyeball_fid_mask)
|
| 420 |
|
| 421 |
# mesh completion via IF-net
|
| 422 |
-
in_tensor.update(
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
occupancies = VoxelGrid.from_mesh(
|
| 429 |
-
side_mesh, cfg.vol_res, loc=[
|
| 430 |
-
0,
|
| 431 |
-
] * 3, scale=2.0).data.transpose(2, 1, 0)
|
| 432 |
occupancies = np.flip(occupancies, axis=1)
|
| 433 |
|
| 434 |
in_tensor["body_voxels"] = torch.tensor(occupancies.copy()).float().unsqueeze(0).to(device)
|
|
@@ -446,10 +428,9 @@ if __name__ == "__main__":
|
|
| 446 |
else:
|
| 447 |
side_mesh = apply_vertex_mask(
|
| 448 |
side_mesh,
|
| 449 |
-
(SMPLX_object.front_flame_vertex_mask + SMPLX_object.mano_vertex_mask +
|
| 450 |
-
SMPLX_object.eyeball_vertex_mask).eq(0).float(),
|
| 451 |
)
|
| 452 |
-
|
| 453 |
#register side_mesh to BNI surfaces
|
| 454 |
side_mesh = Meshes(
|
| 455 |
verts=[torch.tensor(side_mesh.vertices).float()],
|
|
@@ -458,7 +439,6 @@ if __name__ == "__main__":
|
|
| 458 |
sm = SubdivideMeshes(side_mesh)
|
| 459 |
side_mesh = register(BNI_object.F_B_trimesh, sm(side_mesh), device)
|
| 460 |
|
| 461 |
-
|
| 462 |
side_verts = torch.tensor(side_mesh.vertices).float().to(device)
|
| 463 |
side_faces = torch.tensor(side_mesh.faces).long().to(device)
|
| 464 |
|
|
@@ -469,7 +449,6 @@ if __name__ == "__main__":
|
|
| 469 |
|
| 470 |
# export intermediate meshes
|
| 471 |
BNI_object.F_B_trimesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_BNI.obj")
|
| 472 |
-
|
| 473 |
full_lst = []
|
| 474 |
|
| 475 |
if "face" in cfg.bni.use_smpl:
|
|
@@ -479,8 +458,7 @@ if __name__ == "__main__":
|
|
| 479 |
face_mesh.vertices = face_mesh.vertices - np.array([0, 0, cfg.bni.thickness])
|
| 480 |
|
| 481 |
# remove face neighbor triangles
|
| 482 |
-
BNI_object.F_B_trimesh = part_removal(
|
| 483 |
-
BNI_object.F_B_trimesh, face_mesh, cfg.bni.face_thres, device, smplx_mesh, region="face")
|
| 484 |
side_mesh = part_removal(side_mesh, face_mesh, cfg.bni.face_thres, device, smplx_mesh, region="face")
|
| 485 |
face_mesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_face.obj")
|
| 486 |
full_lst += [face_mesh]
|
|
@@ -497,8 +475,7 @@ if __name__ == "__main__":
|
|
| 497 |
hand_mesh = apply_vertex_mask(hand_mesh, hand_mask)
|
| 498 |
|
| 499 |
# remove hand neighbor triangles
|
| 500 |
-
BNI_object.F_B_trimesh = part_removal(
|
| 501 |
-
BNI_object.F_B_trimesh, hand_mesh, cfg.bni.hand_thres, device, smplx_mesh, region="hand")
|
| 502 |
side_mesh = part_removal(side_mesh, hand_mesh, cfg.bni.hand_thres, device, smplx_mesh, region="hand")
|
| 503 |
hand_mesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_hand.obj")
|
| 504 |
full_lst += [hand_mesh]
|
|
@@ -528,6 +505,16 @@ if __name__ == "__main__":
|
|
| 528 |
rotate_recon_lst = dataset.render.get_image(cam_type="four")
|
| 529 |
per_loop_lst.extend([in_tensor['image'][idx:idx + 1]] + rotate_recon_lst)
|
| 530 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 531 |
# for video rendering
|
| 532 |
in_tensor["BNI_verts"].append(torch.tensor(final_mesh.vertices).float())
|
| 533 |
in_tensor["BNI_faces"].append(torch.tensor(final_mesh.faces).long())
|
|
|
|
| 33 |
from apps.IFGeo import IFGeo
|
| 34 |
from pytorch3d.ops import SubdivideMeshes
|
| 35 |
from lib.common.config import cfg
|
| 36 |
+
from lib.common.render import query_color
|
| 37 |
from lib.common.train_util import init_loss, load_normal_networks, load_networks
|
| 38 |
from lib.common.BNI import BNI
|
| 39 |
from lib.common.BNI_utils import save_normal_tensor
|
|
|
|
| 94 |
"vol_res": cfg.vol_res,
|
| 95 |
"single": args.multi,
|
| 96 |
}
|
| 97 |
+
|
| 98 |
if cfg.bni.use_ifnet:
|
| 99 |
print(colored("Use IF-Nets (Implicit)+ for completion", "green"))
|
| 100 |
else:
|
| 101 |
print(colored("Use SMPL-X (Explicit) for completion", "green"))
|
| 102 |
|
| 103 |
dataset = TestDataset(dataset_param, device)
|
| 104 |
+
|
| 105 |
print(colored(f"Dataset Size: {len(dataset)}", "green"))
|
| 106 |
|
| 107 |
pbar = tqdm(dataset)
|
|
|
|
| 131 |
|
| 132 |
os.makedirs(osp.join(args.out_dir, cfg.name, "obj"), exist_ok=True)
|
| 133 |
|
| 134 |
+
in_tensor = {"smpl_faces": data["smpl_faces"], "image": data["img_icon"].to(device), "mask": data["img_mask"].to(device)}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
# The optimizer and variables
|
| 137 |
optimed_pose = data["body_pose"].requires_grad_(True)
|
|
|
|
| 155 |
N_body, N_pose = optimed_pose.shape[:2]
|
| 156 |
|
| 157 |
smpl_path = f"{args.out_dir}/{cfg.name}/obj/{data['name']}_smpl_00.obj"
|
| 158 |
+
|
| 159 |
if osp.exists(smpl_path):
|
| 160 |
|
| 161 |
smpl_verts_lst = []
|
|
|
|
| 180 |
|
| 181 |
in_tensor["smpl_verts"] = batch_smpl_verts * torch.tensor([1., -1., 1.]).to(device)
|
| 182 |
in_tensor["smpl_faces"] = batch_smpl_faces[:, :, [0, 2, 1]]
|
| 183 |
+
|
| 184 |
else:
|
| 185 |
# smpl optimization
|
| 186 |
loop_smpl = tqdm(range(args.loop_smpl))
|
|
|
|
| 249 |
|
| 250 |
# BUG: PyTorch3D silhouette renderer generates dilated mask
|
| 251 |
bg_value = in_tensor["T_normal_F"][0, 0, 0, 0]
|
| 252 |
+
smpl_arr_fake = torch.cat([in_tensor["T_normal_F"][:, 0].ne(bg_value).float(), in_tensor["T_normal_B"][:, 0].ne(bg_value).float()],
|
| 253 |
+
dim=-1)
|
|
|
|
| 254 |
|
| 255 |
body_overlap = (gt_arr * smpl_arr_fake.gt(0.0)).sum(dim=[1, 2]) / smpl_arr_fake.gt(0.0).sum(dim=[1, 2])
|
| 256 |
body_overlap_mask = (gt_arr * smpl_arr_fake).unsqueeze(1)
|
| 257 |
body_overlap_flag = body_overlap < cfg.body_overlap_thres
|
| 258 |
|
| 259 |
+
losses["normal"]["value"] = (diff_F_smpl * body_overlap_mask[..., :512] + diff_B_smpl * body_overlap_mask[..., 512:]).mean() / 2.0
|
|
|
|
| 260 |
|
| 261 |
losses["silhouette"]["weight"] = [0 if flag else 1.0 for flag in body_overlap_flag]
|
| 262 |
occluded_idx = torch.where(body_overlap_flag)[0]
|
|
|
|
| 303 |
|
| 304 |
img_crop_path = osp.join(args.out_dir, cfg.name, "png", f"{data['name']}_crop.png")
|
| 305 |
torchvision.utils.save_image(
|
| 306 |
+
torch.cat(
|
| 307 |
+
[data["img_crop"][:, :3], (in_tensor['normal_F'].detach().cpu() + 1.0) * 0.5, (in_tensor['normal_B'].detach().cpu() + 1.0) * 0.5],
|
| 308 |
+
dim=3), img_crop_path)
|
|
|
|
|
|
|
| 309 |
|
| 310 |
rgb_norm_F = blend_rgb_norm(in_tensor["normal_F"], data)
|
| 311 |
rgb_norm_B = blend_rgb_norm(in_tensor["normal_B"], data)
|
| 312 |
|
| 313 |
img_overlap_path = osp.join(args.out_dir, cfg.name, f"png/{data['name']}_overlap.png")
|
| 314 |
+
torchvision.utils.save_image(torch.Tensor([data["img_raw"], rgb_norm_F, rgb_norm_B]).permute(0, 3, 1, 2) / 255., img_overlap_path)
|
|
|
|
| 315 |
|
| 316 |
smpl_obj_lst = []
|
| 317 |
|
|
|
|
| 389 |
)
|
| 390 |
|
| 391 |
# BNI process
|
| 392 |
+
BNI_object = BNI(dir_path=osp.join(args.out_dir, cfg.name, "BNI"), name=data["name"], BNI_dict=BNI_dict, cfg=cfg.bni, device=device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 393 |
|
| 394 |
BNI_object.extract_surface(False)
|
| 395 |
|
|
|
|
| 406 |
side_mesh = apply_face_mask(side_mesh, ~SMPLX_object.smplx_eyeball_fid_mask)
|
| 407 |
|
| 408 |
# mesh completion via IF-net
|
| 409 |
+
in_tensor.update(dataset.depth_to_voxel({"depth_F": BNI_object.F_depth.unsqueeze(0), "depth_B": BNI_object.B_depth.unsqueeze(0)}))
|
| 410 |
+
|
| 411 |
+
occupancies = VoxelGrid.from_mesh(side_mesh, cfg.vol_res, loc=[
|
| 412 |
+
0,
|
| 413 |
+
] * 3, scale=2.0).data.transpose(2, 1, 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 414 |
occupancies = np.flip(occupancies, axis=1)
|
| 415 |
|
| 416 |
in_tensor["body_voxels"] = torch.tensor(occupancies.copy()).float().unsqueeze(0).to(device)
|
|
|
|
| 428 |
else:
|
| 429 |
side_mesh = apply_vertex_mask(
|
| 430 |
side_mesh,
|
| 431 |
+
(SMPLX_object.front_flame_vertex_mask + SMPLX_object.mano_vertex_mask + SMPLX_object.eyeball_vertex_mask).eq(0).float(),
|
|
|
|
| 432 |
)
|
| 433 |
+
|
| 434 |
#register side_mesh to BNI surfaces
|
| 435 |
side_mesh = Meshes(
|
| 436 |
verts=[torch.tensor(side_mesh.vertices).float()],
|
|
|
|
| 439 |
sm = SubdivideMeshes(side_mesh)
|
| 440 |
side_mesh = register(BNI_object.F_B_trimesh, sm(side_mesh), device)
|
| 441 |
|
|
|
|
| 442 |
side_verts = torch.tensor(side_mesh.vertices).float().to(device)
|
| 443 |
side_faces = torch.tensor(side_mesh.faces).long().to(device)
|
| 444 |
|
|
|
|
| 449 |
|
| 450 |
# export intermediate meshes
|
| 451 |
BNI_object.F_B_trimesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_BNI.obj")
|
|
|
|
| 452 |
full_lst = []
|
| 453 |
|
| 454 |
if "face" in cfg.bni.use_smpl:
|
|
|
|
| 458 |
face_mesh.vertices = face_mesh.vertices - np.array([0, 0, cfg.bni.thickness])
|
| 459 |
|
| 460 |
# remove face neighbor triangles
|
| 461 |
+
BNI_object.F_B_trimesh = part_removal(BNI_object.F_B_trimesh, face_mesh, cfg.bni.face_thres, device, smplx_mesh, region="face")
|
|
|
|
| 462 |
side_mesh = part_removal(side_mesh, face_mesh, cfg.bni.face_thres, device, smplx_mesh, region="face")
|
| 463 |
face_mesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_face.obj")
|
| 464 |
full_lst += [face_mesh]
|
|
|
|
| 475 |
hand_mesh = apply_vertex_mask(hand_mesh, hand_mask)
|
| 476 |
|
| 477 |
# remove hand neighbor triangles
|
| 478 |
+
BNI_object.F_B_trimesh = part_removal(BNI_object.F_B_trimesh, hand_mesh, cfg.bni.hand_thres, device, smplx_mesh, region="hand")
|
|
|
|
| 479 |
side_mesh = part_removal(side_mesh, hand_mesh, cfg.bni.hand_thres, device, smplx_mesh, region="hand")
|
| 480 |
hand_mesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_hand.obj")
|
| 481 |
full_lst += [hand_mesh]
|
|
|
|
| 505 |
rotate_recon_lst = dataset.render.get_image(cam_type="four")
|
| 506 |
per_loop_lst.extend([in_tensor['image'][idx:idx + 1]] + rotate_recon_lst)
|
| 507 |
|
| 508 |
+
# coloring the final mesh
|
| 509 |
+
final_colors = query_color(
|
| 510 |
+
torch.tensor(final_mesh.vertices).float(),
|
| 511 |
+
torch.tensor(final_mesh.faces).long(),
|
| 512 |
+
in_tensor["image"][idx:idx + 1],
|
| 513 |
+
device=device,
|
| 514 |
+
)
|
| 515 |
+
final_mesh.visual.vertex_colors = final_colors
|
| 516 |
+
final_mesh.export(final_path)
|
| 517 |
+
|
| 518 |
# for video rendering
|
| 519 |
in_tensor["BNI_verts"].append(torch.tensor(final_mesh.vertices).float())
|
| 520 |
in_tensor["BNI_faces"].append(torch.tensor(final_mesh.faces).long())
|