Upload 115 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- PixHtLab-Src/.gitignore +5 -0
- PixHtLab-Src/Data/data_prepare.py +21 -0
- PixHtLab-Src/Demo/PixhtLab/Demo.ipynb +0 -0
- PixHtLab-Src/Demo/PixhtLab/Examples/009_depth.npy +3 -0
- PixHtLab-Src/Demo/PixhtLab/Examples/009_depth_valid_mask.npy +3 -0
- PixHtLab-Src/Demo/PixhtLab/Examples/009_pixht_new.npy +3 -0
- PixHtLab-Src/Demo/PixhtLab/Examples/010_depth.npy +3 -0
- PixHtLab-Src/Demo/PixhtLab/Examples/010_depth_valid_mask.npy +3 -0
- PixHtLab-Src/Demo/PixhtLab/Examples/010_pixht_new.npy +3 -0
- PixHtLab-Src/Demo/PixhtLab/Examples/011_depth.npy +3 -0
- PixHtLab-Src/Demo/PixhtLab/Examples/011_depth_valid_mask.npy +3 -0
- PixHtLab-Src/Demo/PixhtLab/Examples/011_pixht_new.npy +3 -0
- PixHtLab-Src/Demo/PixhtLab/Examples/c82c09fc01e84282bc8870c263dcf81b_bg.jpg +3 -0
- PixHtLab-Src/Demo/PixhtLab/GSSN/__init__.py +0 -0
- PixHtLab-Src/Demo/PixhtLab/GSSN/inference_shadow.py +70 -0
- PixHtLab-Src/Demo/PixhtLab/ShadowStyle/__init__.py +0 -0
- PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/__init__.py +0 -0
- PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/inference_shadow.py +70 -0
- PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/params.py +92 -0
- PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/ssn/__init__.py +0 -0
- PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/ssn/perturb_touch.py +24 -0
- PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/ssn/random_pattern.py +92 -0
- PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/ssn/ssn.py +146 -0
- PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/ssn/ssn_dataset.py +290 -0
- PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/ssn/ssn_submodule.py +282 -0
- PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/ssn/test.py +24 -0
- PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/ssn/test_dataset.py +21 -0
- PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/__init__.py +0 -0
- PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/html.py +61 -0
- PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/html_server.py +9 -0
- PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/imgs +1 -0
- PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/index.html +0 -0
- PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/make_html.py +133 -0
- PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/net_utils.py +70 -0
- PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/tensorboard_utils.py +29 -0
- PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/time_utils.py +6 -0
- PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/utils_file.py +59 -0
- PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/vis_test_results.py +21 -0
- PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/visdom_utils.py +53 -0
- PixHtLab-Src/Demo/PixhtLab/Torch_Render/hshadow_cuda.cpp +98 -0
- PixHtLab-Src/Demo/PixhtLab/Torch_Render/hshadow_cuda_kernel.cu +682 -0
- PixHtLab-Src/Demo/PixhtLab/Torch_Render/plane_visualize.cpp +26 -0
- PixHtLab-Src/Demo/PixhtLab/Torch_Render/plane_visualize_cuda.cu +237 -0
- PixHtLab-Src/Demo/PixhtLab/Torch_Render/setup.py +29 -0
- PixHtLab-Src/Demo/PixhtLab/Torch_Render/test_ground.py +33 -0
- PixHtLab-Src/Demo/PixhtLab/Torch_Render/test_hshadow.py +130 -0
- PixHtLab-Src/Demo/PixhtLab/camera.py +246 -0
- PixHtLab-Src/Demo/PixhtLab/gssn_demo.py +32 -0
- PixHtLab-Src/Demo/PixhtLab/hshadow_render.py +268 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
PixHtLab-Src/Demo/PixhtLab/Examples/c82c09fc01e84282bc8870c263dcf81b_bg.jpg filter=lfs diff=lfs merge=lfs -text
|
PixHtLab-Src/.gitignore
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.DS_Store
|
| 2 |
+
.idea
|
| 3 |
+
*.log
|
| 4 |
+
tmp/
|
| 5 |
+
*__pycache__*
|
PixHtLab-Src/Data/data_prepare.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import objaverse
|
| 2 |
+
import multiprocessing
|
| 3 |
+
import random
|
| 4 |
+
|
| 5 |
+
processes = 8
|
| 6 |
+
|
| 7 |
+
random.seed(0)
|
| 8 |
+
uids = objaverse.load_uids()
|
| 9 |
+
random_object_uids = random.sample(uids, 100)
|
| 10 |
+
objects = objaverse.load_objects(
|
| 11 |
+
uids=random_object_uids,
|
| 12 |
+
download_processes=processes
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
objects = objaverse.load_objects(
|
| 17 |
+
uids=random_object_uids,
|
| 18 |
+
download_processes=processes
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
import pdb; pdb.set_trace()
|
PixHtLab-Src/Demo/PixhtLab/Demo.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
PixHtLab-Src/Demo/PixhtLab/Examples/009_depth.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:17338c5be1e9aa07c8d68ed26076af44f433374a92a90130b3f803d66f8d2442
|
| 3 |
+
size 1048704
|
PixHtLab-Src/Demo/PixhtLab/Examples/009_depth_valid_mask.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:85f888fa80c5da36acf97938cda8f55323758aefda263fc859ab83b2955c92c1
|
| 3 |
+
size 262272
|
PixHtLab-Src/Demo/PixhtLab/Examples/009_pixht_new.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:25dfedce61e31739fec2164898d9975ab9224c02aea1540029eb7b04d0505118
|
| 3 |
+
size 2097280
|
PixHtLab-Src/Demo/PixhtLab/Examples/010_depth.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:872277399788bc63b43a1fc63e817d8c12d18e6960b6d7e51a41d2ea2fae21f6
|
| 3 |
+
size 1048704
|
PixHtLab-Src/Demo/PixhtLab/Examples/010_depth_valid_mask.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2ceb9b34576604a240e557f2bb2fa87b7328e30a28a44357d799c6a3271c6bf6
|
| 3 |
+
size 262272
|
PixHtLab-Src/Demo/PixhtLab/Examples/010_pixht_new.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bb2031249ece53825e983e62e1a695c53bdfcd5ae12d3a1350dff80e710f0aa3
|
| 3 |
+
size 2097280
|
PixHtLab-Src/Demo/PixhtLab/Examples/011_depth.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bfdabce8a328c02e7e52d156a2f2f03e1302fb695e9e70f35355d1aced7de24f
|
| 3 |
+
size 1048704
|
PixHtLab-Src/Demo/PixhtLab/Examples/011_depth_valid_mask.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:12eaae90b70229aa83450d8ea63fd77c7e8fdc6240e473aafab77021c2891efe
|
| 3 |
+
size 262272
|
PixHtLab-Src/Demo/PixhtLab/Examples/011_pixht_new.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9ef1a1f21f4ac53dc680fc6a965cce4f8043bd313a89e52ef733c5986be16595
|
| 3 |
+
size 2097280
|
PixHtLab-Src/Demo/PixhtLab/Examples/c82c09fc01e84282bc8870c263dcf81b_bg.jpg
ADDED
|
Git LFS Details
|
PixHtLab-Src/Demo/PixhtLab/GSSN/__init__.py
ADDED
|
File without changes
|
PixHtLab-Src/Demo/PixhtLab/GSSN/inference_shadow.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from torch.utils.data import Dataset, DataLoader
|
| 9 |
+
import torch.optim as optim
|
| 10 |
+
import torchvision
|
| 11 |
+
import torchvision.transforms as T
|
| 12 |
+
import argparse
|
| 13 |
+
import time
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
import numpy as np
|
| 16 |
+
import os
|
| 17 |
+
from os.path import join
|
| 18 |
+
|
| 19 |
+
import math
|
| 20 |
+
|
| 21 |
+
import cv2
|
| 22 |
+
import random
|
| 23 |
+
|
| 24 |
+
import sys
|
| 25 |
+
# sys.path.insert(0, '../../Training/app/models')
|
| 26 |
+
sys.path.insert(0, '/home/ysheng/Documents/Research/GSSN/Training/app/models')
|
| 27 |
+
|
| 28 |
+
from SSN_v1 import SSN_v1
|
| 29 |
+
from SSN import SSN
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class SSN_Infernece():
|
| 33 |
+
def __init__(self, ckpt, device=torch.device('cuda:0')):
|
| 34 |
+
self.device = device
|
| 35 |
+
self.model = SSN(3, 1, mid_act='gelu', out_act='null', resnet=False)
|
| 36 |
+
|
| 37 |
+
weight = torch.load(ckpt)
|
| 38 |
+
self.model.to(device)
|
| 39 |
+
self.model.load_state_dict(weight['model'])
|
| 40 |
+
|
| 41 |
+
# inference related
|
| 42 |
+
BINs = 100
|
| 43 |
+
MAX_RAD = 20
|
| 44 |
+
self.size_interval = MAX_RAD / BINs
|
| 45 |
+
self.soft_distribution = [[np.exp(-0.2 * (i - j) ** 2) for i in np.arange(BINs)] for j in np.arange(BINs)]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def render_ss(self, input_np, softness):
|
| 49 |
+
""" input_np:
|
| 50 |
+
H x W x C
|
| 51 |
+
"""
|
| 52 |
+
input_tensor = torch.tensor(input_np.transpose((2, 0, 1)))[None, ...].float().to(self.device)
|
| 53 |
+
transform = T.Resize((256, 256))
|
| 54 |
+
|
| 55 |
+
c = input_tensor.shape[1]
|
| 56 |
+
# for i in range(c):
|
| 57 |
+
# print(input_tensor[:, i].min(), input_tensor[:, i].max())
|
| 58 |
+
|
| 59 |
+
# print('softness: ', softness)
|
| 60 |
+
l = torch.from_numpy(np.array(self.soft_distribution[int(softness/self.size_interval)]).astype(np.float32)).unsqueeze(dim=0).to(self.device)
|
| 61 |
+
|
| 62 |
+
input_tensor = transform(input_tensor)
|
| 63 |
+
output_tensor = self.model(input_tensor, l)
|
| 64 |
+
output_np = output_tensor[0].detach().cpu().numpy().transpose((1,2,0))
|
| 65 |
+
|
| 66 |
+
return output_np
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
if __name__ == '__main__':
|
| 70 |
+
model = SSN_Infernece('weights/0000000700.pt')
|
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/__init__.py
ADDED
|
File without changes
|
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/__init__.py
ADDED
|
File without changes
|
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/inference_shadow.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from torch.utils.data import Dataset, DataLoader
|
| 9 |
+
import torch.optim as optim
|
| 10 |
+
import torchvision
|
| 11 |
+
import torchvision.transforms as transforms
|
| 12 |
+
import argparse
|
| 13 |
+
import time
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
import numpy as np
|
| 16 |
+
import os
|
| 17 |
+
from os.path import join
|
| 18 |
+
|
| 19 |
+
import math
|
| 20 |
+
|
| 21 |
+
import cv2
|
| 22 |
+
import random
|
| 23 |
+
from .ssn.ssn import Relight_SSN
|
| 24 |
+
device = torch.device('cuda:0')
|
| 25 |
+
|
| 26 |
+
def net_render_np(model, mask_np, hard_shadow_np, size, orientation):
|
| 27 |
+
"""
|
| 28 |
+
input:
|
| 29 |
+
mask_np shape: b x c x h x w
|
| 30 |
+
ibl_np shape: 1 x 16 x 32
|
| 31 |
+
output:
|
| 32 |
+
shadow_predict shape: b x c x h x w
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
size_interval = 0.5 / 100
|
| 36 |
+
ori_interval = np.pi / 100
|
| 37 |
+
|
| 38 |
+
soft_distribution = [[np.exp(-0.2 * (i - j) ** 2) for i in np.arange(0.5 / size_interval)]
|
| 39 |
+
for j in np.arange(0.5 / size_interval)]
|
| 40 |
+
|
| 41 |
+
# print('mask_np: {}, hard_shadow_np: {}'.format(mask_np.shape, hard_shadow_np.shape))
|
| 42 |
+
s = time.time()
|
| 43 |
+
if mask_np.dtype == np.uint8:
|
| 44 |
+
mask_np = mask_np / 255.0
|
| 45 |
+
|
| 46 |
+
mask, h_shadow = torch.Tensor(mask_np), torch.Tensor(hard_shadow_np)
|
| 47 |
+
size_soft = torch.Tensor(np.array(soft_distribution[int(size / size_interval)])).unsqueeze(0)
|
| 48 |
+
ori_soft = torch.Tensor(np.array(soft_distribution[int(orientation / ori_interval)])).unsqueeze(0)
|
| 49 |
+
|
| 50 |
+
with torch.no_grad():
|
| 51 |
+
I_m, I_h, size_t, ori = mask.to(device), h_shadow.to(device), size_soft.to(device), ori_soft.to(device)
|
| 52 |
+
# print('I_m: {}, I_h: {}'.format(I_m.shape, I_h.shape))
|
| 53 |
+
predicted_img = model(I_h, I_m, size_t, ori)
|
| 54 |
+
|
| 55 |
+
# print('net predict finished, time: {}s'.format(time.time() - s))
|
| 56 |
+
|
| 57 |
+
return predicted_img.detach().cpu().numpy()
|
| 58 |
+
|
| 59 |
+
def init_models(ckpt):
|
| 60 |
+
baseline_model = Relight_SSN(1, 1, is_training=False)
|
| 61 |
+
baseline_checkpoint = torch.load(ckpt)
|
| 62 |
+
baseline_model.to(device)
|
| 63 |
+
baseline_model.load_state_dict(baseline_checkpoint['model_state_dict'])
|
| 64 |
+
return baseline_model
|
| 65 |
+
|
| 66 |
+
if __name__ == '__main__':
|
| 67 |
+
softness = [0.02, 0.2, 0.3, 0.4]
|
| 68 |
+
model = init_models('weights/human_baseline123.pt')
|
| 69 |
+
mask, hard_shadow, size, orientation = np.random.randn(1,1,256,256), np.random.randn(1,1,256,256), softness[0], 0
|
| 70 |
+
shadow = net_render_np(model, mask, hard_shadow, size, orientation)
|
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/params.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
class params():
|
| 4 |
+
""" Singleton class for doing experiments """
|
| 5 |
+
|
| 6 |
+
class __params():
|
| 7 |
+
def __init__(self):
|
| 8 |
+
self.norm = 'group_norm'
|
| 9 |
+
self.prelu = False
|
| 10 |
+
self.weight_decay = 5e-4
|
| 11 |
+
self.small_ds = False
|
| 12 |
+
self.multi_gpu = False
|
| 13 |
+
self.log = False
|
| 14 |
+
self.input_channel = 1
|
| 15 |
+
self.vis_port = 8002
|
| 16 |
+
self.cpu = False
|
| 17 |
+
self.pred_touch = False
|
| 18 |
+
self.tbaseline = False
|
| 19 |
+
self.touch_loss = False
|
| 20 |
+
self.input_channel = 1
|
| 21 |
+
|
| 22 |
+
def set_params(self, options):
|
| 23 |
+
self.options = options
|
| 24 |
+
self.norm = options.norm
|
| 25 |
+
self.prelu = options.prelu
|
| 26 |
+
self.weight_decay = options.weight_decay
|
| 27 |
+
self.small_ds = options.small_ds
|
| 28 |
+
self.multi_gpu = options.multi_gpu
|
| 29 |
+
self.log = options.log
|
| 30 |
+
self.input_channel = options.input_channel
|
| 31 |
+
self.vis_port = options.vis_port
|
| 32 |
+
self.cpu = options.cpu
|
| 33 |
+
self.ds_folder = options.ds_folder
|
| 34 |
+
self.pred_touch = options.pred_touch
|
| 35 |
+
self.tbaseline = options.tbaseline
|
| 36 |
+
self.touch_loss = options.touch_loss
|
| 37 |
+
|
| 38 |
+
def __str__(self):
|
| 39 |
+
return 'norm: {} prelu: {} weight decay: {} small ds: {}'.format(self.norm, self.prelu, self.weight_decay, self.small_ds)
|
| 40 |
+
|
| 41 |
+
# private static variable
|
| 42 |
+
param_instance = None
|
| 43 |
+
|
| 44 |
+
def __init__(self):
|
| 45 |
+
if not params.param_instance:
|
| 46 |
+
params.param_instance = params.__params()
|
| 47 |
+
|
| 48 |
+
def get_params(self):
|
| 49 |
+
return params.param_instance
|
| 50 |
+
|
| 51 |
+
def set_params(self, options):
|
| 52 |
+
params.param_instance.set_params(options)
|
| 53 |
+
|
| 54 |
+
def parse_params():
|
| 55 |
+
parser = argparse.ArgumentParser()
|
| 56 |
+
parser.add_argument('--workers', type=int, help='number of data loading workers', default=16)
|
| 57 |
+
parser.add_argument('--batch_size', type=int, default=28, help='input batch size during training')
|
| 58 |
+
parser.add_argument('--epochs', type=int, default=10000, help='number of epochs to train for')
|
| 59 |
+
parser.add_argument('--lr', type=float, default=0.003, help='learning rate, default=0.005')
|
| 60 |
+
parser.add_argument('--beta1', type=float, default=0.9, help='momentum for SGD, default=0.9')
|
| 61 |
+
parser.add_argument('--resume', action='store_true', help='resume training')
|
| 62 |
+
parser.add_argument('--relearn', action='store_true', help='forget previous best validation loss')
|
| 63 |
+
parser.add_argument('--weight_file',type=str, help='weight file')
|
| 64 |
+
parser.add_argument('--multi_gpu', action='store_true', help='use multiple GPU training')
|
| 65 |
+
parser.add_argument('--timers', type=int, default=1, help='number of epochs to train for')
|
| 66 |
+
parser.add_argument('--use_schedule', action='store_true',help='use automatic schedule')
|
| 67 |
+
parser.add_argument('--patience', type=int, default=2, help='use automatic schedule')
|
| 68 |
+
parser.add_argument('--exp_name', type=str, default='l1 loss',help='experiment name')
|
| 69 |
+
parser.add_argument('--norm', type=str, default='group_norm', help='use group norm')
|
| 70 |
+
parser.add_argument('--ds_folder', type=str, default='./dataset/general_dataset', help='Dataset folder')
|
| 71 |
+
parser.add_argument('--hd_dir', type=str, default='/mnt/yifan/data/Adobe/HD_styleshadow/', help='Dataset folder')
|
| 72 |
+
parser.add_argument('--prelu', action='store_true', help='use prelu')
|
| 73 |
+
parser.add_argument('--small_ds', action='store_true', help='small dataset')
|
| 74 |
+
parser.add_argument('--log', action='store_true', help='log information')
|
| 75 |
+
parser.add_argument('--vis_port', default=8002,type=int, help='visdom port')
|
| 76 |
+
parser.add_argument('--weight_decay', type=float, default=4e-5, help='weight decay for model weight')
|
| 77 |
+
parser.add_argument('--save', action='store_true', help='save batch results?')
|
| 78 |
+
parser.add_argument('--cpu', action='store_true', help='Force training on CPU')
|
| 79 |
+
parser.add_argument('--pred_touch', action='store_true', help='Use touching surface')
|
| 80 |
+
parser.add_argument('--input_channel', type=int, default=1, help='how many input channels')
|
| 81 |
+
|
| 82 |
+
# based on baseline method, for fine tuning
|
| 83 |
+
parser.add_argument('--from_baseline', action='store_true', help='training from baseline')
|
| 84 |
+
parser.add_argument('--tbaseline', action='store_true', help='T-baseline, input two channels')
|
| 85 |
+
parser.add_argument('--touch_loss', action='store_true', help='Use touching loss')
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
arguments = parser.parse_args()
|
| 89 |
+
parameter = params()
|
| 90 |
+
parameter.set_params(arguments)
|
| 91 |
+
|
| 92 |
+
return arguments
|
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/ssn/__init__.py
ADDED
|
File without changes
|
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/ssn/perturb_touch.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import cv2
|
| 3 |
+
import random
|
| 4 |
+
|
| 5 |
+
random.seed(19920208)
|
| 6 |
+
|
| 7 |
+
def random_kernel():
|
| 8 |
+
ksize = random.randint(1,3)
|
| 9 |
+
kernel = np.ones((ksize, ksize))
|
| 10 |
+
return kernel
|
| 11 |
+
|
| 12 |
+
def random_perturb(img):
|
| 13 |
+
return img
|
| 14 |
+
# perturbed = img.copy()
|
| 15 |
+
# if random.random() < 0.5:
|
| 16 |
+
# perturbed = cv2.erode(perturbed, random_kernel(), iterations = 1)
|
| 17 |
+
|
| 18 |
+
# if random.random() < 0.5:
|
| 19 |
+
# perturbed = cv2.dilate(perturbed, random_kernel(), iterations = 1)
|
| 20 |
+
|
| 21 |
+
# cv2.normalize(perturbed, perturbed, 0.0,1.0, cv2.NORM_MINMAX)
|
| 22 |
+
# if len(perturbed.shape) == 2:
|
| 23 |
+
# perturbed = perturbed[:,:,np.newaxis]
|
| 24 |
+
# return perturbed
|
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/ssn/random_pattern.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import time
|
| 3 |
+
import numbergen as ng
|
| 4 |
+
import imagen as ig
|
| 5 |
+
import numpy as np
|
| 6 |
+
import cv2
|
| 7 |
+
from param.parameterized import get_logger
|
| 8 |
+
import logging
|
| 9 |
+
|
| 10 |
+
get_logger().setLevel(logging.ERROR)
|
| 11 |
+
|
| 12 |
+
class random_pattern():
|
| 13 |
+
def __init__(self, maximum_blob=50):
|
| 14 |
+
# self.generator_list = []
|
| 15 |
+
|
| 16 |
+
# start = time.time()
|
| 17 |
+
# for i in range(maximum_blob):
|
| 18 |
+
# self.generator_list.append(ig.Gaussian(size=))
|
| 19 |
+
# print('random pattern init time: {}s'.format(time.time()-start))
|
| 20 |
+
|
| 21 |
+
pass
|
| 22 |
+
|
| 23 |
+
def y_transform(self, y):
|
| 24 |
+
# y = []
|
| 25 |
+
pass
|
| 26 |
+
|
| 27 |
+
def get_pattern(self, w, h, x_density=512, y_density=128, num=50, scale=3.0, size=0.1, energy=3500,
|
| 28 |
+
mitsuba=False, seed=None, dataset=False):
|
| 29 |
+
if seed is None:
|
| 30 |
+
seed = random.randint(0, 19920208)
|
| 31 |
+
else:
|
| 32 |
+
seed = seed + int(time.time())
|
| 33 |
+
|
| 34 |
+
if num == 0:
|
| 35 |
+
ibl = np.zeros((y_density, x_density))
|
| 36 |
+
orientation = np.pi * ng.UniformRandom(seed=seed + 3)()
|
| 37 |
+
else:
|
| 38 |
+
y_fact = y_density / 256
|
| 39 |
+
num = 1
|
| 40 |
+
size = size * ng.UniformRandom(seed=seed + 4)()
|
| 41 |
+
orientation = np.pi * ng.UniformRandom(seed=seed + 3)()
|
| 42 |
+
gs = ig.Composite(operator=np.add,
|
| 43 |
+
generators=[ig.Gaussian(
|
| 44 |
+
size=size,
|
| 45 |
+
scale=1.0,
|
| 46 |
+
x=ng.UniformRandom(seed=seed + i + 1) - 0.5,
|
| 47 |
+
y=((1.0 - ng.UniformRandom(seed=seed + i + 2) * y_fact) - 0.5),
|
| 48 |
+
aspect_ratio=0.7,
|
| 49 |
+
orientation=orientation,
|
| 50 |
+
) for i in range(num)],
|
| 51 |
+
position=(0, 0),
|
| 52 |
+
xdensity=512)
|
| 53 |
+
|
| 54 |
+
# gs = ig.Composite(operator=np.add,
|
| 55 |
+
# generators=[ig.Gaussian(
|
| 56 |
+
# size=size * ng.UniformRandom(seed=seed + i + 4),
|
| 57 |
+
# scale=scale * (ng.UniformRandom(seed=seed + i + 5) + 1e-3),
|
| 58 |
+
# x=int(ind / h),
|
| 59 |
+
# y=ind % h,
|
| 60 |
+
# aspect_ratio=0.7,
|
| 61 |
+
# orientation=np.pi * ng.UniformRandom(seed=seed + i + 3),
|
| 62 |
+
# ) for i in range(num)],
|
| 63 |
+
# position=(0, 0),
|
| 64 |
+
# xdensity=512)
|
| 65 |
+
ibl = gs()[:y_density, :]
|
| 66 |
+
|
| 67 |
+
# prepare to fix energy inconsistent
|
| 68 |
+
if dataset:
|
| 69 |
+
ibl = self.to_dataset(ibl, w, h)
|
| 70 |
+
|
| 71 |
+
if mitsuba:
|
| 72 |
+
return ibl, size, orientation
|
| 73 |
+
else:
|
| 74 |
+
return ibl, size, orientation
|
| 75 |
+
|
| 76 |
+
def to_mts_ibl(self, ibl):
|
| 77 |
+
""" Input: 256 x 512 pattern generated ibl
|
| 78 |
+
Output: the ibl in mitsuba ibl
|
| 79 |
+
"""
|
| 80 |
+
return np.repeat(ibl[:, :, np.newaxis], 3, axis=2)
|
| 81 |
+
|
| 82 |
+
def normalize(self, ibl, energy=30.0):
|
| 83 |
+
total_energy = np.sum(ibl)
|
| 84 |
+
if total_energy < 1e-3:
|
| 85 |
+
print('small energy: ', total_energy)
|
| 86 |
+
h, w = ibl.shape
|
| 87 |
+
return np.zeros((h, w))
|
| 88 |
+
|
| 89 |
+
return ibl * energy / total_energy
|
| 90 |
+
|
| 91 |
+
def to_dataset(self, ibl, w, h):
|
| 92 |
+
return self.normalize(cv2.flip(cv2.resize(ibl, (w, h)), 0), 30)
|
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/ssn/ssn.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
from .ssn_submodule import Conv, Up, Up_Stream, get_layer_info, add_coords
|
| 6 |
+
import copy
|
| 7 |
+
|
| 8 |
+
class Relight_SSN(nn.Module):
|
| 9 |
+
""" Implementation of Relighting Net """
|
| 10 |
+
|
| 11 |
+
def __init__(self, n_channels=3, out_channels=3, is_training=True, activation_func = 'relu'):
|
| 12 |
+
super(Relight_SSN, self).__init__()
|
| 13 |
+
self.is_training = is_training
|
| 14 |
+
|
| 15 |
+
norm_layer1, activation_func1 = get_layer_info(16, activation_func)
|
| 16 |
+
norm_layer2, activation_func2 = get_layer_info(16, activation_func)
|
| 17 |
+
if norm_layer1 is not None:
|
| 18 |
+
self.in_conv1 = nn.Sequential(
|
| 19 |
+
nn.Conv2d(n_channels, 16, kernel_size=7, padding=3, bias=True),
|
| 20 |
+
norm_layer1,
|
| 21 |
+
activation_func1
|
| 22 |
+
)
|
| 23 |
+
elif norm_layer1 is None:
|
| 24 |
+
self.in_conv1 = nn.Sequential(
|
| 25 |
+
nn.Conv2d(n_channels, 16, kernel_size=7, padding=3, bias=True),
|
| 26 |
+
activation_func1
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
if norm_layer2 is not None:
|
| 30 |
+
self.in_conv2 = nn.Sequential(
|
| 31 |
+
nn.Conv2d(n_channels, 16, kernel_size=7, padding=3, bias=True),
|
| 32 |
+
norm_layer2,
|
| 33 |
+
activation_func2
|
| 34 |
+
)
|
| 35 |
+
elif norm_layer2 is None:
|
| 36 |
+
self.in_conv2 = nn.Sequential(
|
| 37 |
+
nn.Conv2d(n_channels, 16, kernel_size=7, padding=3, bias=True),
|
| 38 |
+
activation_func2
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
self.down_256_128 = Conv(32, 64, conv_stride=2)
|
| 42 |
+
self.down_128_128 = Conv(64, 64, conv_stride=1)
|
| 43 |
+
self.down_128_64 = Conv(64, 128, conv_stride=2)
|
| 44 |
+
self.down_64_64 = Conv(128, 128, conv_stride=1)
|
| 45 |
+
self.down_64_32 = Conv(128, 256, conv_stride=2)
|
| 46 |
+
self.down_32_32 = Conv(256, 256, conv_stride=1)
|
| 47 |
+
self.down_32_16 = Conv(256, 512, conv_stride=2)
|
| 48 |
+
self.down_16_16_1 = Conv(512, 512, conv_stride=1)
|
| 49 |
+
self.down_16_16_2 = Conv(512, 512, conv_stride=1)
|
| 50 |
+
self.down_16_16_3 = Conv(512, 512, conv_stride=1)
|
| 51 |
+
self.to_bottleneck = Conv(512, 2, conv_stride=1)
|
| 52 |
+
|
| 53 |
+
self.up_stream = Up_Stream(out_channels)
|
| 54 |
+
|
| 55 |
+
"""
|
| 56 |
+
Input is (source image, target light, source light, )
|
| 57 |
+
Output is: predicted new image, predicted source light, self-supervision image
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def forward(self, I_h, I_m, size, angle):
|
| 61 |
+
if self.is_training:
|
| 62 |
+
bs, fake_bs, dim = size.shape
|
| 63 |
+
bs, fake_bs, ch, w, h = I_h.shape
|
| 64 |
+
size = size.reshape(bs * fake_bs, dim)
|
| 65 |
+
angle = angle.reshape(bs * fake_bs, dim)
|
| 66 |
+
I_h = I_h.reshape(bs * fake_bs, ch, w, h)
|
| 67 |
+
I_m = I_m.reshape(bs * fake_bs, ch, w, h)
|
| 68 |
+
else:
|
| 69 |
+
size = size
|
| 70 |
+
angle = angle
|
| 71 |
+
I_h = I_h
|
| 72 |
+
I_m = I_m
|
| 73 |
+
style = torch.cat((size, angle), dim=1)
|
| 74 |
+
x1 = self.in_conv1(I_m) # 29 x 256 x 256
|
| 75 |
+
x2 = self.in_conv2(I_h)
|
| 76 |
+
|
| 77 |
+
x1 = torch.cat((x1, x2), dim=1) # 32 x 256 x 256
|
| 78 |
+
|
| 79 |
+
x2 = self.down_256_128(x1, x1) # 64 x 128 x 128
|
| 80 |
+
|
| 81 |
+
x3 = self.down_128_128(x2, x1) # 64 x 128 x 128
|
| 82 |
+
|
| 83 |
+
x4 = self.down_128_64(x3, x1) # 128 x 64 x 64
|
| 84 |
+
|
| 85 |
+
x5 = self.down_64_64(x4, x1) # 128 x 64 x 64
|
| 86 |
+
|
| 87 |
+
x6 = self.down_64_32(x5, x1) # 256 x 32 x 32
|
| 88 |
+
|
| 89 |
+
x7 = self.down_32_32(x6, x1) # 256 x 32 x 32
|
| 90 |
+
|
| 91 |
+
x8 = self.down_32_16(x7, x1) # 512 x 16 x 16
|
| 92 |
+
|
| 93 |
+
x9 = self.down_16_16_1(x8, x1) # 512 x 16 x 16
|
| 94 |
+
|
| 95 |
+
x10 = self.down_16_16_2(x9, x1) # 512 x 16 x 16
|
| 96 |
+
|
| 97 |
+
x11 = self.down_16_16_3(x10, x1) # 512 x 16 x 16
|
| 98 |
+
ty = self.up_stream(x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, style)
|
| 99 |
+
|
| 100 |
+
return ty
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def baseline_2_tbaseline(model):
|
| 104 |
+
""" change input layer to be two channels
|
| 105 |
+
"""
|
| 106 |
+
input_channel = 2
|
| 107 |
+
tbase_inlayer = nn.Sequential(
|
| 108 |
+
nn.Conv2d(input_channel, 32 - input_channel, kernel_size=7, padding=3, bias=True),
|
| 109 |
+
nn.GroupNorm(1, 32 - input_channel),
|
| 110 |
+
nn.ReLU()
|
| 111 |
+
)
|
| 112 |
+
model.in_conv = tbase_inlayer
|
| 113 |
+
return model
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def baseline_2_touchloss(model):
|
| 117 |
+
""" change output layer to be two channels
|
| 118 |
+
"""
|
| 119 |
+
touchless_outlayer = nn.Sequential(
|
| 120 |
+
nn.Conv2d(64, 2, stride=1, kernel_size=3, padding=1, bias=True),
|
| 121 |
+
nn.GroupNorm(1, 2),
|
| 122 |
+
nn.ReLU()
|
| 123 |
+
)
|
| 124 |
+
model.up_stream.out_conv = touchless_outlayer
|
| 125 |
+
return model
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
if __name__ == '__main__':
|
| 129 |
+
mask_test, touch_test = torch.zeros((1, 1, 256, 256)), torch.zeros((1, 1, 256, 256))
|
| 130 |
+
ibl = torch.zeros((1, 1, 16, 32))
|
| 131 |
+
|
| 132 |
+
I_s = mask_test
|
| 133 |
+
baseline = Relight_SSN(1, 1)
|
| 134 |
+
baseline_output, _ = baseline(I_s, ibl)
|
| 135 |
+
|
| 136 |
+
tbaseline = baseline_2_tbaseline(copy.deepcopy(baseline))
|
| 137 |
+
I_s = torch.cat((mask_test, touch_test), axis=1)
|
| 138 |
+
tbaseline_output, _ = tbaseline(I_s, ibl)
|
| 139 |
+
|
| 140 |
+
t_loss_baseline = baseline_2_touchloss(copy.deepcopy(baseline))
|
| 141 |
+
I_s = mask_test
|
| 142 |
+
tloss_output, _ = t_loss_baseline(I_s, ibl)
|
| 143 |
+
|
| 144 |
+
print('baseline output: ', baseline_output.shape)
|
| 145 |
+
print('tbaseline output: ', tbaseline_output.shape)
|
| 146 |
+
print('tloss output: ', tloss_output.shape)
|
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/ssn/ssn_dataset.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
|
| 3 |
+
sys.path.append("..")
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
from os.path import join
|
| 7 |
+
import torch
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
from torch.utils.data import Dataset, DataLoader
|
| 11 |
+
from torchvision import transforms, utils
|
| 12 |
+
import time
|
| 13 |
+
import random
|
| 14 |
+
# import matplotlib.pyplot as plt
|
| 15 |
+
import cv2
|
| 16 |
+
from params import params
|
| 17 |
+
from .random_pattern import random_pattern
|
| 18 |
+
from .perturb_touch import random_perturb
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class ToTensor(object):
|
| 22 |
+
"""Convert ndarrays in sample to Tensors."""
|
| 23 |
+
|
| 24 |
+
def __call__(self, img, is_transpose=True):
|
| 25 |
+
# swap color axis because
|
| 26 |
+
# numpy image: H x W x C
|
| 27 |
+
# torch image: C X H X W
|
| 28 |
+
if is_transpose:
|
| 29 |
+
img = img.transpose((0, 3, 1, 2))
|
| 30 |
+
return torch.Tensor(img)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class SSN_Dataset(Dataset):
|
| 34 |
+
def __init__(self, ds_dir, hd_dir, is_training, fake_batch_size=8):
|
| 35 |
+
start = time.time()
|
| 36 |
+
self.fake_batch_size = fake_batch_size
|
| 37 |
+
# # of samples in each group
|
| 38 |
+
# magic number here
|
| 39 |
+
self.ibl_group_size = 16
|
| 40 |
+
|
| 41 |
+
parameter = params().get_params()
|
| 42 |
+
|
| 43 |
+
# (shadow_path, mask_path)
|
| 44 |
+
self.meta_data = self.init_meta(ds_dir, hd_dir)
|
| 45 |
+
|
| 46 |
+
self.is_training = is_training
|
| 47 |
+
self.to_tensor = ToTensor()
|
| 48 |
+
|
| 49 |
+
end = time.time()
|
| 50 |
+
print("Dataset initialize spent: {} ms".format(end - start))
|
| 51 |
+
|
| 52 |
+
# fake random
|
| 53 |
+
np.random.seed(19950220)
|
| 54 |
+
np.random.shuffle(self.meta_data)
|
| 55 |
+
|
| 56 |
+
self.valid_divide = 10
|
| 57 |
+
if parameter.small_ds:
|
| 58 |
+
self.meta_data = self.meta_data[:len(self.meta_data) // self.valid_divide]
|
| 59 |
+
|
| 60 |
+
self.training_num = len(self.meta_data) - len(self.meta_data) // self.valid_divide
|
| 61 |
+
print('training: {}, validation: {}'.format(self.training_num, len(self.meta_data) // self.valid_divide))
|
| 62 |
+
|
| 63 |
+
self.random_pattern_generator = random_pattern()
|
| 64 |
+
|
| 65 |
+
self.thread_id = os.getpid()
|
| 66 |
+
self.seed = os.getpid()
|
| 67 |
+
self.perturb = not parameter.pred_touch and not parameter.touch_loss
|
| 68 |
+
self.size_interval = 0.5 / 100
|
| 69 |
+
self.ori_interval = np.pi / 100
|
| 70 |
+
|
| 71 |
+
self.soft_distribution = [[np.exp(-0.4 * (i - j) ** 2) for i in np.arange(0.5 / self.size_interval)]
|
| 72 |
+
for j in np.arange(0.5 / self.size_interval)]
|
| 73 |
+
|
| 74 |
+
def __len__(self):
|
| 75 |
+
if self.is_training:
|
| 76 |
+
return self.training_num
|
| 77 |
+
else:
|
| 78 |
+
# return len(self.meta_data) - self.training_num
|
| 79 |
+
return len(self.meta_data) // self.valid_divide
|
| 80 |
+
|
| 81 |
+
def __getitem__(self, idx):
|
| 82 |
+
if self.is_training and idx > self.training_num:
|
| 83 |
+
print("error")
|
| 84 |
+
# offset to validation set
|
| 85 |
+
if not self.is_training:
|
| 86 |
+
idx = self.training_num + idx
|
| 87 |
+
|
| 88 |
+
cur_seed = idx * 1234 + os.getpid() + time.time()
|
| 89 |
+
random.seed(cur_seed)
|
| 90 |
+
|
| 91 |
+
# random ibls
|
| 92 |
+
shadow_path, mask_path, hard_path, touch_path = self.meta_data[idx]
|
| 93 |
+
hard_folder = hard_path.replace(hard_path.split('/')[-1], '')
|
| 94 |
+
if os.path.exists(hard_folder):
|
| 95 |
+
# hard_shadow = cv2.imread(hard_folder)
|
| 96 |
+
|
| 97 |
+
mask_img = cv2.imread(mask_path)
|
| 98 |
+
mask_img = mask_img[:, :, 0]
|
| 99 |
+
if mask_img.dtype == np.uint8:
|
| 100 |
+
mask_img = mask_img / 255.0
|
| 101 |
+
mask_img, shadow_bases = np.expand_dims(mask_img, axis=2), np.load(shadow_path)
|
| 102 |
+
|
| 103 |
+
w, h, c, m = shadow_bases.shape
|
| 104 |
+
shadow_soft_list = []
|
| 105 |
+
shadow_hard_list = []
|
| 106 |
+
size_list = []
|
| 107 |
+
orientation_list = []
|
| 108 |
+
mask_img_list = []
|
| 109 |
+
for i in range(int(self.fake_batch_size)):
|
| 110 |
+
shadow_img, light_img, size, orientation = self.render_new_shadow(shadow_bases)
|
| 111 |
+
|
| 112 |
+
h, w = mask_img.shape[0], mask_img.shape[1]
|
| 113 |
+
hi, wi = np.where(light_img == light_img.max())
|
| 114 |
+
|
| 115 |
+
while len(hi) > 1:
|
| 116 |
+
shadow_img, light_img, size, orientation = self.render_new_shadow(shadow_bases)
|
| 117 |
+
hi, wi = np.where(light_img == light_img.max())
|
| 118 |
+
size_soft = np.array(self.soft_distribution[int(size / self.size_interval)])
|
| 119 |
+
ori_soft = np.array(self.soft_distribution[int(orientation / self.ori_interval)])
|
| 120 |
+
prefix = '_ibli_' + str(int(wi * 8)) + '_iblj_' + str(int(hi * 8) + 128) + '_shadow.png'
|
| 121 |
+
shadow_hard_path = hard_path.replace('_shadow.png', prefix)
|
| 122 |
+
shadow_base = cv2.imread(shadow_hard_path, -1)[:, :, 0] / 255.0
|
| 123 |
+
shadow_base = np.expand_dims(shadow_base, axis=2)
|
| 124 |
+
shadow_base = self.line_aug(shadow_base)
|
| 125 |
+
shadow_soft_list.append(shadow_img)
|
| 126 |
+
shadow_hard_list.append(shadow_base)
|
| 127 |
+
size_list.append(size_soft)
|
| 128 |
+
orientation_list.append(ori_soft)
|
| 129 |
+
mask_img_list.append(mask_img)
|
| 130 |
+
shadow_softs = np.array(shadow_soft_list)
|
| 131 |
+
shadow_hards = np.array(shadow_hard_list)
|
| 132 |
+
sizes = np.array(size_list)
|
| 133 |
+
orientations = np.array(orientation_list)
|
| 134 |
+
mask_imgs = np.array(mask_img_list)
|
| 135 |
+
|
| 136 |
+
# touch_img = self.read_img(touch_path)
|
| 137 |
+
# touch_img = touch_img[:, :, 0:1]
|
| 138 |
+
|
| 139 |
+
# if self.perturb:
|
| 140 |
+
# touch_img = random_perturb(touch_img)
|
| 141 |
+
|
| 142 |
+
# input_img = np.concatenate((mask_img, touch_img), axis=2)
|
| 143 |
+
size = torch.Tensor(sizes)
|
| 144 |
+
ori = torch.Tensor(orientations)
|
| 145 |
+
hard_shadow, soft_shadow, mask_img = self.to_tensor(shadow_hards), self.to_tensor(
|
| 146 |
+
shadow_softs), self.to_tensor(
|
| 147 |
+
mask_imgs)
|
| 148 |
+
return {"hard_shadow": hard_shadow, "soft_shadow": soft_shadow, "mask_img": mask_img, "size": size,
|
| 149 |
+
"angle": ori}
|
| 150 |
+
else:
|
| 151 |
+
mask_img = cv2.imread(mask_path)
|
| 152 |
+
mask_img = mask_img[:, :, 0]
|
| 153 |
+
if mask_img.dtype == np.uint8:
|
| 154 |
+
mask_img = mask_img / 255.0
|
| 155 |
+
mask_img, shadow_bases = np.expand_dims(mask_img, axis=2), 1.0 - np.load(shadow_path)
|
| 156 |
+
|
| 157 |
+
w, h, c, m = shadow_bases.shape
|
| 158 |
+
shadow_soft_list = []
|
| 159 |
+
shadow_hard_list = []
|
| 160 |
+
size_list = []
|
| 161 |
+
orientation_list = []
|
| 162 |
+
mask_img_list = []
|
| 163 |
+
for i in range(int(self.fake_batch_size)):
|
| 164 |
+
shadow_img, light_img, size, orientation = self.render_new_shadow(shadow_bases)
|
| 165 |
+
|
| 166 |
+
h, w = mask_img.shape[0], mask_img.shape[1]
|
| 167 |
+
hi, wi = np.where(light_img == light_img.max())
|
| 168 |
+
|
| 169 |
+
while len(hi) > 1:
|
| 170 |
+
shadow_img, light_img, size, orientation = self.render_new_shadow(shadow_bases)
|
| 171 |
+
hi, wi, _ = np.where(light_img == light_img[:, :, :].max())
|
| 172 |
+
size_soft = np.array(self.soft_distribution[int(size / self.size_interval)])
|
| 173 |
+
ori_soft = np.array(self.soft_distribution[int(orientation / self.ori_interval)])
|
| 174 |
+
|
| 175 |
+
shadow_base = shadow_bases[:, :, wi, hi]
|
| 176 |
+
shadow_base[shadow_base > 0.3] = 1
|
| 177 |
+
shadow_base[shadow_base < 0.4] = 0
|
| 178 |
+
shadow_base = self.line_aug(shadow_base)
|
| 179 |
+
mask_img = np.expand_dims(cv2.resize(mask_img, (512, 512)), axis=2)
|
| 180 |
+
shadow_base = np.expand_dims(cv2.resize(shadow_base, (512, 512)), axis=2)
|
| 181 |
+
shadow_img = np.expand_dims(cv2.resize(shadow_img, (512, 512)), axis=2)
|
| 182 |
+
shadow_soft_list.append(shadow_img)
|
| 183 |
+
shadow_hard_list.append(shadow_base)
|
| 184 |
+
size_list.append(size_soft)
|
| 185 |
+
orientation_list.append(ori_soft)
|
| 186 |
+
mask_img_list.append(mask_img)
|
| 187 |
+
shadow_softs = np.array(shadow_soft_list)
|
| 188 |
+
shadow_hards = np.array(shadow_hard_list)
|
| 189 |
+
sizes = np.array(size_list)
|
| 190 |
+
orientations = np.array(orientation_list)
|
| 191 |
+
mask_imgs = np.array(mask_img_list)
|
| 192 |
+
|
| 193 |
+
# touch_img = self.read_img(touch_path)
|
| 194 |
+
# touch_img = touch_img[:, :, 0:1]
|
| 195 |
+
|
| 196 |
+
# if self.perturb:
|
| 197 |
+
# touch_img = random_perturb(touch_img)
|
| 198 |
+
|
| 199 |
+
# input_img = np.concatenate((mask_img, touch_img), axis=2)
|
| 200 |
+
size = torch.Tensor(sizes)
|
| 201 |
+
ori = torch.Tensor(orientations)
|
| 202 |
+
|
| 203 |
+
hard_shadow, soft_shadow, mask_img = self.to_tensor(shadow_hards), self.to_tensor(
|
| 204 |
+
shadow_softs), self.to_tensor(
|
| 205 |
+
mask_imgs)
|
| 206 |
+
|
| 207 |
+
return {"hard_shadow": hard_shadow, "soft_shadow": soft_shadow, "mask_img": mask_img, "size": size,
|
| 208 |
+
"angle": ori}
|
| 209 |
+
|
| 210 |
+
def init_meta(self, ds_dir, hd_dir):
|
| 211 |
+
metadata = []
|
| 212 |
+
# base_folder = join(ds_dir, 'base')
|
| 213 |
+
# mask_folder = join(ds_dir, 'mask')
|
| 214 |
+
# hard_folder = join(ds_dir, 'hard')
|
| 215 |
+
# touch_folder = join(ds_dir, 'touch')
|
| 216 |
+
# model_list = [f for f in os.listdir(base_folder) if os.path.isdir(join(base_folder, f))]
|
| 217 |
+
# for m in model_list:
|
| 218 |
+
# shadow_folder, cur_mask_folder = join(base_folder, m), join(mask_folder, m)
|
| 219 |
+
# shadows = [f for f in os.listdir(shadow_folder) if f.find('_shadow.npy') != -1]
|
| 220 |
+
# for s in shadows:
|
| 221 |
+
# prefix = s[:s.find('_shadow')]
|
| 222 |
+
# metadata.append((join(shadow_folder, s),
|
| 223 |
+
# join(cur_mask_folder, prefix + '_mask.png'),
|
| 224 |
+
# join(join(hard_folder, m), prefix + '_shadow.png'),
|
| 225 |
+
# join(join(touch_folder, m), prefix + '_touch.png')))
|
| 226 |
+
|
| 227 |
+
base_folder = join(hd_dir, 'base')
|
| 228 |
+
mask_folder = join(hd_dir, 'mask')
|
| 229 |
+
hard_folder = join(hd_dir, 'hard')
|
| 230 |
+
touch_folder = join(hd_dir, 'touch')
|
| 231 |
+
model_list = [f for f in os.listdir(base_folder) if os.path.isdir(join(base_folder, f))]
|
| 232 |
+
for m in model_list:
|
| 233 |
+
shadow_folder, cur_mask_folder = join(base_folder, m), join(mask_folder, m)
|
| 234 |
+
shadows = [f for f in os.listdir(shadow_folder) if f.find('_shadow.npy') != -1]
|
| 235 |
+
for s in shadows:
|
| 236 |
+
prefix = s[:s.find('_shadow')]
|
| 237 |
+
metadata.append((join(shadow_folder, s),
|
| 238 |
+
join(cur_mask_folder, prefix + '_mask.png'),
|
| 239 |
+
join(join(hard_folder, m), prefix + '_shadow.png'),
|
| 240 |
+
join(join(touch_folder, m), prefix + '_touch.png')))
|
| 241 |
+
|
| 242 |
+
return metadata
|
| 243 |
+
|
| 244 |
+
def line_aug(self, shadow):
|
| 245 |
+
p = np.random.random()
|
| 246 |
+
if p > 0.6:
|
| 247 |
+
k = np.tan(min((np.random.random() + 0.000000001), 0.999) * np.pi - np.pi / 2)
|
| 248 |
+
x, y, c = shadow.shape
|
| 249 |
+
b_max = y - x * k
|
| 250 |
+
line_num = np.random.randint(1, 20)
|
| 251 |
+
b_list = np.random.random(line_num) * b_max
|
| 252 |
+
x_coord = np.tile(np.arange(shadow.shape[1])[None, :], (shadow.shape[0], 1))
|
| 253 |
+
y_coord = np.tile(np.arange(shadow.shape[0])[:, None], (1, shadow.shape[1]))
|
| 254 |
+
|
| 255 |
+
for b in b_list:
|
| 256 |
+
mask_res = y_coord - k * x_coord - b
|
| 257 |
+
shadow[np.abs(mask_res) < 1] = 0
|
| 258 |
+
return shadow
|
| 259 |
+
|
| 260 |
+
def get_prefix(self, path):
|
| 261 |
+
folder = os.path.dirname(path)
|
| 262 |
+
basename = os.path.basename(path)
|
| 263 |
+
return os.path.join(folder, basename[:basename.find('_')])
|
| 264 |
+
|
| 265 |
+
def render_new_shadow(self, shadow_bases):
|
| 266 |
+
shadow_bases = shadow_bases[:, :, :, :]
|
| 267 |
+
h, w, iw, ih = shadow_bases.shape
|
| 268 |
+
|
| 269 |
+
num = random.randint(0, 50)
|
| 270 |
+
pattern_img, size, orientation = self.random_pattern_generator.get_pattern(iw, ih, num=num, size=0.5,
|
| 271 |
+
mitsuba=False)
|
| 272 |
+
|
| 273 |
+
# flip to mitsuba ibl
|
| 274 |
+
pattern_img = self.normalize_energy(cv2.flip(cv2.resize(pattern_img, (iw, ih)), 0))
|
| 275 |
+
shadow = np.tensordot(shadow_bases, pattern_img, axes=([2, 3], [1, 0]))
|
| 276 |
+
# pattern_img = np.expand_dims(cv2.resize(pattern_img, (iw, 16)), 2)
|
| 277 |
+
|
| 278 |
+
return np.expand_dims(shadow, 2), pattern_img, size, orientation
|
| 279 |
+
|
| 280 |
+
def get_min_max(self, batch_data, name):
|
| 281 |
+
print('{} min: {}, max: {}'.format(name, np.min(batch_data), np.max(batch_data)))
|
| 282 |
+
|
| 283 |
+
def log(self, log_info):
|
| 284 |
+
with open('log.txt', 'a+') as f:
|
| 285 |
+
f.write(log_info)
|
| 286 |
+
|
| 287 |
+
def normalize_energy(self, ibl, energy=30.0):
|
| 288 |
+
if np.sum(ibl) < 1e-3:
|
| 289 |
+
return ibl
|
| 290 |
+
return ibl * energy / np.sum(ibl)
|
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/ssn/ssn_submodule.py
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
def get_layer_info(out_channels, activation_func='relu'):
|
| 6 |
+
if out_channels >= 32:
|
| 7 |
+
group_num = 32
|
| 8 |
+
else:
|
| 9 |
+
group_num = 1
|
| 10 |
+
|
| 11 |
+
norm_layer = nn.GroupNorm(group_num, out_channels)
|
| 12 |
+
|
| 13 |
+
if activation_func == 'relu':
|
| 14 |
+
activation_func = nn.ReLU()
|
| 15 |
+
elif activation_func == 'prelu':
|
| 16 |
+
activation_func = nn.PReLU(out_channels)
|
| 17 |
+
|
| 18 |
+
return norm_layer, activation_func
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# add coord_conv
|
| 22 |
+
class add_coords(nn.Module):
|
| 23 |
+
def __init__(self, use_cuda=True):
|
| 24 |
+
super(add_coords, self).__init__()
|
| 25 |
+
self.use_cuda = use_cuda
|
| 26 |
+
|
| 27 |
+
def forward(self, input_tensor):
|
| 28 |
+
b, c, dim_y, dim_x = input_tensor.shape
|
| 29 |
+
xx_ones = torch.ones([1, 1, 1, dim_x], dtype=torch.int32)
|
| 30 |
+
yy_ones = torch.ones([1, 1, 1, dim_y], dtype=torch.int32)
|
| 31 |
+
|
| 32 |
+
xx_range = torch.arange(dim_y, dtype=torch.int32)
|
| 33 |
+
yy_range = torch.arange(dim_x, dtype=torch.int32)
|
| 34 |
+
xx_range = xx_range[None, None, :, None]
|
| 35 |
+
yy_range = yy_range[None, None, :, None]
|
| 36 |
+
|
| 37 |
+
xx_channel = torch.matmul(xx_range, xx_ones)
|
| 38 |
+
yy_channel = torch.matmul(yy_range, yy_ones)
|
| 39 |
+
|
| 40 |
+
# transpose y
|
| 41 |
+
yy_channel = yy_channel.permute(0, 1, 3, 2)
|
| 42 |
+
|
| 43 |
+
xx_channel = xx_channel.float() / (dim_y - 1)
|
| 44 |
+
yy_channel = yy_channel.float() / (dim_x - 1)
|
| 45 |
+
|
| 46 |
+
xx_channel = xx_channel * 2 - 1
|
| 47 |
+
yy_channel = yy_channel * 2 - 1
|
| 48 |
+
|
| 49 |
+
xx_channel = xx_channel.repeat(b, 1, 1, 1)
|
| 50 |
+
yy_channel = yy_channel.repeat(b, 1, 1, 1)
|
| 51 |
+
|
| 52 |
+
if torch.cuda.is_available and self.use_cuda:
|
| 53 |
+
input_tensor = input_tensor.cuda()
|
| 54 |
+
xx_channel = xx_channel.cuda()
|
| 55 |
+
yy_channel = yy_channel.cuda()
|
| 56 |
+
out = torch.cat([input_tensor, xx_channel, yy_channel], dim=1)
|
| 57 |
+
return out
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class Conv(nn.Module):
|
| 61 |
+
""" (convolution => [BN] => ReLU) """
|
| 62 |
+
|
| 63 |
+
def __init__(self, in_channels, out_channels, kernel_size=3, conv_stride=1, padding=1, bias=True,
|
| 64 |
+
activation_func='relu', style=False):
|
| 65 |
+
super().__init__()
|
| 66 |
+
|
| 67 |
+
self.style = style
|
| 68 |
+
norm_layer, activation_func = get_layer_info(out_channels, activation_func)
|
| 69 |
+
if style:
|
| 70 |
+
self.styleconv = Conv2DMod(in_channels, out_channels, kernel_size)
|
| 71 |
+
self.relu = nn.LeakyReLU(0.2, inplace=True)
|
| 72 |
+
else:
|
| 73 |
+
if norm_layer is not None:
|
| 74 |
+
self.conv = nn.Sequential(
|
| 75 |
+
nn.Conv2d(in_channels, out_channels, stride=conv_stride, kernel_size=kernel_size, padding=padding,
|
| 76 |
+
bias=bias),
|
| 77 |
+
norm_layer,
|
| 78 |
+
activation_func)
|
| 79 |
+
else:
|
| 80 |
+
self.conv = nn.Sequential(
|
| 81 |
+
nn.Conv2d(in_channels, out_channels, stride=conv_stride, kernel_size=kernel_size, padding=padding,
|
| 82 |
+
bias=bias),
|
| 83 |
+
activation_func)
|
| 84 |
+
|
| 85 |
+
def forward(self, x, style_fea):
|
| 86 |
+
if self.style:
|
| 87 |
+
res = self.styleconv(x, style_fea)
|
| 88 |
+
res = self.relu(res)
|
| 89 |
+
return res
|
| 90 |
+
else:
|
| 91 |
+
return self.conv(x)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class Conv2DMod(nn.Module):
|
| 95 |
+
def __init__(self, in_chan, out_chan, kernel, demod=True, stride=1, dilation=1, eps=1e-8, **kwargs):
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.filters = out_chan
|
| 98 |
+
self.demod = demod
|
| 99 |
+
self.kernel = kernel
|
| 100 |
+
self.stride = stride
|
| 101 |
+
self.dilation = dilation
|
| 102 |
+
self.weight = nn.Parameter(torch.randn((out_chan, in_chan, kernel, kernel)))
|
| 103 |
+
self.eps = eps
|
| 104 |
+
nn.init.kaiming_normal_(self.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
|
| 105 |
+
|
| 106 |
+
def _get_same_padding(self, size, kernel, dilation, stride):
|
| 107 |
+
return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2
|
| 108 |
+
|
| 109 |
+
def forward(self, x, y):
|
| 110 |
+
b, c, h, w = x.shape
|
| 111 |
+
|
| 112 |
+
w1 = y[:, None, :, None, None]
|
| 113 |
+
w2 = self.weight[None, :, :, :, :]
|
| 114 |
+
weights = w2 * (w1 + 1)
|
| 115 |
+
|
| 116 |
+
if self.demod:
|
| 117 |
+
d = torch.rsqrt((weights ** 2).sum(dim=(2, 3, 4), keepdim=True) + self.eps)
|
| 118 |
+
weights = weights * d
|
| 119 |
+
|
| 120 |
+
x = x.reshape(1, -1, h, w)
|
| 121 |
+
|
| 122 |
+
_, _, *ws = weights.shape
|
| 123 |
+
weights = weights.reshape(b * self.filters, *ws)
|
| 124 |
+
|
| 125 |
+
padding = self._get_same_padding(h, self.kernel, self.dilation, self.stride)
|
| 126 |
+
x = F.conv2d(x, weights, padding=padding, groups=b)
|
| 127 |
+
|
| 128 |
+
x = x.reshape(-1, self.filters, h, w)
|
| 129 |
+
return x
|
| 130 |
+
|
| 131 |
+
class Up(nn.Module):
|
| 132 |
+
""" Upscaling then conv """
|
| 133 |
+
|
| 134 |
+
def __init__(self, in_channels, out_channels, activation_func='relu', style=False):
|
| 135 |
+
super().__init__()
|
| 136 |
+
self.style = style
|
| 137 |
+
activation_func = 'relu'
|
| 138 |
+
norm_layer, activation_func = get_layer_info(out_channels, activation_func)
|
| 139 |
+
|
| 140 |
+
self.up_layer = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
| 141 |
+
self.up = Conv(in_channels, in_channels // 4, activation_func=activation_func, style=style)
|
| 142 |
+
|
| 143 |
+
def forward(self, x, style_fea):
|
| 144 |
+
if self.style:
|
| 145 |
+
x = self.up_layer(x)
|
| 146 |
+
return self.up(x, style_fea)
|
| 147 |
+
else:
|
| 148 |
+
x = self.up_layer(x)
|
| 149 |
+
return self.up(x, style_fea)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class PSPUpsample(nn.Module):
|
| 153 |
+
def __init__(self, in_channels, out_channels, scale_x, scale_y):
|
| 154 |
+
super().__init__()
|
| 155 |
+
self.conv = Conv(in_channels, out_channels)
|
| 156 |
+
self.scale_x = scale_x
|
| 157 |
+
self.scale_y = scale_y
|
| 158 |
+
|
| 159 |
+
def forward(self, x):
|
| 160 |
+
h, w = self.scale_y * x.size(2), self.scale_x * x.size(3)
|
| 161 |
+
p = F.upsample(input=x, size=(h, w), mode='bilinear')
|
| 162 |
+
return self.conv(p)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class PSP(nn.Module):
|
| 166 |
+
def __init__(self, in_channels):
|
| 167 |
+
super().__init__()
|
| 168 |
+
# pooling
|
| 169 |
+
self.pool2 = nn.AdaptiveAvgPool2d((2, 2))
|
| 170 |
+
self.pool4 = nn.AdaptiveAvgPool2d((4, 4))
|
| 171 |
+
self.pool8 = nn.AdaptiveAvgPool2d((8, 8))
|
| 172 |
+
|
| 173 |
+
# conv -> compress channels
|
| 174 |
+
avg_channel = in_channels // 4
|
| 175 |
+
self.conv2 = Conv(in_channels, avg_channel)
|
| 176 |
+
self.conv4 = Conv(in_channels, avg_channel)
|
| 177 |
+
self.conv8 = Conv(in_channels, avg_channel)
|
| 178 |
+
self.conv16 = Conv(in_channels, avg_channel)
|
| 179 |
+
|
| 180 |
+
# up sapmle -> match dimension
|
| 181 |
+
self.up2 = PSPUpsample(avg_channel, avg_channel, 16 // 2, 16 // 2)
|
| 182 |
+
self.up4 = PSPUpsample(avg_channel, avg_channel, 16 // 4, 16 // 4)
|
| 183 |
+
self.up8 = PSPUpsample(avg_channel, avg_channel, 16 // 8, 16 // 8)
|
| 184 |
+
|
| 185 |
+
def forward(self, x):
|
| 186 |
+
x2 = self.up2(self.conv2(self.pool2(x)))
|
| 187 |
+
x4 = self.up4(self.conv4(self.pool4(x)))
|
| 188 |
+
x8 = self.up8(self.conv8(self.pool8(x)))
|
| 189 |
+
x16 = self.conv16(x)
|
| 190 |
+
return torch.cat((x2, x4, x8, x16), dim=1)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
class Up_Stream(nn.Module):
|
| 194 |
+
""" Up Stream Sequence """
|
| 195 |
+
|
| 196 |
+
def __init__(self, out_channels=3, activation_func = 'relu'):
|
| 197 |
+
super(Up_Stream, self).__init__()
|
| 198 |
+
|
| 199 |
+
input_channel = 512
|
| 200 |
+
fea_dim = 200
|
| 201 |
+
norm_layer, activation_func = get_layer_info(input_channel, activation_func)
|
| 202 |
+
self.to_style1 = nn.Linear(in_features=fea_dim, out_features=input_channel)
|
| 203 |
+
self.up_16_16_1 = Conv(input_channel, 256, activation_func=activation_func, style=True)
|
| 204 |
+
self.up_16_16_2 = Conv(768, 512, activation_func=activation_func)
|
| 205 |
+
self.up_16_16_3 = Conv(1024, 512, activation_func=activation_func)
|
| 206 |
+
|
| 207 |
+
self.up_16_32 = Up(1024, 256, activation_func=activation_func)
|
| 208 |
+
self.to_style2 = nn.Linear(in_features=fea_dim, out_features=512)
|
| 209 |
+
self.up_32_32_1 = Conv(512, 256, activation_func=activation_func, style=True)
|
| 210 |
+
|
| 211 |
+
self.up_32_64 = Up(512, 128, activation_func=activation_func)
|
| 212 |
+
self.to_style3 = nn.Linear(in_features=fea_dim, out_features=256)
|
| 213 |
+
self.up_64_64_1 = Conv(256, 128, activation_func=activation_func, style=True)
|
| 214 |
+
|
| 215 |
+
self.up_64_128 = Up(256, 64, activation_func=activation_func)
|
| 216 |
+
self.to_style4 = nn.Linear(in_features=fea_dim, out_features=128)
|
| 217 |
+
self.up_128_128_1 = Conv(128, 64, activation_func=activation_func, style=True)
|
| 218 |
+
|
| 219 |
+
self.up_128_256 = Up(128, 32, activation_func=activation_func)
|
| 220 |
+
self.out_conv = Conv(64, out_channels, activation_func='relu')
|
| 221 |
+
|
| 222 |
+
def forward(self, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, style):
|
| 223 |
+
batch_size, c, h, w = x1.size()
|
| 224 |
+
|
| 225 |
+
# import pdb; pdb.set_trace()
|
| 226 |
+
# multiple channel ibl
|
| 227 |
+
# style = torch.zeros(batch_size, 500).to(x11.device)
|
| 228 |
+
|
| 229 |
+
# y = l.view(-1, 512, 1, 1).repeat(1, 1, 16, 16)
|
| 230 |
+
style1 = self.to_style1(style)
|
| 231 |
+
y = self.up_16_16_1(x11, style1) # 256 x 16 x 16
|
| 232 |
+
|
| 233 |
+
y = torch.cat((x10, y), dim=1) # 768 x 16 x 16
|
| 234 |
+
# print(y.size())
|
| 235 |
+
|
| 236 |
+
y = self.up_16_16_2(y, y) # 512 x 16 x 16
|
| 237 |
+
# print(y.size())
|
| 238 |
+
|
| 239 |
+
y = torch.cat((x9, y), dim=1) # 1024 x 16 x 16
|
| 240 |
+
# print(y.size())
|
| 241 |
+
|
| 242 |
+
# import pdb; pdb.set_trace()
|
| 243 |
+
y = self.up_16_16_3(y, y) # 512 x 16 x 16
|
| 244 |
+
# print(y.size())
|
| 245 |
+
|
| 246 |
+
y = torch.cat((x8, y), dim=1) # 1024 x 16 x 16
|
| 247 |
+
# print(y.size())
|
| 248 |
+
|
| 249 |
+
# import pdb; pdb.set_trace()
|
| 250 |
+
y = self.up_16_32(y, y) # 256 x 32 x 32
|
| 251 |
+
# print(y.size())
|
| 252 |
+
|
| 253 |
+
y = torch.cat((x7, y), dim=1)
|
| 254 |
+
style2 = self.to_style2(style)
|
| 255 |
+
y = self.up_32_32_1(y, style2) # 256 x 32 x 32
|
| 256 |
+
# print(y.size())
|
| 257 |
+
|
| 258 |
+
y = torch.cat((x6, y), dim=1)
|
| 259 |
+
y = self.up_32_64(y, y)
|
| 260 |
+
# print(y.size())
|
| 261 |
+
y = torch.cat((x5, y), dim=1)
|
| 262 |
+
style3 = self.to_style3(style)
|
| 263 |
+
y = self.up_64_64_1(y, style3) # 128 x 64 x 64
|
| 264 |
+
# print(y.size())
|
| 265 |
+
|
| 266 |
+
y = torch.cat((x4, y), dim=1)
|
| 267 |
+
y = self.up_64_128(y, y)
|
| 268 |
+
# print(y.size())
|
| 269 |
+
y = torch.cat((x3, y), dim=1)
|
| 270 |
+
style4 = self.to_style4(style)
|
| 271 |
+
y = self.up_128_128_1(y, style4) # 64 x 128 x 128
|
| 272 |
+
# print(y.size())
|
| 273 |
+
|
| 274 |
+
y = torch.cat((x2, y), dim=1)
|
| 275 |
+
y = self.up_128_256(y, y) # 32 x 256 x 256
|
| 276 |
+
# print(y.size())
|
| 277 |
+
|
| 278 |
+
y = torch.cat((x1, y), dim=1)
|
| 279 |
+
|
| 280 |
+
y = self.out_conv(y, y) # 3 x 256 x 256
|
| 281 |
+
|
| 282 |
+
return y
|
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/ssn/test.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import ssn_dataset
|
| 3 |
+
from torchvision import transforms, utils
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
csv_file = "~/Dataset/soft_shadow/train/metadata.csv"
|
| 7 |
+
# compose_transform = None
|
| 8 |
+
training_dataset = ssn_dataset.SSN_Dataset(csv_file, is_training = True)
|
| 9 |
+
testing_dataset = ssn_dataset.SSN_Dataset(csv_file, is_training = False)
|
| 10 |
+
|
| 11 |
+
print('training set size: ', len(training_dataset))
|
| 12 |
+
print('testing set size: ',len(testing_dataset))
|
| 13 |
+
|
| 14 |
+
print(len(training_dataset.meta_data))
|
| 15 |
+
print(training_dataset.meta_data[0])
|
| 16 |
+
|
| 17 |
+
# for j in range(10):
|
| 18 |
+
# for i in range(len(training_dataset)):
|
| 19 |
+
# data = training_dataset[i]
|
| 20 |
+
# # print("{} \r".format(i), flush=True, end="")
|
| 21 |
+
# print("{} ".format(i))
|
| 22 |
+
|
| 23 |
+
# for i,data in enumerate(testing_dataset):
|
| 24 |
+
# print("{} \r".format(i), flush=True, end="")
|
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/ssn/test_dataset.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ssn_dataset
|
| 2 |
+
import time
|
| 3 |
+
|
| 4 |
+
if __name__ == '__main__':
|
| 5 |
+
start = time.time()
|
| 6 |
+
csv_file = "~/Dataset/soft_shadow/single_human/metadata.csv"
|
| 7 |
+
training_dataset = ssn_dataset.SSN_Dataset(csv_file, is_training=True)
|
| 8 |
+
testing_dataset = ssn_dataset.SSN_Dataset(csv_file, is_training=False)
|
| 9 |
+
|
| 10 |
+
print("Training dataset num: ", len(training_dataset))
|
| 11 |
+
print("Testing dataset num: ",len(testing_dataset))
|
| 12 |
+
|
| 13 |
+
for i in range(len(training_dataset)):
|
| 14 |
+
data = training_dataset[i]
|
| 15 |
+
print('Training set: successfully iterate {} \r'.format(i), flush=True, end='')
|
| 16 |
+
|
| 17 |
+
for i in range(len(testing_dataset)):
|
| 18 |
+
data = testing_dataset[i]
|
| 19 |
+
print('Validation set: successfully iterate {} \r'.format(i), flush=True, end='')
|
| 20 |
+
|
| 21 |
+
end = time.time()
|
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/__init__.py
ADDED
|
File without changes
|
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/html.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dominate
|
| 2 |
+
from dominate.tags import meta, h3, table, tr, td, p, a, img, br
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
class HTML:
|
| 6 |
+
def __init__(self, web_dir, title, reflesh=0):
|
| 7 |
+
self.title = title
|
| 8 |
+
self.web_dir = web_dir
|
| 9 |
+
if not os.path.exists(self.web_dir):
|
| 10 |
+
os.makedirs(self.web_dir)
|
| 11 |
+
|
| 12 |
+
# print(self.img_dir)
|
| 13 |
+
|
| 14 |
+
self.doc = dominate.document(title=title)
|
| 15 |
+
if reflesh > 0:
|
| 16 |
+
with self.doc.head:
|
| 17 |
+
meta(http_equiv="reflesh", content=str(reflesh))
|
| 18 |
+
|
| 19 |
+
def get_image_dir(self):
|
| 20 |
+
return self.img_dir
|
| 21 |
+
|
| 22 |
+
def add_header(self, str):
|
| 23 |
+
with self.doc:
|
| 24 |
+
h3(str)
|
| 25 |
+
|
| 26 |
+
def add_table(self, border=1):
|
| 27 |
+
self.t = table(border=border, style="table-layout: fixed;")
|
| 28 |
+
self.doc.add(self.t)
|
| 29 |
+
|
| 30 |
+
def add_images(self, ims, txts, links, width=400, height=300):
|
| 31 |
+
self.add_table()
|
| 32 |
+
with self.t:
|
| 33 |
+
with tr():
|
| 34 |
+
for im, txt, link in zip(ims, txts, links):
|
| 35 |
+
with td(style="word-wrap: break-word; height:{}px; width:{}px".format(height + 10,width + 10), halign="center", valign="top"):
|
| 36 |
+
with p():
|
| 37 |
+
with a(href=os.path.join('/',link)):
|
| 38 |
+
img(style="width:{}px;height:{}".format(width, height), src=os.path.join('/',im))
|
| 39 |
+
br()
|
| 40 |
+
p(txt)
|
| 41 |
+
|
| 42 |
+
def save(self):
|
| 43 |
+
html_file = '%s/index.html' % self.web_dir
|
| 44 |
+
f = open(html_file, 'wt')
|
| 45 |
+
f.write(self.doc.render())
|
| 46 |
+
f.close()
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
if __name__ == '__main__':
|
| 50 |
+
html = HTML('web/', 'test_html')
|
| 51 |
+
html.add_header('hello world')
|
| 52 |
+
|
| 53 |
+
ims = []
|
| 54 |
+
txts = []
|
| 55 |
+
links = []
|
| 56 |
+
for n in range(4):
|
| 57 |
+
ims.append('image_%d.png' % n)
|
| 58 |
+
txts.append('text_%d' % n)
|
| 59 |
+
links.append('image_%d.png' % n)
|
| 60 |
+
html.add_images(ims, txts, links)
|
| 61 |
+
html.save()
|
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/html_server.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import http.server
|
| 2 |
+
import socketserver
|
| 3 |
+
|
| 4 |
+
PORT = 8081
|
| 5 |
+
Handler = http.server.SimpleHTTPRequestHandler
|
| 6 |
+
|
| 7 |
+
with socketserver.TCPServer(("", PORT), Handler) as httpd:
|
| 8 |
+
print("serving at port", PORT)
|
| 9 |
+
httpd.serve_forever()
|
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/imgs
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
/home/ysheng/Dataset/dropbox/Research/SSN_Training_Share/touch
|
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/index.html
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/make_html.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import json
|
| 3 |
+
import pdb
|
| 4 |
+
import os
|
| 5 |
+
from os.path import join
|
| 6 |
+
import html
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
import argparse
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
|
| 12 |
+
def get_files(folder):
|
| 13 |
+
return [join(folder, f) for f in os.listdir(folder) if os.path.isfile(join(folder, f))]
|
| 14 |
+
|
| 15 |
+
def get_folders(folder):
|
| 16 |
+
return [join(folder, f) for f in os.listdir(folder) if os.path.isdir(join(folder, f))]
|
| 17 |
+
|
| 18 |
+
vis_img_folder = 'imgs'
|
| 19 |
+
vis_eval_img_folder = 'eval_imgs'
|
| 20 |
+
def eval_gen(webpage, output_folder, is_pattern=True):
|
| 21 |
+
def get_file(files, key_world):
|
| 22 |
+
mitsuba_shadow = ''
|
| 23 |
+
for f in files:
|
| 24 |
+
if f.find(key_world) != -1:
|
| 25 |
+
mitsuba_shadow = f
|
| 26 |
+
break
|
| 27 |
+
return mitsuba_shadow
|
| 28 |
+
|
| 29 |
+
def flip_shadow(img_file):
|
| 30 |
+
dirname, fname = os.path.dirname(img_file), os.path.splitext(os.path.basename(img_file))[0]
|
| 31 |
+
if img_file == '':
|
| 32 |
+
print('find one zero')
|
| 33 |
+
mts_shadow_np = np.zeros((256,256,3))
|
| 34 |
+
else:
|
| 35 |
+
mts_shadow_np = plt.imread(img_file)
|
| 36 |
+
|
| 37 |
+
save_path = join(dirname, fname + '_flip.png')
|
| 38 |
+
plt.imsave(save_path, 1.0-mts_shadow_np)
|
| 39 |
+
return save_path
|
| 40 |
+
|
| 41 |
+
img_folders = join(output_folder, 'imgs')
|
| 42 |
+
folders = get_folders(img_folders)
|
| 43 |
+
print("There are {} folders".format(len(folders)))
|
| 44 |
+
|
| 45 |
+
for model in tqdm(folders):
|
| 46 |
+
cur_model_relative = join(vis_img_folder, os.path.basename(model))
|
| 47 |
+
evl_cur_model_relative = join(vis_eval_img_folder, os.path.basename(model))
|
| 48 |
+
|
| 49 |
+
if is_pattern:
|
| 50 |
+
ibl_relative = join(cur_model_relative, 'pattern')
|
| 51 |
+
else:
|
| 52 |
+
ibl_relative = join(cur_model_relative, 'real')
|
| 53 |
+
|
| 54 |
+
# import pdb; pdb.set_trace()
|
| 55 |
+
ibl_folders = get_folders(ibl_relative)
|
| 56 |
+
ibl_folders.sort()
|
| 57 |
+
for ibl in ibl_folders:
|
| 58 |
+
cur_ibl_relative = join(ibl_relative,os.path.basename(ibl))
|
| 59 |
+
gt_files = get_files(cur_ibl_relative)
|
| 60 |
+
mts_shadow = get_file(gt_files, '_shadow.png')
|
| 61 |
+
|
| 62 |
+
ibl_name = os.path.basename(ibl)
|
| 63 |
+
ibl = join(cur_ibl_relative, ibl_name + '.png')
|
| 64 |
+
|
| 65 |
+
mitsuba_shadow = flip_shadow(mts_shadow)
|
| 66 |
+
|
| 67 |
+
cur_eval_folder = join(evl_cur_model_relative, join('pattern', ibl_name))
|
| 68 |
+
net_predict = get_file(get_files(cur_eval_folder), 'predict.png')
|
| 69 |
+
|
| 70 |
+
# mitsuba_final = join(cur_ibl_relative, 'composite.png')
|
| 71 |
+
# pred_final = join(cur_ibl_relative, 'composite_pred.png')
|
| 72 |
+
|
| 73 |
+
# print(ibl_name)
|
| 74 |
+
ims, txts, links = [ibl,mitsuba_shadow, net_predict], ['ibl','mitsuba', 'predict'], [ibl,mitsuba_shadow, net_predict]
|
| 75 |
+
|
| 76 |
+
webpage.add_images(ims, txts, links)
|
| 77 |
+
|
| 78 |
+
vis_pattern_folder = '/home/ysheng/Documents/vis_pattern'
|
| 79 |
+
vis_real_folder = '/home/ysheng/Documents/vis_real'
|
| 80 |
+
def vis_files_in_folder():
|
| 81 |
+
folder = '/home/ysheng/Documents/vis_models'
|
| 82 |
+
webpage = html.HTML(folder, 'models', reflesh=1)
|
| 83 |
+
img_folders = join(folder, 'imgs')
|
| 84 |
+
files = get_files(img_folders)
|
| 85 |
+
print("There are {} files".format(len(files)))
|
| 86 |
+
|
| 87 |
+
prefix_set = set()
|
| 88 |
+
for cur_file in tqdm(files):
|
| 89 |
+
cur_name = os.path.splitext(os.path.basename(cur_file))[0]
|
| 90 |
+
prefix_set.add(cur_name[:-3])
|
| 91 |
+
|
| 92 |
+
print('there are {} prefixs'.format(len(prefix_set)))
|
| 93 |
+
prefix_set = list(prefix_set)
|
| 94 |
+
prefix_set.sort()
|
| 95 |
+
|
| 96 |
+
# import pdb; pdb.set_trace()
|
| 97 |
+
relative_folder = './imgs'
|
| 98 |
+
for i, prefix in enumerate(prefix_set):
|
| 99 |
+
ims = [join(relative_folder, prefix + '{:03d}.png'.format(i)) for i in range(len(files) // len(prefix_set))]
|
| 100 |
+
txts = [prefix + '{:03d}'.format(i) for i in range(len(files) // len(prefix_set))]
|
| 101 |
+
links = ims
|
| 102 |
+
webpage.add_images(ims, txts, links)
|
| 103 |
+
|
| 104 |
+
webpage.save()
|
| 105 |
+
print('finished')
|
| 106 |
+
|
| 107 |
+
def vis_files(df_file):
|
| 108 |
+
""" input is a pandas dataframe
|
| 109 |
+
format: path, path,..., name,name, ...
|
| 110 |
+
"""
|
| 111 |
+
folder = '.'
|
| 112 |
+
webpage = html.HTML(folder, 'benchmark', reflesh=1)
|
| 113 |
+
|
| 114 |
+
relative_folder = './imgs'
|
| 115 |
+
# for i, prefix in enumerate(prefix_set):
|
| 116 |
+
# ims = [join(relative_folder, prefix + '{:03d}.png'.format(i)) for i in range(len(files) // len(prefix_set))]
|
| 117 |
+
# txts = [prefix + '{:03d}'.format(i) for i in range(len(files) // len(prefix_set))]
|
| 118 |
+
# links = ims
|
| 119 |
+
# webpage.add_images(ims, txts, links)
|
| 120 |
+
|
| 121 |
+
df = pd.read_csv(df_file)
|
| 122 |
+
for i,v in tqdm(df.iterrows(), total=len(df)):
|
| 123 |
+
img_range = len(v)//2+1
|
| 124 |
+
imgs = [join(relative_folder,v[i]) for i in range(1,img_range)]
|
| 125 |
+
txts = [v[i] for i in range(img_range, len(v))]
|
| 126 |
+
links = imgs
|
| 127 |
+
webpage.add_images(imgs, txts, links)
|
| 128 |
+
|
| 129 |
+
webpage.save()
|
| 130 |
+
print('finished')
|
| 131 |
+
|
| 132 |
+
if __name__ == "__main__":
|
| 133 |
+
vis_files('/home/ysheng/Documents/paper_project/adobe/soft_shadow/benchmark_results/html.csv')
|
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/net_utils.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#import matplotlib.pyplot as plt
|
| 2 |
+
import os
|
| 3 |
+
from torchvision import transforms, utils
|
| 4 |
+
import torch
|
| 5 |
+
#import matplotlib.pyplot as plt
|
| 6 |
+
import numpy as np
|
| 7 |
+
from utils.utils_file import get_cur_time_stamp, create_folder
|
| 8 |
+
|
| 9 |
+
def compute_differentiable_params(net):
|
| 10 |
+
return sum(p.numel() for p in net.parameters() if p.requires_grad)
|
| 11 |
+
|
| 12 |
+
def convert_Relight_latent_light(latent_feature):
|
| 13 |
+
""" Convert n x 6 x 16 x 16 -> n x 3 x 16 x 32 """
|
| 14 |
+
# torch image: C X H X W
|
| 15 |
+
batch_size, C, H, W = latent_feature.size()
|
| 16 |
+
latent_feature = torch.reshape(latent_feature, (batch_size, 3, 16, 32)) # make sure it is right
|
| 17 |
+
# print(latent_feature.size())
|
| 18 |
+
return latent_feature
|
| 19 |
+
|
| 20 |
+
def show_batch(sample_batch, out_file=None):
|
| 21 |
+
grid = utils.make_grid(sample_batch)
|
| 22 |
+
plt.figure(figsize=(30,20))
|
| 23 |
+
plt.imshow(grid.detach().cpu().numpy().transpose((1,2,0)))
|
| 24 |
+
|
| 25 |
+
if not out_file is None:
|
| 26 |
+
print('try save ', out_file)
|
| 27 |
+
plt.savefig(out_file)
|
| 28 |
+
|
| 29 |
+
plt.show()
|
| 30 |
+
|
| 31 |
+
def show_light_batch(light_batch):
|
| 32 |
+
light_batch = convert_Relight_latent_light(light_batch)
|
| 33 |
+
show_batch(light_batch)
|
| 34 |
+
|
| 35 |
+
def save_loss(figure_fname, train_loss, valid_loss):
|
| 36 |
+
plt.plot(train_loss)
|
| 37 |
+
plt.plot(valid_loss)
|
| 38 |
+
plt.legend(['train_loss', 'valid_loss'])
|
| 39 |
+
plt.savefig(figure_fname)
|
| 40 |
+
|
| 41 |
+
def save_model(output_folder, model, optimizer, epoch, best_loss, fname, hist_train_loss, hist_valid_loss, hist_lr, params):
|
| 42 |
+
""" Save current best model into some folder """
|
| 43 |
+
create_folder(output_folder)
|
| 44 |
+
|
| 45 |
+
# cur_time_stamp = get_cur_time_stamp()
|
| 46 |
+
# output_fname = os.path.join(output_folder, exp_name + '_' + cur_time_stamp + ".pt")
|
| 47 |
+
output_fname = os.path.join(output_folder, fname)
|
| 48 |
+
tmp_model = model
|
| 49 |
+
if params.multi_gpu and hasattr(tmp_model, 'module'):
|
| 50 |
+
tmp_model = model.module
|
| 51 |
+
|
| 52 |
+
torch.save({
|
| 53 |
+
'epoch': epoch,
|
| 54 |
+
'best_loss': best_loss,
|
| 55 |
+
'model_state_dict': tmp_model.state_dict(),
|
| 56 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 57 |
+
'hist_train_loss': hist_train_loss,
|
| 58 |
+
'hist_valid_loss': hist_valid_loss,
|
| 59 |
+
'hist_lr':hist_lr,
|
| 60 |
+
'params':str(params)
|
| 61 |
+
}, output_fname)
|
| 62 |
+
return output_fname
|
| 63 |
+
|
| 64 |
+
def get_lr(optimizer):
|
| 65 |
+
for param_group in optimizer.param_groups:
|
| 66 |
+
return param_group['lr']
|
| 67 |
+
|
| 68 |
+
def set_lr(optimizer, lr):
|
| 69 |
+
for param_group in optimizer.param_groups:
|
| 70 |
+
param_group['lr'] = lr
|
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/tensorboard_utils.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 4 |
+
|
| 5 |
+
def tensorboard_plot_loss(win_name, loss, writer):
|
| 6 |
+
writer.add_scalar("Loss/{}".format(win_name), loss[-1], len(loss))
|
| 7 |
+
writer.flush()
|
| 8 |
+
|
| 9 |
+
def normalize_img(imgs):
|
| 10 |
+
b,c,h,w = imgs.shape
|
| 11 |
+
gt_batch = b//2
|
| 12 |
+
for i in range(gt_batch):
|
| 13 |
+
factor = torch.max(imgs[i])
|
| 14 |
+
imgs[i] = imgs[i]/factor
|
| 15 |
+
imgs[gt_batch + i] = imgs[gt_batch + i]/factor
|
| 16 |
+
|
| 17 |
+
imgs = torch.clamp(imgs, 0.0,1.0)
|
| 18 |
+
return imgs
|
| 19 |
+
|
| 20 |
+
def tensorboard_show_batch(imgs, writer, win_name=None, nrow=2, normalize=True, step=0):
|
| 21 |
+
if normalize:
|
| 22 |
+
imgs = normalize_img(imgs)
|
| 23 |
+
|
| 24 |
+
writer.add_images('{}'.format(win_name), imgs, step)
|
| 25 |
+
writer.flush()
|
| 26 |
+
|
| 27 |
+
def tensorboard_log(log_info, writer, win_name='logger', step=0):
|
| 28 |
+
writer.add_text(win_name, log_info, step)
|
| 29 |
+
writer.flush()
|
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/time_utils.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datetime
|
| 2 |
+
|
| 3 |
+
def get_time_stamp():
|
| 4 |
+
return '{:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now())
|
| 5 |
+
|
| 6 |
+
|
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/utils_file.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from shutil import copyfile
|
| 2 |
+
from os import listdir
|
| 3 |
+
from os.path import isfile, join
|
| 4 |
+
import os
|
| 5 |
+
import datetime
|
| 6 |
+
|
| 7 |
+
def get_all_folders(folder):
|
| 8 |
+
if check_file_exists(folder) == False:
|
| 9 |
+
print("Cannot find the folder ", folder)
|
| 10 |
+
return []
|
| 11 |
+
subfolders = [f for f in os.listdir(folder) if not isfile(join(folder,f))]
|
| 12 |
+
return subfolders
|
| 13 |
+
|
| 14 |
+
def get_all_files(folder):
|
| 15 |
+
if check_file_exists(folder) == False:
|
| 16 |
+
print("Cannot find the folder ", folder)
|
| 17 |
+
return []
|
| 18 |
+
|
| 19 |
+
ori_files = [f for f in listdir(folder) if isfile(join(folder, f))]
|
| 20 |
+
return ori_files
|
| 21 |
+
|
| 22 |
+
def create_folder(folder):
|
| 23 |
+
if not os.path.exists(folder):
|
| 24 |
+
os.mkdir(folder)
|
| 25 |
+
|
| 26 |
+
def create_folders(folder_list):
|
| 27 |
+
for f in folder_list:
|
| 28 |
+
create_folder(f)
|
| 29 |
+
|
| 30 |
+
def replace_file_ext(fname, new_ext):
|
| 31 |
+
ext_pos = fname.find(".")
|
| 32 |
+
if ext_pos != -1:
|
| 33 |
+
return fname[0:ext_pos] + "."+ new_ext
|
| 34 |
+
else:
|
| 35 |
+
print("Please check " + fname)
|
| 36 |
+
|
| 37 |
+
def check_file_exists(fname, verbose=True):
|
| 38 |
+
try:
|
| 39 |
+
if not os.path.exists(fname) or get_file_size(fname) == 0:
|
| 40 |
+
if verbose:
|
| 41 |
+
print("file {} does not exists! ".format(fname))
|
| 42 |
+
return False
|
| 43 |
+
except:
|
| 44 |
+
print("File {} has some issue! ".format(fname))
|
| 45 |
+
return False
|
| 46 |
+
return True
|
| 47 |
+
|
| 48 |
+
def delete_file(fname):
|
| 49 |
+
if check_file_exists(fname):
|
| 50 |
+
os.remove(fname)
|
| 51 |
+
|
| 52 |
+
def get_file_size(fname):
|
| 53 |
+
return os.path.getsize(fname)
|
| 54 |
+
|
| 55 |
+
def get_folder_size(folder):
|
| 56 |
+
return sum(os.path.getsize(folder + f) for f in listdir(folder) if isfile(join(folder, f)))
|
| 57 |
+
|
| 58 |
+
def get_cur_time_stamp():
|
| 59 |
+
return datetime.datetime.now().strftime("%d-%B-%I-%M-%p")
|
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/vis_test_results.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import json
|
| 3 |
+
import pdb
|
| 4 |
+
import os
|
| 5 |
+
from os.path import join
|
| 6 |
+
import html
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
import argparse
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import cv2 as cv
|
| 11 |
+
|
| 12 |
+
base_softness = "0.1"
|
| 13 |
+
base_exp_name = "fov_results_real_89"
|
| 14 |
+
root_dir = "/mnt/share/yifan/code/soft_shadow-master/vis_res/"
|
| 15 |
+
base_dir = root_dir + base_exp_name + '/' + base_softness
|
| 16 |
+
|
| 17 |
+
exp_name = ["fov_results_real_hd", "fov_results_real_hd_0713"]
|
| 18 |
+
|
| 19 |
+
softness = [0.1, 0.2]
|
| 20 |
+
|
| 21 |
+
case_name = ["case1", "case2", "case7"]
|
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/visdom_utils.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from visdom import Visdom
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
# viz = Visdom(port=8002)
|
| 6 |
+
# viz2 = Visdom(port=8003)
|
| 7 |
+
|
| 8 |
+
def setup_visdom(port=8002):
|
| 9 |
+
return Visdom(port=port)
|
| 10 |
+
|
| 11 |
+
def visdom_plot_loss(win_name, loss, cur_viz):
|
| 12 |
+
loss_np = np.array(loss)
|
| 13 |
+
x = np.arange(1, 1 + len(loss))
|
| 14 |
+
cur_viz.line(win=win_name,
|
| 15 |
+
X=x,
|
| 16 |
+
Y=loss_np,
|
| 17 |
+
opts=dict(showlegend=True, legend=[win_name]))
|
| 18 |
+
|
| 19 |
+
def guassian_light(light_tensor):
|
| 20 |
+
light_tensor = light_tensor.detach().cpu()
|
| 21 |
+
channel = light_tensor.size()[0]
|
| 22 |
+
tensor_ret = torch.zeros(light_tensor.size())
|
| 23 |
+
for i in range(channel):
|
| 24 |
+
light_np = light_tensor[0].numpy() * 100.0
|
| 25 |
+
light_np = gaussian_filter(light_np, sigma=2)
|
| 26 |
+
tensor_ret[i] = torch.from_numpy(light_np)
|
| 27 |
+
tensor_ret[i] = torch.clamp(tensor_ret[i], 0.0, 1.0)
|
| 28 |
+
|
| 29 |
+
return tensor_ret
|
| 30 |
+
|
| 31 |
+
def normalize_img(imgs):
|
| 32 |
+
b,c,h,w = imgs.shape
|
| 33 |
+
gt_batch = b//2
|
| 34 |
+
for i in range(gt_batch):
|
| 35 |
+
factor = torch.max(imgs[i])
|
| 36 |
+
imgs[i] = imgs[i]/factor
|
| 37 |
+
imgs[gt_batch + i] = imgs[gt_batch + i]/factor
|
| 38 |
+
# imgs[i] = imgs[i]/3.0
|
| 39 |
+
|
| 40 |
+
imgs = torch.clamp(imgs, 0.0,1.0)
|
| 41 |
+
return imgs
|
| 42 |
+
|
| 43 |
+
def visdom_show_batch(imgs, cur_viz, win_name=None, nrow=2, normalize=True):
|
| 44 |
+
if normalize:
|
| 45 |
+
imgs = normalize_img(imgs)
|
| 46 |
+
|
| 47 |
+
if win_name is None:
|
| 48 |
+
cur_viz.images(imgs, win="batch visualize",nrow=nrow)
|
| 49 |
+
else:
|
| 50 |
+
cur_viz.images(imgs, win=win_name, opts=dict(title=win_name),nrow=nrow)
|
| 51 |
+
|
| 52 |
+
def visdom_log(log_info, viz, win_name='logger'):
|
| 53 |
+
viz.text(log_info, win=win_name)
|
PixHtLab-Src/Demo/PixhtLab/Torch_Render/hshadow_cuda.cpp
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/extension.h>
|
| 2 |
+
|
| 3 |
+
#include <vector>
|
| 4 |
+
#include <stdio.h>
|
| 5 |
+
|
| 6 |
+
// CUDA forward declarations
|
| 7 |
+
std::vector<torch::Tensor> hshadow_render_cuda_forward(torch::Tensor rgb, torch::Tensor mask, torch::Tensor mask_bb, torch::Tensor hmap, torch::Tensor rechmap, torch::Tensor light_pos);
|
| 8 |
+
std::vector<torch::Tensor> reflect_render_cuda_forward(torch::Tensor rgb, torch::Tensor mask, torch::Tensor hmap, torch::Tensor rechmap, torch::Tensor thresholds);
|
| 9 |
+
std::vector<torch::Tensor> glossy_reflect_render_cuda_forward(torch::Tensor rgb, torch::Tensor mask, torch::Tensor hmap, torch::Tensor rechmap, const int sample_n, const float glossy);
|
| 10 |
+
torch::Tensor ray_intersect_cuda_forward(torch::Tensor rgb, torch::Tensor mask, torch::Tensor hmap, torch::Tensor rechmap, torch::Tensor rd_map);
|
| 11 |
+
torch::Tensor ray_scene_intersect_cuda_forward(torch::Tensor rgb, torch::Tensor mask, torch::Tensor hmap, torch::Tensor ro, torch::Tensor rd, float dh);
|
| 12 |
+
|
| 13 |
+
// C++ interface
|
| 14 |
+
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
|
| 15 |
+
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
|
| 16 |
+
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
|
| 17 |
+
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
| 18 |
+
|
| 19 |
+
/* Heightmap Shadow Rendering
|
| 20 |
+
rgb: B x 3 x H x W
|
| 21 |
+
mask: B x 1 x H x W
|
| 22 |
+
mask: B x 1
|
| 23 |
+
hmap: B x 1 x H x W
|
| 24 |
+
rechmap: B x 1 x H x W
|
| 25 |
+
light_pos: B x 1 (x,y,h)
|
| 26 |
+
*/
|
| 27 |
+
std::vector<torch::Tensor> hshadow_render_forward(torch::Tensor rgb, torch::Tensor mask, torch::Tensor bb, torch::Tensor hmap, torch::Tensor rechmap, torch::Tensor light_pos) {
|
| 28 |
+
CHECK_INPUT(rgb);
|
| 29 |
+
CHECK_INPUT(mask);
|
| 30 |
+
CHECK_INPUT(bb);
|
| 31 |
+
CHECK_INPUT(hmap);
|
| 32 |
+
CHECK_INPUT(rechmap);
|
| 33 |
+
CHECK_INPUT(light_pos);
|
| 34 |
+
|
| 35 |
+
return hshadow_render_cuda_forward(rgb, mask, bb, hmap, rechmap, light_pos);
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
std::vector<torch::Tensor> reflect_render_forward(torch::Tensor rgb, torch::Tensor mask, torch::Tensor hmap, torch::Tensor rechmap, torch::Tensor thresholds) {
|
| 39 |
+
CHECK_INPUT(rgb);
|
| 40 |
+
CHECK_INPUT(mask);
|
| 41 |
+
CHECK_INPUT(hmap);
|
| 42 |
+
CHECK_INPUT(rechmap);
|
| 43 |
+
CHECK_INPUT(thresholds);
|
| 44 |
+
|
| 45 |
+
return reflect_render_cuda_forward(rgb, mask, hmap, rechmap, thresholds);
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
std::vector<torch::Tensor> glossy_reflect_render_forward(torch::Tensor rgb,
|
| 50 |
+
torch::Tensor mask,
|
| 51 |
+
torch::Tensor hmap,
|
| 52 |
+
torch::Tensor rechmap,
|
| 53 |
+
int sample_n,
|
| 54 |
+
float glossy) {
|
| 55 |
+
CHECK_INPUT(rgb);
|
| 56 |
+
CHECK_INPUT(mask);
|
| 57 |
+
CHECK_INPUT(hmap);
|
| 58 |
+
CHECK_INPUT(rechmap);
|
| 59 |
+
|
| 60 |
+
return glossy_reflect_render_cuda_forward(rgb, mask, hmap, rechmap, sample_n, glossy);
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
torch::Tensor ray_intersect_foward(torch::Tensor rgb,
|
| 65 |
+
torch::Tensor mask,
|
| 66 |
+
torch::Tensor hmap,
|
| 67 |
+
torch::Tensor rechmap,
|
| 68 |
+
torch::Tensor rd_map) {
|
| 69 |
+
CHECK_INPUT(rgb);
|
| 70 |
+
CHECK_INPUT(mask);
|
| 71 |
+
CHECK_INPUT(hmap);
|
| 72 |
+
CHECK_INPUT(rechmap);
|
| 73 |
+
|
| 74 |
+
return ray_intersect_cuda_forward(rgb, mask, hmap, rechmap, rd_map);
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
torch::Tensor ray_scene_intersect_foward(torch::Tensor rgb,
|
| 78 |
+
torch::Tensor mask,
|
| 79 |
+
torch::Tensor hmap,
|
| 80 |
+
torch::Tensor ro,
|
| 81 |
+
torch::Tensor rd,
|
| 82 |
+
float dh) {
|
| 83 |
+
CHECK_INPUT(rgb);
|
| 84 |
+
CHECK_INPUT(mask);
|
| 85 |
+
CHECK_INPUT(hmap);
|
| 86 |
+
CHECK_INPUT(ro);
|
| 87 |
+
CHECK_INPUT(rd);
|
| 88 |
+
|
| 89 |
+
return ray_scene_intersect_cuda_forward(rgb, mask, hmap, ro, rd, dh);
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 93 |
+
m.def("forward", &hshadow_render_forward, "Heightmap Shadow Rendering Forward (CUDA)");
|
| 94 |
+
m.def("reflection", &reflect_render_forward, "Reflection Rendering Forward (CUDA)");
|
| 95 |
+
m.def("glossy_reflection", &glossy_reflect_render_forward, "Glossy Reflection Rendering Forward (CUDA)");
|
| 96 |
+
m.def("ray_intersect", &ray_intersect_foward, "Ray scene intersection");
|
| 97 |
+
m.def("ray_scene_intersect", &ray_scene_intersect_foward, "Ray scene intersection");
|
| 98 |
+
}
|
PixHtLab-Src/Demo/PixhtLab/Torch_Render/hshadow_cuda_kernel.cu
ADDED
|
@@ -0,0 +1,682 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/extension.h>
|
| 2 |
+
|
| 3 |
+
#include <cuda.h>
|
| 4 |
+
#include <cuda_runtime.h>
|
| 5 |
+
|
| 6 |
+
#include <vector>
|
| 7 |
+
#include <stdio.h>
|
| 8 |
+
|
| 9 |
+
namespace {
|
| 10 |
+
template <typename scalar_t>
|
| 11 |
+
__device__
|
| 12 |
+
scalar_t sign(scalar_t t) {
|
| 13 |
+
if (t > 0.0) {
|
| 14 |
+
return (scalar_t)1.0;
|
| 15 |
+
} else {
|
| 16 |
+
return -(scalar_t)1.0;
|
| 17 |
+
}
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
template <typename scalar_t>
|
| 21 |
+
struct vec2 {
|
| 22 |
+
scalar_t x, y;
|
| 23 |
+
|
| 24 |
+
__device__
|
| 25 |
+
vec2() { x=0.0, y=0.0;}
|
| 26 |
+
|
| 27 |
+
__device__
|
| 28 |
+
vec2(scalar_t x, scalar_t y):x(x), y(y) {}
|
| 29 |
+
};
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
template <typename scalar_t>
|
| 33 |
+
struct vec3 {
|
| 34 |
+
scalar_t x, y, z;
|
| 35 |
+
|
| 36 |
+
__device__
|
| 37 |
+
vec3() { x=0.0, y=0.0, z=0.0;}
|
| 38 |
+
|
| 39 |
+
__device__
|
| 40 |
+
vec3(scalar_t x, scalar_t y, scalar_t z):x(x), y(y), z(z) {}
|
| 41 |
+
};
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
template <typename scalar_t>
|
| 45 |
+
__device__
|
| 46 |
+
scalar_t lerp(scalar_t a, scalar_t b, scalar_t t) {
|
| 47 |
+
return (1.0-t) * a + t * b;
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
template <typename scalar_t>
|
| 51 |
+
__device__
|
| 52 |
+
void proj_ground(
|
| 53 |
+
scalar_t x0, scalar_t y0, scalar_t h0,
|
| 54 |
+
scalar_t x1, scalar_t y1, scalar_t h1,
|
| 55 |
+
scalar_t &x2, scalar_t &y2
|
| 56 |
+
) {
|
| 57 |
+
scalar_t t = (0-h0)/(h1-h0);
|
| 58 |
+
x2 = lerp(x0, x1, t);
|
| 59 |
+
y2 = lerp(y0, y1, t);
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
// line checking condition with thickness value dh, which is the height difference for double-height map
|
| 64 |
+
// we can also use dh as a tolerance value
|
| 65 |
+
template <typename scalar_t>
|
| 66 |
+
__device__
|
| 67 |
+
bool check_intersect(
|
| 68 |
+
scalar_t xa, scalar_t ya, scalar_t ha,
|
| 69 |
+
scalar_t xb, scalar_t yb, scalar_t hb,
|
| 70 |
+
scalar_t x, scalar_t y, scalar_t h,
|
| 71 |
+
scalar_t dh, int& flag) {
|
| 72 |
+
scalar_t t = xa == xb ? (y-ya)/(yb-ya):(x-xa)/(xb-xa);
|
| 73 |
+
scalar_t h_ = lerp(ha, hb, t);
|
| 74 |
+
flag = h_ <= h ? 1:-1;
|
| 75 |
+
return (h_ <= h) && (h_ >= h-dh);
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
/*
|
| 79 |
+
* Ray trace in the current scene
|
| 80 |
+
* Given start point xyh, light point xyh, current receiver's height map,
|
| 81 |
+
* Return:
|
| 82 |
+
* 1. if intersect or not
|
| 83 |
+
* 2. the color for the intersection point
|
| 84 |
+
* */
|
| 85 |
+
template <typename scalar_t>
|
| 86 |
+
__device__
|
| 87 |
+
bool ray_trace(vec3<scalar_t> s,
|
| 88 |
+
vec3<scalar_t> l,
|
| 89 |
+
const int bi,
|
| 90 |
+
const torch::PackedTensorAccessor64<scalar_t,4> d_mask,
|
| 91 |
+
const torch::PackedTensorAccessor64<scalar_t,4> d_hmap,
|
| 92 |
+
const torch::PackedTensorAccessor64<scalar_t,4> d_rechmap,
|
| 93 |
+
const torch::PackedTensorAccessor64<scalar_t,4> d_rgb,
|
| 94 |
+
vec3<scalar_t> &out) {
|
| 95 |
+
bool ret = false;
|
| 96 |
+
|
| 97 |
+
const int batch_size = d_rgb.size(0), h = d_rgb.size(2), w = d_rgb.size(3);
|
| 98 |
+
scalar_t lx = l.x;
|
| 99 |
+
scalar_t ly = l.y;
|
| 100 |
+
scalar_t lh = l.z;
|
| 101 |
+
|
| 102 |
+
scalar_t recx = s.x;
|
| 103 |
+
scalar_t recy = s.y;
|
| 104 |
+
scalar_t rech = s.z;
|
| 105 |
+
|
| 106 |
+
scalar_t dirx = lx - recx, diry = ly - recy;
|
| 107 |
+
bool gox = abs(dirx) > abs(diry);
|
| 108 |
+
int searching_n = gox ? w : h;
|
| 109 |
+
int starti = 0, endi = searching_n;
|
| 110 |
+
if (lh > 0) {
|
| 111 |
+
if (gox) {
|
| 112 |
+
starti = recx < lx ? recx:lx;
|
| 113 |
+
|
| 114 |
+
endi = recx < lx ? lx:recx;
|
| 115 |
+
}
|
| 116 |
+
else {
|
| 117 |
+
starti = recy < ly ? recy:ly;
|
| 118 |
+
endi = recy < ly ? ly:recy;
|
| 119 |
+
}
|
| 120 |
+
}
|
| 121 |
+
if (lh < 0) {
|
| 122 |
+
if (gox) {
|
| 123 |
+
starti = recx < lx ? 0:recx;
|
| 124 |
+
endi = recx < lx ? recx:endi;
|
| 125 |
+
}
|
| 126 |
+
else {
|
| 127 |
+
starti = recy < ly ? 0:recy;
|
| 128 |
+
endi = recy < ly ? recy:endi;
|
| 129 |
+
}
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
scalar_t sx, sy;
|
| 133 |
+
int flag = 0, last_flag = 0;
|
| 134 |
+
for(int si = starti; si < endi; ++si) {
|
| 135 |
+
/* Searching Point xyh */
|
| 136 |
+
if (gox) {
|
| 137 |
+
sx = si;
|
| 138 |
+
sy = recy + (sx-recx)/dirx * diry;
|
| 139 |
+
} else {
|
| 140 |
+
sy = si;
|
| 141 |
+
sx = recx + (sy-recy)/diry * dirx;
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
if (sx < 0 || sx > w-1 || sy < 0 || sy > h-1 || d_mask[bi][0][sy][sx] < 0.989) {
|
| 145 |
+
last_flag = 0;
|
| 146 |
+
continue;
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
scalar_t sh0 = d_hmap[bi][0][sy][sx];
|
| 151 |
+
scalar_t sh1, sh;
|
| 152 |
+
// do linear interpolation for sh; note that either sy or sx are floating number
|
| 153 |
+
if (gox) {
|
| 154 |
+
if ( sy+1 > h-1 || d_mask[bi][0][sy+1][sx] < 0.989)
|
| 155 |
+
sh = sh0;
|
| 156 |
+
else {
|
| 157 |
+
sh1 = d_hmap[bi][0][sy+1][sx];
|
| 158 |
+
sh = lerp(sh0, sh1, sy - int(sy));
|
| 159 |
+
}
|
| 160 |
+
}
|
| 161 |
+
else {
|
| 162 |
+
if ( sx+1 > w-1 || d_mask[bi][0][sy][sx+1] < 0.989)
|
| 163 |
+
sh = sh0;
|
| 164 |
+
else {
|
| 165 |
+
sh1 = d_hmap[bi][0][sy][sx+1];
|
| 166 |
+
sh = lerp(sh0, sh1, sx - int(sx));
|
| 167 |
+
}
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
scalar_t dh = 1.0; // this controls the thickness; for double height map, dh = h_f - h_b
|
| 171 |
+
bool intersect = check_intersect(recx, recy, rech, lx, ly, lh, sx, sy, sh, dh, flag);
|
| 172 |
+
if (intersect) {
|
| 173 |
+
/* TODO, which sampling? linear interpolation? */
|
| 174 |
+
out.x = d_rgb[bi][0][(int)sy][(int)sx];
|
| 175 |
+
out.y = d_rgb[bi][1][(int)sy][(int)sx];
|
| 176 |
+
out.z = d_rgb[bi][2][(int)sy][(int)sx];
|
| 177 |
+
|
| 178 |
+
ret = true;
|
| 179 |
+
break;
|
| 180 |
+
}
|
| 181 |
+
if (last_flag != 0){
|
| 182 |
+
if (last_flag != flag) {
|
| 183 |
+
out.x = d_rgb[bi][0][(int)sy][(int)sx];
|
| 184 |
+
out.y = d_rgb[bi][1][(int)sy][(int)sx];
|
| 185 |
+
out.z = d_rgb[bi][2][(int)sy][(int)sx];
|
| 186 |
+
|
| 187 |
+
ret = true;
|
| 188 |
+
break;
|
| 189 |
+
}
|
| 190 |
+
}
|
| 191 |
+
last_flag = flag;
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
return ret;
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
/*
|
| 198 |
+
* Ray trace in the current scene
|
| 199 |
+
* Given start point xyh, light point xyh, current receiver's height map,
|
| 200 |
+
* Return:
|
| 201 |
+
* 1. if intersect or not
|
| 202 |
+
* 2. the color for the intersection point
|
| 203 |
+
* */
|
| 204 |
+
template <typename scalar_t>
|
| 205 |
+
__device__
|
| 206 |
+
bool ray_scene_intersect(vec3<scalar_t> ro,
|
| 207 |
+
vec3<scalar_t> rd,
|
| 208 |
+
const scalar_t dh,
|
| 209 |
+
const int bi,
|
| 210 |
+
const torch::PackedTensorAccessor64<scalar_t,4> d_mask,
|
| 211 |
+
const torch::PackedTensorAccessor64<scalar_t,4> d_hmap,
|
| 212 |
+
const torch::PackedTensorAccessor64<scalar_t,4> d_rgb,
|
| 213 |
+
vec3<scalar_t> &out) {
|
| 214 |
+
bool ret = false;
|
| 215 |
+
int h = d_mask.size(2);
|
| 216 |
+
int w = d_mask.size(3);
|
| 217 |
+
scalar_t dirx = rd.x, diry = rd.y, dirh = rd.z;
|
| 218 |
+
|
| 219 |
+
/* Special Case, there's no direction update in x or y, but in h */
|
| 220 |
+
if (abs(dirx) < 1e-6f && abs(diry) < 1e-6f) {
|
| 221 |
+
out.x = d_rgb[bi][0][(int)ro.y][(int)ro.x];
|
| 222 |
+
out.y = d_rgb[bi][1][(int)ro.y][(int)ro.x];
|
| 223 |
+
out.z = d_rgb[bi][2][(int)ro.y][(int)ro.x];
|
| 224 |
+
return true;
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
bool gox = abs(dirx) > abs(diry);
|
| 228 |
+
int searching_n = gox ? w : h;
|
| 229 |
+
|
| 230 |
+
scalar_t cur_h;
|
| 231 |
+
scalar_t sx, sy;
|
| 232 |
+
|
| 233 |
+
int prev_sign, cur_sign;
|
| 234 |
+
|
| 235 |
+
// for(int si = starti; si < endi; ++si) {
|
| 236 |
+
for(int si = 0; si < searching_n; ++si) {
|
| 237 |
+
/* Searching Point XYH */
|
| 238 |
+
if (gox) {
|
| 239 |
+
sx = ro.x + si * sign(dirx);
|
| 240 |
+
sy = ro.y + (sx-ro.x)/dirx * diry;
|
| 241 |
+
} else {
|
| 242 |
+
sy = ro.y + si * sign(diry);
|
| 243 |
+
sx = ro.x + (sy-ro.y)/diry * dirx;
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
if (sx < 0 || sx > w-1 || sy < 0 || sy > h-1 || d_mask[bi][0][sy][sx] < 0.989) {
|
| 247 |
+
continue;
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
scalar_t sh0 = d_hmap[bi][0][sy][sx];
|
| 251 |
+
scalar_t sh1, sh;
|
| 252 |
+
// do linear interpolation for sh; note that either sy or sx are floating number
|
| 253 |
+
if (gox) {
|
| 254 |
+
if (sy+1 > h-1 || d_mask[bi][0][sy+1][sx] < 0.989)
|
| 255 |
+
sh = sh0;
|
| 256 |
+
else {
|
| 257 |
+
sh1 = d_hmap[bi][0][sy+1][sx];
|
| 258 |
+
sh = lerp(sh0, sh1, sy - int(sy)); // Always use 0.5 to do interpolation
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
cur_h = ro.z + (sx - ro.x) / dirx * dirh;
|
| 262 |
+
}
|
| 263 |
+
else {
|
| 264 |
+
if ( sx + 1 > w-1 || d_mask[bi][0][sy][sx+1] < 0.989)
|
| 265 |
+
sh = sh0;
|
| 266 |
+
else {
|
| 267 |
+
sh1 = d_hmap[bi][0][sy][sx+1];
|
| 268 |
+
sh = lerp(sh0, sh1, sx - int(sx));
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
cur_h = ro.z + (sy - ro.y) / diry * dirh;
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
// collide with the rechmap?
|
| 275 |
+
if (si == 0) { /* First sign */
|
| 276 |
+
cur_sign = cur_h - sh;
|
| 277 |
+
continue;
|
| 278 |
+
} else {
|
| 279 |
+
prev_sign = cur_sign;
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
cur_sign = cur_h - sh;
|
| 283 |
+
// if (cur_sign * prev_sign < 0.0 || abs(cur_sign) < dh) { /* pass through some objects */
|
| 284 |
+
if (abs(cur_sign) < dh) { /* pass through some objects */
|
| 285 |
+
out.x = d_rgb[bi][0][(int)sy][(int)sx];
|
| 286 |
+
out.y = d_rgb[bi][1][(int)sy][(int)sx];
|
| 287 |
+
out.z = d_rgb[bi][2][(int)sy][(int)sx];
|
| 288 |
+
ret = true;
|
| 289 |
+
break;
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
return ret;
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
template <typename scalar_t>
|
| 299 |
+
__global__ void hshadow_render_cuda_forward(
|
| 300 |
+
const torch::PackedTensorAccessor64<scalar_t,4> d_rgb,
|
| 301 |
+
const torch::PackedTensorAccessor64<scalar_t,4> d_mask,
|
| 302 |
+
const torch::PackedTensorAccessor64<scalar_t,2> d_bb,
|
| 303 |
+
const torch::PackedTensorAccessor64<scalar_t,4> d_hmap,
|
| 304 |
+
const torch::PackedTensorAccessor64<scalar_t,4> d_rechmap,
|
| 305 |
+
const torch::PackedTensorAccessor64<scalar_t,2> d_lightpos,
|
| 306 |
+
torch::PackedTensorAccessor64<scalar_t,4> d_shadow) {
|
| 307 |
+
const int wstride = gridDim.x * blockDim.x, hstride = gridDim.y * blockDim.y, bstride = gridDim.z * blockDim.z;
|
| 308 |
+
const int batch_size = d_rgb.size(0), h = d_rgb.size(2), w = d_rgb.size(3);
|
| 309 |
+
|
| 310 |
+
for (int bi = blockIdx.z; bi < batch_size; bi += bstride) {
|
| 311 |
+
/* light xyh */
|
| 312 |
+
scalar_t lx = d_lightpos[bi][0], ly = d_lightpos[bi][1], lh = d_lightpos[bi][2];
|
| 313 |
+
int minh = max((int)d_bb[bi][0], 0), maxh = min((int)d_bb[bi][1], h-1), minw = max((int)d_bb[bi][2], 0), maxw = min((int)d_bb[bi][3], w-1);
|
| 314 |
+
|
| 315 |
+
vec3<scalar_t> light(lx, ly, lh);
|
| 316 |
+
for (int wi = blockIdx.x * blockDim.x + threadIdx.x; wi < w; wi += wstride) for(int hi = blockIdx.y * blockDim.y + threadIdx.y; hi < h; hi += hstride) {
|
| 317 |
+
scalar_t shadow(1.0), mask_alpha(0.0);
|
| 318 |
+
scalar_t recx = wi + 0.5, recy = hi+0.5, rech = d_rechmap[bi][0][hi][wi];
|
| 319 |
+
|
| 320 |
+
vec3<scalar_t> start(recx, recy, rech);
|
| 321 |
+
vec3<scalar_t> intersect_color;
|
| 322 |
+
|
| 323 |
+
/* Searching Potentials */
|
| 324 |
+
if (ray_trace(start, light, bi, d_mask, d_hmap, d_rechmap, d_rgb, intersect_color)) {
|
| 325 |
+
shadow = 0.0;
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
d_shadow[bi][0][hi][wi] = shadow;
|
| 329 |
+
d_shadow[bi][1][hi][wi] = shadow;
|
| 330 |
+
d_shadow[bi][2][hi][wi] = shadow;
|
| 331 |
+
}
|
| 332 |
+
}
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
template <typename scalar_t>
|
| 337 |
+
__global__ void ray_intersect_cuda_forward(
|
| 338 |
+
const torch::PackedTensorAccessor64<scalar_t,4> d_rgb,
|
| 339 |
+
const torch::PackedTensorAccessor64<scalar_t,4> d_mask,
|
| 340 |
+
const torch::PackedTensorAccessor64<scalar_t,4> d_hmap,
|
| 341 |
+
const torch::PackedTensorAccessor64<scalar_t,4> d_rechmap,
|
| 342 |
+
const torch::PackedTensorAccessor64<scalar_t,4> d_rd_map,
|
| 343 |
+
torch::PackedTensorAccessor64<scalar_t,4> d_intersect) {
|
| 344 |
+
|
| 345 |
+
const int wstride = gridDim.x * blockDim.x, hstride = gridDim.y * blockDim.y, bstride = gridDim.z * blockDim.z;
|
| 346 |
+
const int batch_size = d_rgb.size(0), h = d_rgb.size(2), w = d_rgb.size(3);
|
| 347 |
+
const scalar_t default_value = 0.0;
|
| 348 |
+
|
| 349 |
+
for (int bi = blockIdx.z; bi < batch_size; bi += bstride) {
|
| 350 |
+
for (int wi = blockIdx.x * blockDim.x + threadIdx.x; wi < w; wi += wstride) for(int hi = blockIdx.y * blockDim.y + threadIdx.y; hi < h; hi += hstride) {
|
| 351 |
+
scalar_t shadow(1.0), mask_alpha(0.0);
|
| 352 |
+
scalar_t recx = wi + 0.5, recy = hi+0.5, rech = d_rechmap[bi][0][hi][wi];
|
| 353 |
+
|
| 354 |
+
scalar_t lx = d_rd_map[bi][0][hi][wi];
|
| 355 |
+
scalar_t ly = d_rd_map[bi][1][hi][wi];
|
| 356 |
+
scalar_t lh = d_rd_map[bi][2][hi][wi];
|
| 357 |
+
|
| 358 |
+
vec3<scalar_t> start(recx, recy, rech);
|
| 359 |
+
vec3<scalar_t> rd(lx, ly, lh);
|
| 360 |
+
|
| 361 |
+
vec3<scalar_t> intersect_color;
|
| 362 |
+
if (ray_trace(start, rd, bi, d_mask, d_hmap, d_rechmap, d_rgb, intersect_color)) {
|
| 363 |
+
d_intersect[bi][0][hi][wi] = intersect_color.x;
|
| 364 |
+
d_intersect[bi][1][hi][wi] = intersect_color.y;
|
| 365 |
+
d_intersect[bi][2][hi][wi] = intersect_color.z;
|
| 366 |
+
d_intersect[bi][3][hi][wi] = 1.0;
|
| 367 |
+
} else {
|
| 368 |
+
d_intersect[bi][0][hi][wi] = default_value;
|
| 369 |
+
d_intersect[bi][1][hi][wi] = default_value;
|
| 370 |
+
d_intersect[bi][2][hi][wi] = default_value;
|
| 371 |
+
d_intersect[bi][3][hi][wi] = 0.0;
|
| 372 |
+
}
|
| 373 |
+
}
|
| 374 |
+
}
|
| 375 |
+
}
|
| 376 |
+
|
| 377 |
+
template <typename scalar_t>
|
| 378 |
+
__global__ void ray_scene_intersect_cuda_forward(
|
| 379 |
+
const torch::PackedTensorAccessor64<scalar_t,4> d_rgb,
|
| 380 |
+
const torch::PackedTensorAccessor64<scalar_t,4> d_mask,
|
| 381 |
+
const torch::PackedTensorAccessor64<scalar_t,4> d_hmap,
|
| 382 |
+
const torch::PackedTensorAccessor64<scalar_t,4> d_ro,
|
| 383 |
+
const torch::PackedTensorAccessor64<scalar_t,4> d_rd,
|
| 384 |
+
const scalar_t dh,
|
| 385 |
+
torch::PackedTensorAccessor64<scalar_t,4> d_intersect) {
|
| 386 |
+
|
| 387 |
+
const int wstride = gridDim.x * blockDim.x, hstride = gridDim.y * blockDim.y, bstride = gridDim.z * blockDim.z;
|
| 388 |
+
const int batch_size = d_rgb.size(0), h = d_rgb.size(2), w = d_rgb.size(3);
|
| 389 |
+
const scalar_t default_value = 0.0;
|
| 390 |
+
|
| 391 |
+
for (int bi = blockIdx.z; bi < batch_size; bi += bstride) {
|
| 392 |
+
for (int wi = blockIdx.x * blockDim.x + threadIdx.x; wi < w; wi += wstride) {
|
| 393 |
+
for(int hi = blockIdx.y * blockDim.y + threadIdx.y; hi < h; hi += hstride) {
|
| 394 |
+
// scalar_t rox = wi + 0.5, roy = hi+0.5, roh = d_ro[bi][0][hi][wi];
|
| 395 |
+
scalar_t rox = d_ro[bi][0][hi][wi];
|
| 396 |
+
scalar_t roy = d_ro[bi][1][hi][wi];
|
| 397 |
+
scalar_t roh = d_ro[bi][2][hi][wi];
|
| 398 |
+
|
| 399 |
+
scalar_t rdx = d_rd[bi][0][hi][wi];
|
| 400 |
+
scalar_t rdy = d_rd[bi][1][hi][wi];
|
| 401 |
+
scalar_t rdh = d_rd[bi][2][hi][wi];
|
| 402 |
+
|
| 403 |
+
vec3<scalar_t> ro(rox, roy, roh);
|
| 404 |
+
vec3<scalar_t> rd(rdx, rdy, rdh);
|
| 405 |
+
|
| 406 |
+
vec3<scalar_t> intersect_color;
|
| 407 |
+
if (ray_scene_intersect(ro, rd, dh, bi, d_mask, d_hmap, d_rgb, intersect_color)) {
|
| 408 |
+
d_intersect[bi][0][hi][wi] = intersect_color.x;
|
| 409 |
+
d_intersect[bi][1][hi][wi] = intersect_color.y;
|
| 410 |
+
d_intersect[bi][2][hi][wi] = intersect_color.z;
|
| 411 |
+
d_intersect[bi][3][hi][wi] = 1.0;
|
| 412 |
+
} else {
|
| 413 |
+
d_intersect[bi][0][hi][wi] = default_value;
|
| 414 |
+
d_intersect[bi][1][hi][wi] = default_value;
|
| 415 |
+
d_intersect[bi][2][hi][wi] = default_value;
|
| 416 |
+
d_intersect[bi][3][hi][wi] = 0.0;
|
| 417 |
+
}
|
| 418 |
+
}
|
| 419 |
+
}
|
| 420 |
+
}
|
| 421 |
+
}
|
| 422 |
+
|
| 423 |
+
template <typename scalar_t>
|
| 424 |
+
__global__ void reflect_render_cuda_forward(
|
| 425 |
+
const torch::PackedTensorAccessor64<scalar_t,4> d_rgb,
|
| 426 |
+
const torch::PackedTensorAccessor64<scalar_t,4> d_mask,
|
| 427 |
+
const torch::PackedTensorAccessor64<scalar_t,4> d_hmap,
|
| 428 |
+
const torch::PackedTensorAccessor64<scalar_t,4> d_rechmap,
|
| 429 |
+
const torch::PackedTensorAccessor64<scalar_t,2> d_thresholds,
|
| 430 |
+
torch::PackedTensorAccessor64<scalar_t,4> d_reflect,
|
| 431 |
+
torch::PackedTensorAccessor64<scalar_t,4> d_reflect_height,
|
| 432 |
+
torch::PackedTensorAccessor64<scalar_t,4> d_reflect_mask) {
|
| 433 |
+
const int wstride = gridDim.x * blockDim.x, hstride = gridDim.y * blockDim.y, bstride = gridDim.z * blockDim.z;
|
| 434 |
+
const int batch_size = d_rgb.size(0), h = d_rgb.size(2), w = d_rgb.size(3);
|
| 435 |
+
|
| 436 |
+
for (int bi = blockIdx.z; bi < batch_size; bi += bstride) {
|
| 437 |
+
for (int wi = blockIdx.x * blockDim.x + threadIdx.x; wi < w; wi += wstride) for(int hi = blockIdx.y * blockDim.y + threadIdx.y; hi < h; hi += hstride) {
|
| 438 |
+
/* Back tracing along the height
|
| 439 |
+
Find the closest point, filter the closest point
|
| 440 |
+
*/
|
| 441 |
+
scalar_t min_dis = FLT_MAX;
|
| 442 |
+
scalar_t min_r, min_g, min_b, min_height, min_mask;
|
| 443 |
+
for(int ti = hi-1; ti >= 0; --ti) {
|
| 444 |
+
if (d_mask[bi][0][ti][wi] < 0.45)
|
| 445 |
+
continue;
|
| 446 |
+
|
| 447 |
+
scalar_t dis = abs(d_hmap[bi][0][ti][wi] * 2 + ti - hi);
|
| 448 |
+
if (dis < min_dis) {
|
| 449 |
+
min_dis = dis;
|
| 450 |
+
min_r = d_rgb[bi][0][ti][wi];
|
| 451 |
+
min_g = d_rgb[bi][1][ti][wi];
|
| 452 |
+
min_b = d_rgb[bi][2][ti][wi];
|
| 453 |
+
|
| 454 |
+
min_height = d_hmap[bi][0][ti][wi];
|
| 455 |
+
min_mask = d_mask[bi][0][ti][wi];
|
| 456 |
+
}
|
| 457 |
+
}
|
| 458 |
+
|
| 459 |
+
/* Check Condition */
|
| 460 |
+
scalar_t cur_thresholds = d_thresholds[bi][0];
|
| 461 |
+
if (min_dis < cur_thresholds) {
|
| 462 |
+
/* Let Use Nearest Neighbor First */
|
| 463 |
+
d_reflect[bi][0][hi][wi] = min_r;
|
| 464 |
+
d_reflect[bi][1][hi][wi] = min_g;
|
| 465 |
+
d_reflect[bi][2][hi][wi] = min_b;
|
| 466 |
+
d_reflect_height[bi][0][hi][wi] = min_height;
|
| 467 |
+
d_reflect_mask[bi][0][hi][wi] = 1.0;
|
| 468 |
+
}
|
| 469 |
+
|
| 470 |
+
// } else {
|
| 471 |
+
// scalar_t fadding = 1.0-(min_dis-cur_thresholds);
|
| 472 |
+
// if (fadding < 0.0) fadding = 0.0;
|
| 473 |
+
// d_reflect[bi][0][hi][wi] = min_r * fadding + (1.0-fadding);
|
| 474 |
+
// d_reflect[bi][1][hi][wi] = min_g * fadding + (1.0-fadding);
|
| 475 |
+
// d_reflect[bi][2][hi][wi] = min_b * fadding + (1.0-fadding);
|
| 476 |
+
// d_reflect_height[bi][0][hi][wi] = 0.0;
|
| 477 |
+
// d_reflect_mask[bi][0][hi][wi] = fadding;
|
| 478 |
+
// }
|
| 479 |
+
}
|
| 480 |
+
}
|
| 481 |
+
}
|
| 482 |
+
|
| 483 |
+
template <typename scalar_t>
|
| 484 |
+
__global__ void glossy_reflect_render_cuda_forward(
|
| 485 |
+
const torch::PackedTensorAccessor64<scalar_t,4> d_rgb,
|
| 486 |
+
const torch::PackedTensorAccessor64<scalar_t,4> d_mask,
|
| 487 |
+
const torch::PackedTensorAccessor64<scalar_t,4> d_hmap,
|
| 488 |
+
const torch::PackedTensorAccessor64<scalar_t,4> d_rechmap,
|
| 489 |
+
const int sample_n,
|
| 490 |
+
const float glossy,
|
| 491 |
+
torch::PackedTensorAccessor64<scalar_t,4> d_reflect) {
|
| 492 |
+
|
| 493 |
+
const int wstride = gridDim.x * blockDim.x, hstride = gridDim.y * blockDim.y, bstride = gridDim.z * blockDim.z;
|
| 494 |
+
const int batch_size = d_rgb.size(0), h = d_rgb.size(2), w = d_rgb.size(3);
|
| 495 |
+
|
| 496 |
+
for (int bi = blockIdx.z; bi < batch_size; bi += bstride) {
|
| 497 |
+
for (int wi = blockIdx.x * blockDim.x + threadIdx.x; wi < w; wi += wstride) for(int hi = blockIdx.y * blockDim.y + threadIdx.y; hi < h; hi += hstride) {
|
| 498 |
+
/* Back tracing along the height
|
| 499 |
+
Find the closest point, filter the closest point
|
| 500 |
+
*/
|
| 501 |
+
scalar_t min_dis = FLT_MAX;
|
| 502 |
+
scalar_t min_r, min_g, min_b, min_height, min_mask;
|
| 503 |
+
for(int ti = hi-1; ti >= 0; --ti) {
|
| 504 |
+
if (d_mask[bi][0][ti][wi] < 0.45)
|
| 505 |
+
continue;
|
| 506 |
+
|
| 507 |
+
scalar_t dis = abs(d_hmap[bi][0][ti][wi] * 2 + ti - hi);
|
| 508 |
+
if (dis < min_dis) {
|
| 509 |
+
min_dis = dis;
|
| 510 |
+
min_r = d_rgb[bi][0][ti][wi];
|
| 511 |
+
min_g = d_rgb[bi][1][ti][wi];
|
| 512 |
+
min_b = d_rgb[bi][2][ti][wi];
|
| 513 |
+
|
| 514 |
+
min_height = d_hmap[bi][0][ti][wi];
|
| 515 |
+
min_mask = d_mask[bi][0][ti][wi];
|
| 516 |
+
}
|
| 517 |
+
}
|
| 518 |
+
|
| 519 |
+
/* Check Condition */
|
| 520 |
+
float cur_thresholds = 1e-1;
|
| 521 |
+
if (min_dis < cur_thresholds) {
|
| 522 |
+
/* Let Use Nearest Neighbor First */
|
| 523 |
+
d_reflect[bi][0][hi][wi] = min_r;
|
| 524 |
+
d_reflect[bi][1][hi][wi] = min_g;
|
| 525 |
+
d_reflect[bi][2][hi][wi] = min_b;
|
| 526 |
+
}
|
| 527 |
+
}
|
| 528 |
+
}
|
| 529 |
+
}
|
| 530 |
+
|
| 531 |
+
} // namespace
|
| 532 |
+
|
| 533 |
+
std::vector<torch::Tensor> hshadow_render_cuda_forward(
|
| 534 |
+
torch::Tensor rgb,
|
| 535 |
+
torch::Tensor mask,
|
| 536 |
+
torch::Tensor bb,
|
| 537 |
+
torch::Tensor hmap,
|
| 538 |
+
torch::Tensor rechmap,
|
| 539 |
+
torch::Tensor light_pos) {
|
| 540 |
+
const auto batch_size = rgb.size(0);
|
| 541 |
+
const auto channel_size = rgb.size(1);
|
| 542 |
+
const auto h = rgb.size(2);
|
| 543 |
+
const auto w = rgb.size(3);
|
| 544 |
+
const dim3 threads(16, 16, 1);
|
| 545 |
+
const dim3 blocks((w + threads.x - 1) / threads.x, (h+threads.y-1)/threads.y, batch_size);
|
| 546 |
+
torch::Tensor shadow_tensor = torch::ones({batch_size, 3, h, w}).to(rgb);
|
| 547 |
+
|
| 548 |
+
AT_DISPATCH_FLOATING_TYPES(rgb.type(), "hshadow_render_cuda_forward", ([&] {
|
| 549 |
+
hshadow_render_cuda_forward<scalar_t><<<blocks, threads>>>(
|
| 550 |
+
rgb.packed_accessor64<scalar_t,4>(),
|
| 551 |
+
mask.packed_accessor64<scalar_t,4>(),
|
| 552 |
+
bb.packed_accessor64<scalar_t,2>(),
|
| 553 |
+
hmap.packed_accessor64<scalar_t,4>(),
|
| 554 |
+
rechmap.packed_accessor64<scalar_t,4>(),
|
| 555 |
+
light_pos.packed_accessor64<scalar_t,2>(),
|
| 556 |
+
shadow_tensor.packed_accessor64<scalar_t,4>());
|
| 557 |
+
}));
|
| 558 |
+
|
| 559 |
+
return {shadow_tensor};
|
| 560 |
+
}
|
| 561 |
+
|
| 562 |
+
std::vector<torch::Tensor> reflect_render_cuda_forward(
|
| 563 |
+
torch::Tensor rgb,
|
| 564 |
+
torch::Tensor mask,
|
| 565 |
+
torch::Tensor hmap,
|
| 566 |
+
torch::Tensor rechmap,
|
| 567 |
+
torch::Tensor thresholds) {
|
| 568 |
+
const auto batch_size = rgb.size(0);
|
| 569 |
+
const auto channel_size = rgb.size(1);
|
| 570 |
+
const auto h = rgb.size(2);
|
| 571 |
+
const auto w = rgb.size(3);
|
| 572 |
+
const dim3 threads(16, 16, 1);
|
| 573 |
+
const dim3 blocks((w + threads.x - 1) / threads.x, (h+threads.y-1)/threads.y, batch_size);
|
| 574 |
+
torch::Tensor reflection_tensor = torch::ones({batch_size, 3, h, w}).to(rgb);
|
| 575 |
+
torch::Tensor reflection_mask_tensor = torch::zeros({batch_size, 1, h, w}).to(rgb);
|
| 576 |
+
torch::Tensor reflection_height_tensor = torch::zeros({batch_size, 1, h, w}).to(rgb);
|
| 577 |
+
|
| 578 |
+
AT_DISPATCH_FLOATING_TYPES(rgb.type(), "reflect_render_cuda_forward", ([&] {
|
| 579 |
+
reflect_render_cuda_forward<scalar_t><<<blocks, threads>>>(
|
| 580 |
+
rgb.packed_accessor64<scalar_t,4>(),
|
| 581 |
+
mask.packed_accessor64<scalar_t,4>(),
|
| 582 |
+
hmap.packed_accessor64<scalar_t,4>(),
|
| 583 |
+
rechmap.packed_accessor64<scalar_t,4>(),
|
| 584 |
+
thresholds.packed_accessor64<scalar_t,2>(),
|
| 585 |
+
reflection_tensor.packed_accessor64<scalar_t,4>(),
|
| 586 |
+
reflection_height_tensor.packed_accessor64<scalar_t,4>(),
|
| 587 |
+
reflection_mask_tensor.packed_accessor64<scalar_t,4>());
|
| 588 |
+
}));
|
| 589 |
+
|
| 590 |
+
return {reflection_tensor, reflection_height_tensor,reflection_mask_tensor};
|
| 591 |
+
}
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
std::vector<torch::Tensor> glossy_reflect_render_cuda_forward(torch::Tensor rgb,
|
| 595 |
+
torch::Tensor mask,
|
| 596 |
+
torch::Tensor hmap,
|
| 597 |
+
torch::Tensor rechmap,
|
| 598 |
+
const int sample_n,
|
| 599 |
+
const float glossy) {
|
| 600 |
+
const auto batch_size = rgb.size(0);
|
| 601 |
+
const auto channel_size = rgb.size(1);
|
| 602 |
+
const auto h = rgb.size(2);
|
| 603 |
+
const auto w = rgb.size(3);
|
| 604 |
+
const dim3 threads(16, 16, 1);
|
| 605 |
+
const dim3 blocks((w + threads.x - 1) / threads.x, (h+threads.y-1)/threads.y, batch_size);
|
| 606 |
+
|
| 607 |
+
torch::Tensor reflection_tensor = torch::ones({batch_size, 3, h, w}).to(rgb);
|
| 608 |
+
|
| 609 |
+
AT_DISPATCH_FLOATING_TYPES(rgb.type(), "reflect_render_cuda_forward", ([&] {
|
| 610 |
+
glossy_reflect_render_cuda_forward<scalar_t><<<blocks, threads>>>(
|
| 611 |
+
rgb.packed_accessor64<scalar_t,4>(),
|
| 612 |
+
mask.packed_accessor64<scalar_t,4>(),
|
| 613 |
+
hmap.packed_accessor64<scalar_t,4>(),
|
| 614 |
+
rechmap.packed_accessor64<scalar_t,4>(),
|
| 615 |
+
sample_n,
|
| 616 |
+
glossy,
|
| 617 |
+
reflection_tensor.packed_accessor64<scalar_t,4>());
|
| 618 |
+
}));
|
| 619 |
+
|
| 620 |
+
return {reflection_tensor};
|
| 621 |
+
|
| 622 |
+
}
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
torch::Tensor ray_intersect_cuda_forward(torch::Tensor rgb,
|
| 626 |
+
torch::Tensor mask,
|
| 627 |
+
torch::Tensor hmap,
|
| 628 |
+
torch::Tensor rechmap,
|
| 629 |
+
torch::Tensor rd_map){
|
| 630 |
+
const auto batch_size = rgb.size(0);
|
| 631 |
+
const auto channel_size = rgb.size(1);
|
| 632 |
+
const auto h = rgb.size(2);
|
| 633 |
+
const auto w = rgb.size(3);
|
| 634 |
+
const dim3 threads(16, 16, 1);
|
| 635 |
+
const dim3 blocks((w + threads.x - 1) / threads.x, (h+threads.y-1)/threads.y, batch_size);
|
| 636 |
+
|
| 637 |
+
torch::Tensor intersect_tensor = torch::ones({batch_size, 4, h, w}).to(rgb);
|
| 638 |
+
|
| 639 |
+
AT_DISPATCH_FLOATING_TYPES(rgb.type(), "reflect_render_cuda_forward", ([&] {
|
| 640 |
+
ray_intersect_cuda_forward<scalar_t><<<blocks, threads>>>(
|
| 641 |
+
rgb.packed_accessor64<scalar_t,4>(),
|
| 642 |
+
mask.packed_accessor64<scalar_t,4>(),
|
| 643 |
+
hmap.packed_accessor64<scalar_t,4>(),
|
| 644 |
+
rechmap.packed_accessor64<scalar_t,4>(),
|
| 645 |
+
rd_map.packed_accessor64<scalar_t,4>(),
|
| 646 |
+
intersect_tensor.packed_accessor64<scalar_t,4>());
|
| 647 |
+
}));
|
| 648 |
+
|
| 649 |
+
return intersect_tensor;
|
| 650 |
+
|
| 651 |
+
}
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
torch::Tensor ray_scene_intersect_cuda_forward(torch::Tensor rgb,
|
| 655 |
+
torch::Tensor mask,
|
| 656 |
+
torch::Tensor hmap,
|
| 657 |
+
torch::Tensor ro,
|
| 658 |
+
torch::Tensor rd,
|
| 659 |
+
float dh){
|
| 660 |
+
const auto batch_size = rgb.size(0);
|
| 661 |
+
const auto channel_size = rgb.size(1);
|
| 662 |
+
const auto h = rgb.size(2);
|
| 663 |
+
const auto w = rgb.size(3);
|
| 664 |
+
const dim3 threads(16, 16, 1);
|
| 665 |
+
const dim3 blocks((w + threads.x - 1) / threads.x, (h+threads.y-1)/threads.y, batch_size);
|
| 666 |
+
|
| 667 |
+
torch::Tensor intersect_tensor = torch::ones({batch_size, 4, h, w}).to(rgb);
|
| 668 |
+
|
| 669 |
+
AT_DISPATCH_FLOATING_TYPES(rgb.type(), "reflect_render_cuda_forward", ([&] {
|
| 670 |
+
ray_scene_intersect_cuda_forward<scalar_t><<<blocks, threads>>>(
|
| 671 |
+
rgb.packed_accessor64<scalar_t,4>(),
|
| 672 |
+
mask.packed_accessor64<scalar_t,4>(),
|
| 673 |
+
hmap.packed_accessor64<scalar_t,4>(),
|
| 674 |
+
ro.packed_accessor64<scalar_t,4>(),
|
| 675 |
+
rd.packed_accessor64<scalar_t,4>(),
|
| 676 |
+
dh,
|
| 677 |
+
intersect_tensor.packed_accessor64<scalar_t,4>());
|
| 678 |
+
}));
|
| 679 |
+
|
| 680 |
+
return intersect_tensor;
|
| 681 |
+
|
| 682 |
+
}
|
PixHtLab-Src/Demo/PixhtLab/Torch_Render/plane_visualize.cpp
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
#include <torch/extension.h>
|
| 3 |
+
|
| 4 |
+
#include <vector>
|
| 5 |
+
#include <stdio.h>
|
| 6 |
+
|
| 7 |
+
// CUDA forward declarations
|
| 8 |
+
std::vector<torch::Tensor> plane_visualize_cuda(torch::Tensor planes, torch::Tensor camera, int h, int w);
|
| 9 |
+
|
| 10 |
+
// C++ interface
|
| 11 |
+
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
|
| 12 |
+
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
|
| 13 |
+
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
|
| 14 |
+
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
| 15 |
+
|
| 16 |
+
std::vector<torch::Tensor> plane_visualize(torch::Tensor planes, torch::Tensor camera, int h, int w) {
|
| 17 |
+
CHECK_INPUT(planes);
|
| 18 |
+
CHECK_INPUT(camera);
|
| 19 |
+
|
| 20 |
+
return plane_visualize_cuda(planes, camera, h, w);
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 25 |
+
m.def("forward", &plane_visualize, "Plane Visualization (CUDA)");
|
| 26 |
+
}
|
PixHtLab-Src/Demo/PixhtLab/Torch_Render/plane_visualize_cuda.cu
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/extension.h>
|
| 2 |
+
|
| 3 |
+
#include <cuda.h>
|
| 4 |
+
#include <cuda_runtime.h>
|
| 5 |
+
|
| 6 |
+
#include <vector>
|
| 7 |
+
|
| 8 |
+
namespace {
|
| 9 |
+
template <typename scalar_t>
|
| 10 |
+
struct vec3 {
|
| 11 |
+
scalar_t x, y, z;
|
| 12 |
+
|
| 13 |
+
__device__ __host__
|
| 14 |
+
vec3<scalar_t>():x(0),y(0),z(0) {}
|
| 15 |
+
|
| 16 |
+
__device__ __host__
|
| 17 |
+
vec3<scalar_t>(scalar_t a):x(a),y(a),z(a) {}
|
| 18 |
+
|
| 19 |
+
__device__ __host__
|
| 20 |
+
vec3<scalar_t>(scalar_t xx, scalar_t yy, scalar_t zz):x(xx),y(yy),z(zz) {}
|
| 21 |
+
|
| 22 |
+
__device__ __host__
|
| 23 |
+
vec3<scalar_t> operator*(const scalar_t &rhs) const {
|
| 24 |
+
vec3<scalar_t> ret(x,y,z);
|
| 25 |
+
ret.x = x * rhs;
|
| 26 |
+
ret.y = y * rhs;
|
| 27 |
+
ret.z = z * rhs;
|
| 28 |
+
return ret;
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
__device__ __host__
|
| 32 |
+
vec3<scalar_t> operator/(const scalar_t &rhs) const {
|
| 33 |
+
vec3<scalar_t> ret(x,y,z);
|
| 34 |
+
ret.x = x / rhs;
|
| 35 |
+
ret.y = y / rhs;
|
| 36 |
+
ret.z = z / rhs;
|
| 37 |
+
return ret;
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
__device__ __host__
|
| 41 |
+
vec3<scalar_t> operator+(const vec3<scalar_t> &rhs) const {
|
| 42 |
+
vec3<scalar_t> ret(x,y,z);
|
| 43 |
+
ret.x = x + rhs.x;
|
| 44 |
+
ret.y = y + rhs.y;
|
| 45 |
+
ret.z = z + rhs.z;
|
| 46 |
+
return ret;
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
__device__ __host__
|
| 50 |
+
vec3<scalar_t> operator-(const vec3<scalar_t> &rhs) const {
|
| 51 |
+
vec3<scalar_t> ret(x,y,z);
|
| 52 |
+
ret.x = x - rhs.x;
|
| 53 |
+
ret.y = y - rhs.y;
|
| 54 |
+
ret.z = z - rhs.z;
|
| 55 |
+
return ret;
|
| 56 |
+
}
|
| 57 |
+
};
|
| 58 |
+
|
| 59 |
+
template <typename scalar_t>
|
| 60 |
+
struct Ray {
|
| 61 |
+
vec3<scalar_t> ro, rd;
|
| 62 |
+
};
|
| 63 |
+
|
| 64 |
+
template <typename scalar_t>
|
| 65 |
+
struct Scene {
|
| 66 |
+
vec3<scalar_t> pp, pn;
|
| 67 |
+
};
|
| 68 |
+
|
| 69 |
+
template <typename scalar_t>
|
| 70 |
+
__device__
|
| 71 |
+
float deg2rad(scalar_t d) {
|
| 72 |
+
return d/180.0 * 3.1415926f;
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
template <typename scalar_t>
|
| 76 |
+
__device__
|
| 77 |
+
scalar_t dot(vec3<scalar_t> a, vec3<scalar_t> b) {
|
| 78 |
+
return a.x * b.x + a.y * b.y + a.z * b.z;
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
template <typename scalar_t>
|
| 82 |
+
__device__
|
| 83 |
+
vec3<scalar_t> cross(vec3<scalar_t> a, vec3<scalar_t> b) {
|
| 84 |
+
vec3<scalar_t> ret(0.0f);
|
| 85 |
+
ret.x = a.y * b.z - a.z * b.y;
|
| 86 |
+
ret.y = a.z * b.x - a.x * b.z;
|
| 87 |
+
ret.z = a.x * b.y - a.y * b.x;
|
| 88 |
+
return ret;
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
template <typename scalar_t>
|
| 92 |
+
__device__
|
| 93 |
+
scalar_t length(vec3<scalar_t> a) {
|
| 94 |
+
return sqrt(dot(a, a));
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
template <typename scalar_t>
|
| 98 |
+
__device__
|
| 99 |
+
vec3<scalar_t> normalize(vec3<scalar_t> a) {
|
| 100 |
+
return a/length(a);
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
template <typename scalar_t>
|
| 104 |
+
__device__
|
| 105 |
+
scalar_t get_focal(int w, scalar_t fov) {
|
| 106 |
+
return 0.5 * w / tan(deg2rad(fov * 0.5));
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
template <typename scalar_t>
|
| 110 |
+
__device__
|
| 111 |
+
vec3<scalar_t> get_rd(vec3<scalar_t> right, vec3<scalar_t> front, vec3<scalar_t> up, int h, int w, scalar_t focal, scalar_t x, scalar_t y) {
|
| 112 |
+
/* x, y in [-1, 1] */
|
| 113 |
+
right = normalize(right);
|
| 114 |
+
front = normalize(front);
|
| 115 |
+
up = normalize(up);
|
| 116 |
+
|
| 117 |
+
return front * focal + right * x * (float)w * 0.5f + up * y * (float)h * 0.5f;
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
template <typename scalar_t>
|
| 121 |
+
__device__
|
| 122 |
+
Ray<scalar_t> get_ray(int h, int w, float hi, float wi, vec3<scalar_t> right, vec3<scalar_t>front, vec3<scalar_t> up, vec3<scalar_t> cam_pos, float fov) {
|
| 123 |
+
/* Note, wi/hi is in [-1.0, 1.0] */
|
| 124 |
+
Ray<scalar_t> ray;
|
| 125 |
+
|
| 126 |
+
float focal = 0.5f * w / tan(deg2rad(fov));
|
| 127 |
+
ray.ro = cam_pos;
|
| 128 |
+
ray.rd = front * focal + right * 0.5f * w * wi + up * 0.5f * h * hi;
|
| 129 |
+
return ray;
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
template <typename scalar_t>
|
| 133 |
+
__device__
|
| 134 |
+
bool plane_intersect(Ray<scalar_t> ray, vec3<scalar_t> p, vec3<scalar_t> n, float &t) {
|
| 135 |
+
vec3<scalar_t> ro = ray.ro, rd = ray.rd;
|
| 136 |
+
t = dot(p-ro, n)/dot(rd, n);
|
| 137 |
+
return t >= 0.0;
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
template <typename scalar_t>
|
| 141 |
+
__device__
|
| 142 |
+
vec3<scalar_t> horizon2front(scalar_t horizon, int h, int w, float fov) {
|
| 143 |
+
scalar_t yoffset = h / 2 - horizon;
|
| 144 |
+
scalar_t focal = 0.5f * w / tan(deg2rad(fov));
|
| 145 |
+
vec3<scalar_t> front = vec3<scalar_t>(0.0f,0.0f,-1.0f) * focal + vec3<scalar_t>(0.0f, 1.0f, 0.0f) * yoffset;
|
| 146 |
+
return normalize(front);
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
template <typename scalar_t>
|
| 150 |
+
__device__
|
| 151 |
+
vec3<scalar_t> plane_texture(vec3<scalar_t> p) {
|
| 152 |
+
float freq = 6.0f;
|
| 153 |
+
float u = sin(p.x * freq), v = sin(p.z * freq);
|
| 154 |
+
vec3<scalar_t> ret(0.0f);
|
| 155 |
+
float line_width = 0.05f;
|
| 156 |
+
if ((abs(u) < line_width || abs(v) < line_width)) {
|
| 157 |
+
ret = vec3<scalar_t>(1.0f);
|
| 158 |
+
}
|
| 159 |
+
return ret;
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
template <typename scalar_t>
|
| 163 |
+
__device__
|
| 164 |
+
bool ray_scene_trace(Ray<scalar_t> ray, Scene<scalar_t> scene, vec3<scalar_t> &color) {
|
| 165 |
+
color = vec3<scalar_t>(0.0f);
|
| 166 |
+
float t;
|
| 167 |
+
if (plane_intersect(ray, scene.pp, scene.pn, t)) {
|
| 168 |
+
vec3<scalar_t> intersect_pos = ray.ro + ray.rd * t;
|
| 169 |
+
color = plane_texture(intersect_pos);
|
| 170 |
+
return true;
|
| 171 |
+
}
|
| 172 |
+
return false;
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
template <typename scalar_t>
|
| 176 |
+
__global__ void plane_visualize_foward(
|
| 177 |
+
const torch::PackedTensorAccessor64<scalar_t,2> d_plane,
|
| 178 |
+
const torch::PackedTensorAccessor64<scalar_t,2> d_camera,
|
| 179 |
+
torch::PackedTensorAccessor64<scalar_t,4> d_vis) {
|
| 180 |
+
const int wstride = gridDim.x * blockDim.x, hstride = gridDim.y * blockDim.y, bstride = gridDim.z * blockDim.z;
|
| 181 |
+
const int batch_size = d_vis.size(0), h = d_vis.size(2), w = d_vis.size(3);
|
| 182 |
+
const int samples = 10;
|
| 183 |
+
|
| 184 |
+
vec3<scalar_t> cam_pos(0.0f, 1.0f, 1.0f), front(0.0f,0.0f,-1.0f), right(1.0f, 0.0f, 0.0f), up(0.0f, 1.0f, 0.0f);
|
| 185 |
+
for (int bi = blockIdx.z; bi < batch_size; bi += bstride) {
|
| 186 |
+
// scalar_t px = d_plane[bi][0], py = d_plane[bi][1], pz = d_plane[bi][2];
|
| 187 |
+
vec3<scalar_t> plane_pos(d_plane[bi][0], d_plane[bi][1], d_plane[bi][2]);
|
| 188 |
+
vec3<scalar_t> plane_norm(d_plane[bi][3], d_plane[bi][4], d_plane[bi][5]);
|
| 189 |
+
Scene<scalar_t> scene = {plane_pos, plane_norm};
|
| 190 |
+
|
| 191 |
+
scalar_t fov = d_camera[bi][0], horizon = d_camera[bi][1];
|
| 192 |
+
front = normalize(horizon2front(horizon, h, w, fov));
|
| 193 |
+
up = normalize(cross(right, front));
|
| 194 |
+
for (int wi = blockIdx.x * blockDim.x + threadIdx.x; wi < w; wi += wstride)
|
| 195 |
+
for(int hi = blockIdx.y * blockDim.y + threadIdx.y; hi < h; hi += hstride) {
|
| 196 |
+
bool intersect = false;
|
| 197 |
+
vec3<scalar_t> color(0.0f);
|
| 198 |
+
for (int si = 0; si < samples * samples; ++si) {
|
| 199 |
+
float hoffset = (float)(si/samples)/max(samples-1, 1);
|
| 200 |
+
float woffset = (float)(si%samples)/max(samples-1, 1);
|
| 201 |
+
float x = (float)(wi + woffset)/w * 2.0 - 1.0;
|
| 202 |
+
float y = (float)(hi + hoffset)/h * 2.0 - 1.0;
|
| 203 |
+
Ray<scalar_t> ray = get_ray(h, w, y, x, right, front, up, cam_pos,fov);
|
| 204 |
+
vec3<scalar_t> tmp_color(0.0f);
|
| 205 |
+
if(ray_scene_trace(ray, scene, tmp_color)) {
|
| 206 |
+
color = color + tmp_color;
|
| 207 |
+
intersect = intersect || true;
|
| 208 |
+
}
|
| 209 |
+
}
|
| 210 |
+
if (intersect) {
|
| 211 |
+
color = color / float(samples * samples);
|
| 212 |
+
d_vis[bi][0][hi][wi] = color.x;
|
| 213 |
+
d_vis[bi][1][hi][wi] = color.y;
|
| 214 |
+
d_vis[bi][2][hi][wi] = color.z;
|
| 215 |
+
}
|
| 216 |
+
}
|
| 217 |
+
}
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
} // namespace
|
| 221 |
+
|
| 222 |
+
std::vector<torch::Tensor> plane_visualize_cuda(torch::Tensor planes, torch::Tensor camera, int h, int w){
|
| 223 |
+
const auto batch_size = planes.size(0);
|
| 224 |
+
const int threads = 512;
|
| 225 |
+
const dim3 blocks((w + threads - 1) / threads, (h+threads-1)/threads, batch_size);
|
| 226 |
+
|
| 227 |
+
torch::Tensor vis_tensor = torch::zeros({batch_size, 3, h, w}).to(planes);
|
| 228 |
+
AT_DISPATCH_FLOATING_TYPES(planes.type(), "plane_visualize_foward", ([&] {
|
| 229 |
+
plane_visualize_foward<scalar_t><<<blocks, threads>>>(
|
| 230 |
+
planes.packed_accessor64<scalar_t,2>(),
|
| 231 |
+
camera.packed_accessor64<scalar_t,2>(),
|
| 232 |
+
vis_tensor.packed_accessor64<scalar_t,4>()
|
| 233 |
+
);
|
| 234 |
+
}));
|
| 235 |
+
|
| 236 |
+
return {vis_tensor};
|
| 237 |
+
}
|
PixHtLab-Src/Demo/PixhtLab/Torch_Render/setup.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from setuptools import setup
|
| 2 |
+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
setup(
|
| 6 |
+
name='hshadow',
|
| 7 |
+
ext_modules=[
|
| 8 |
+
CUDAExtension('hshadow', [
|
| 9 |
+
'hshadow_cuda.cpp',
|
| 10 |
+
'hshadow_cuda_kernel.cu',
|
| 11 |
+
])
|
| 12 |
+
],
|
| 13 |
+
cmdclass={
|
| 14 |
+
'build_ext': BuildExtension
|
| 15 |
+
}
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
# setup(
|
| 19 |
+
# name='plane_visualize',
|
| 20 |
+
# ext_modules=[
|
| 21 |
+
# CUDAExtension('plane_visualize', [
|
| 22 |
+
# 'plane_visualize.cpp',
|
| 23 |
+
# 'plane_visualize_cuda.cu',
|
| 24 |
+
# ])
|
| 25 |
+
# ],
|
| 26 |
+
# cmdclass={
|
| 27 |
+
# 'build_ext': BuildExtension
|
| 28 |
+
# }
|
| 29 |
+
# )
|
PixHtLab-Src/Demo/PixhtLab/Torch_Render/test_ground.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import torch
|
| 3 |
+
import plane_visualize
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from torchvision import transforms
|
| 7 |
+
import os
|
| 8 |
+
from os.path import join
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
test_output = 'imgs/output'
|
| 12 |
+
os.makedirs(test_output, exist_ok=True)
|
| 13 |
+
device = torch.device("cuda:0")
|
| 14 |
+
|
| 15 |
+
def test_ground():
|
| 16 |
+
fov, horizon = 120, 400
|
| 17 |
+
camera = torch.tensor([[fov, horizon]])
|
| 18 |
+
planes = torch.tensor([[0.0, 0.0, 0.0, 0.0, 1.0, 0.0]])
|
| 19 |
+
|
| 20 |
+
camera = camera.repeat(5,1).float().to(device)
|
| 21 |
+
planes = planes.repeat(5,1).float().to(device)
|
| 22 |
+
|
| 23 |
+
ground_vis = plane_visualize.forward(planes, camera, int(512), int(512))[0]
|
| 24 |
+
return ground_vis
|
| 25 |
+
|
| 26 |
+
t = time.time()
|
| 27 |
+
ground_vis = test_ground()
|
| 28 |
+
print('{} s'.format(time.time() - t))
|
| 29 |
+
batch = ground_vis.shape[0]
|
| 30 |
+
for bi in range(batch):
|
| 31 |
+
img = ground_vis[bi].detach().cpu().numpy().transpose(1,2,0)
|
| 32 |
+
img = np.clip(img, 0.0, 1.0)
|
| 33 |
+
plt.imsave(join(test_output, 'ground_{}.png'.format(bi)),img)
|
PixHtLab-Src/Demo/PixhtLab/Torch_Render/test_hshadow.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import torch
|
| 3 |
+
import hshadow
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from torchvision import transforms
|
| 7 |
+
import os
|
| 8 |
+
from os.path import join
|
| 9 |
+
import numpy as np
|
| 10 |
+
from scipy.ndimage import uniform_filter
|
| 11 |
+
|
| 12 |
+
test_output = 'imgs/output'
|
| 13 |
+
os.makedirs(test_output, exist_ok=True)
|
| 14 |
+
def test_shadow(rgb, mask, hmap, rechmap, light_pos):
|
| 15 |
+
h,w = rgb.shape[:2]
|
| 16 |
+
bb = torch.tensor([[0, h-1, 0, w-1]]).float().to(device)
|
| 17 |
+
start = time.time()
|
| 18 |
+
shadow = hshadow.forward(rgb, mask, bb, hmap, rechmap, light_pos)[0]
|
| 19 |
+
end = time.time()
|
| 20 |
+
|
| 21 |
+
print('Shadow rendering: {}s'.format(end-start))
|
| 22 |
+
res = (1.0-mask) * rgb * shadow + mask * rgb
|
| 23 |
+
return res, shadow
|
| 24 |
+
|
| 25 |
+
def frenel_reflect(reflect_tensor, reflect_mask, fov, ref_ind):
|
| 26 |
+
# Use schilik approximation https://en.wikipedia.org/wiki/Schlick%27s_approximation
|
| 27 |
+
def deg2rad(deg):
|
| 28 |
+
return deg/180.0 * 3.1415926
|
| 29 |
+
|
| 30 |
+
def img2cos(reflect_img, fov, horizon):
|
| 31 |
+
# Note, this factor needs calibration if we have camera parameters
|
| 32 |
+
b, c, h, w = reflect_img.shape
|
| 33 |
+
focal = 0.5 * h / np.tan(deg2rad(0.5 * fov))
|
| 34 |
+
fadding_map = torch.arange(0, h).unsqueeze(1).expand(h, w).unsqueeze(0).unsqueeze(0).repeat(b, c, 1, 1)
|
| 35 |
+
fadding_map = focal / torch.sqrt((fadding_map-horizon)**2 + focal **2)
|
| 36 |
+
return fadding_map.to(reflect_img)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
ind = (1.0-ref_ind)/(1.0+ref_ind)
|
| 40 |
+
ind = ind ** 2
|
| 41 |
+
h = reflect_tensor.shape[2]
|
| 42 |
+
horizon = h * 0.7
|
| 43 |
+
cos_map = img2cos(reflect_tensor, fov, horizon)
|
| 44 |
+
# fadding = 1.0 - (ind + (1.0-ind) * torch.pow(1.0-cos_map, 4))
|
| 45 |
+
b, c, h, w = reflect_tensor.shape
|
| 46 |
+
fadding = torch.linspace(3.0,0.0,h)[None, None, ..., None].repeat(b,c,1,w).to(reflect_tensor) ** 4
|
| 47 |
+
fadding = torch.clip(fadding, 0.0, 1.0)
|
| 48 |
+
plt.imsave('test_fadding.png', fadding[0].detach().cpu().numpy().transpose(1,2,0))
|
| 49 |
+
reflect_mask = reflect_mask.repeat(1,3,1,1)
|
| 50 |
+
return fadding * reflect_tensor * reflect_mask + (1.0-reflect_mask * fadding) * torch.ones_like(reflect_tensor)
|
| 51 |
+
|
| 52 |
+
def refine_boundary(output, filter=3):
|
| 53 |
+
return output
|
| 54 |
+
h,w,c = output.shape
|
| 55 |
+
for i in range(c):
|
| 56 |
+
output[...,i] = uniform_filter(output[...,i], size=filter)
|
| 57 |
+
return output
|
| 58 |
+
|
| 59 |
+
def height_fadding(reflect, reflect_height, reflect_mask, fadding_factor):
|
| 60 |
+
def np_sigmoid(a):
|
| 61 |
+
return 1.0/(1.0+np.exp(-a))
|
| 62 |
+
|
| 63 |
+
h,w,c = reflect.shape
|
| 64 |
+
reflect_h = reflect_height/h
|
| 65 |
+
reflect_h = reflect_h/reflect_h.max()
|
| 66 |
+
fadding = (1.0-(np_sigmoid(reflect_h * fadding_factor)-0.5)* 2.0) * reflect_mask
|
| 67 |
+
after_fadding = fadding * reflect + (1.0-fadding) * np.ones_like(reflect)
|
| 68 |
+
return after_fadding
|
| 69 |
+
|
| 70 |
+
def to_numpy(tensor):
|
| 71 |
+
return tensor[0].detach().cpu().numpy().transpose(1,2,0)
|
| 72 |
+
|
| 73 |
+
def test_reflect(rgb, mask, hmap, rechmap, thresholds=1.5, fadding_factor=10.0):
|
| 74 |
+
b, c, h, w = rgb.shape
|
| 75 |
+
# thresholds = torch.tensor([[1.0 + i/b] for i in range(b)]).float().to(device)
|
| 76 |
+
thresholds = torch.tensor([[thresholds]]).float().to(device)
|
| 77 |
+
start = time.time()
|
| 78 |
+
reflect, reflect_height, reflect_mask = hshadow.reflection(rgb, mask, hmap, rechmap, thresholds)
|
| 79 |
+
end = time.time()
|
| 80 |
+
print('Reflection rendering: {}s'.format(end-start))
|
| 81 |
+
|
| 82 |
+
reflect, reflect_height, reflect_mask = to_numpy(reflect), to_numpy(reflect_height), to_numpy(reflect_mask)
|
| 83 |
+
# reflect = frenel_reflect(reflect, reflect_height, reflect_mask, 175, 0.9)
|
| 84 |
+
refine_reflect, refine_reflect_height = refine_boundary(reflect), refine_boundary(reflect_height)
|
| 85 |
+
refine_reflect_mask = reflect_mask
|
| 86 |
+
reflect = height_fadding(refine_reflect, refine_reflect_height, refine_reflect_mask, fadding_factor)
|
| 87 |
+
rgb, mask = to_numpy(rgb), to_numpy(mask)
|
| 88 |
+
res = (1.0-mask) * rgb * reflect + mask * rgb
|
| 89 |
+
return res, reflect
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def test_glossy_reflect(rgb, mask, hmap, rechmap, sample, glossy, fadding_factor=10.0):
|
| 93 |
+
b, c, h, w = rgb.shape
|
| 94 |
+
# thresholds = torch.tensor([[1.0 + i/b] for i in range(b)]).float().to(device)
|
| 95 |
+
start = time.time()
|
| 96 |
+
reflect = hshadow.glossy_reflection(rgb, mask, hmap, rechmap, sample, glossy)[0]
|
| 97 |
+
end = time.time()
|
| 98 |
+
print('Reflection rendering: {}s'.format(end-start))
|
| 99 |
+
|
| 100 |
+
reflect = to_numpy(reflect)
|
| 101 |
+
refine_reflect = refine_boundary(reflect)
|
| 102 |
+
rgb, mask = to_numpy(rgb), to_numpy(mask)
|
| 103 |
+
res = (1.0-mask) * rgb * reflect + mask * rgb
|
| 104 |
+
return res, reflect
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
device = torch.device("cuda:0")
|
| 108 |
+
to_tensor = transforms.ToTensor()
|
| 109 |
+
# for i in range(1,5):
|
| 110 |
+
i = 2
|
| 111 |
+
if True:
|
| 112 |
+
prefix = 'canvas{}'.format(i)
|
| 113 |
+
rgb, mask, hmap = to_tensor(Image.open('imgs/{}_rgb.png'.format(prefix)).convert('RGB')).to(device), to_tensor(Image.open('imgs/{}_mask.png'.format(prefix)).convert('RGB'))[0:1].to(device), to_tensor(Image.open('imgs/{}_height.png'.format(prefix)).convert('RGB'))[0:1].to(device)
|
| 114 |
+
h, w = hmap.shape[1:]
|
| 115 |
+
hmap = hmap * h * 0.45
|
| 116 |
+
|
| 117 |
+
rechmap = torch.zeros_like(hmap)
|
| 118 |
+
rgb, mask, hmap, rechmap = rgb.unsqueeze(dim=0), mask.unsqueeze(dim=0), hmap.unsqueeze(dim=0), rechmap.unsqueeze(dim=0)
|
| 119 |
+
lightpos = torch.tensor([[300, -100, 200.0]]).to(device)
|
| 120 |
+
shadow_res, shadow = test_shadow(rgb, mask, hmap, rechmap, lightpos)
|
| 121 |
+
reflect_res, reflect = test_reflect(rgb, mask, hmap, rechmap, thresholds=2.5, fadding_factor=18.0)
|
| 122 |
+
|
| 123 |
+
glossy_reflect_res, glossy_reflect = test_glossy_reflect(rgb, mask, hmap, rechmap, sample=10, glossy=0.5, fadding_factor=18.0)
|
| 124 |
+
|
| 125 |
+
plt.imsave(join(test_output, prefix + "_shadow_final.png"), shadow_res[0].detach().cpu().numpy().transpose(1,2,0))
|
| 126 |
+
plt.imsave(join(test_output, prefix + "_shadow.png"), shadow[0].detach().cpu().numpy().transpose(1,2,0))
|
| 127 |
+
plt.imsave(join(test_output, prefix + "_reflect_final.png"), reflect_res)
|
| 128 |
+
plt.imsave(join(test_output, prefix + "_reflect.png"), reflect)
|
| 129 |
+
plt.imsave(join(test_output, prefix + "_glossy_reflect_final.png"), glossy_reflect_res)
|
| 130 |
+
plt.imsave(join(test_output, prefix + "_glossy_reflect.png"), glossy_reflect)
|
PixHtLab-Src/Demo/PixhtLab/camera.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import math
|
| 3 |
+
from abc import ABC
|
| 4 |
+
import copy
|
| 5 |
+
|
| 6 |
+
class camera(ABC):
|
| 7 |
+
def __init__(self, hfov, h, w, height=100.0):
|
| 8 |
+
self.fov = hfov
|
| 9 |
+
self.h = h
|
| 10 |
+
self.w = w
|
| 11 |
+
|
| 12 |
+
self.ori_height = height
|
| 13 |
+
self.height = copy.deepcopy(self.ori_height)
|
| 14 |
+
self.O = np.array([0.0, self.height, 0.0]) # ray origianl
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
######################################################################################
|
| 18 |
+
""" Abstraction
|
| 19 |
+
"""
|
| 20 |
+
def align_horizon(self, cur_horizon):
|
| 21 |
+
raise NotImplementedError('Not implemented yet')
|
| 22 |
+
|
| 23 |
+
def C(self):
|
| 24 |
+
raise NotImplementedError('Not implemented yet')
|
| 25 |
+
|
| 26 |
+
def right(self):
|
| 27 |
+
raise NotImplementedError('Not implemented yet')
|
| 28 |
+
|
| 29 |
+
def up(self):
|
| 30 |
+
raise NotImplementedError('Not implemented yet')
|
| 31 |
+
|
| 32 |
+
######################################################################################
|
| 33 |
+
|
| 34 |
+
def deg2rad(self, d):
|
| 35 |
+
return d / 180.0 * 3.1415925
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def rad2deg(self, d):
|
| 39 |
+
return d / 3.1415925 * 180.0
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def get_ray(self, xy):
|
| 43 |
+
""" Assume the center is on the top-left corner
|
| 44 |
+
"""
|
| 45 |
+
u, v = xy
|
| 46 |
+
mat = self.get_ABC_mat()
|
| 47 |
+
r = np.dot(mat, np.array([u, v, 1.0]).T)
|
| 48 |
+
|
| 49 |
+
# r = r/np.sqrt(r @ r)
|
| 50 |
+
return r
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def project(self, xyz):
|
| 54 |
+
relative = xyz - self.O
|
| 55 |
+
|
| 56 |
+
mat = self.get_ABC_mat()
|
| 57 |
+
pp = np.dot(np.linalg.inv(mat), relative)
|
| 58 |
+
pixel = np.array([pp[0]/pp[2], pp[1]/pp[2]])
|
| 59 |
+
|
| 60 |
+
return pixel
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def xyh2w(self, xyh):
|
| 64 |
+
u, v, h = xyh
|
| 65 |
+
|
| 66 |
+
foot_xyh = np.copy(xyh)
|
| 67 |
+
foot_xyh[1] = foot_xyh[1] + foot_xyh[2]
|
| 68 |
+
foot_xyh[2] = 0.0
|
| 69 |
+
fu, fv, fh = foot_xyh
|
| 70 |
+
|
| 71 |
+
a = self.right()
|
| 72 |
+
b = -self.up()
|
| 73 |
+
c = self.C()
|
| 74 |
+
mat = self.get_ABC_mat()
|
| 75 |
+
|
| 76 |
+
w = -self.height/(a[1] * fu + b[1] * fv + c[1])
|
| 77 |
+
return w
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def xyh2xyz(self, xyh):
|
| 81 |
+
u, v, h = xyh
|
| 82 |
+
|
| 83 |
+
foot_xyh = np.copy(xyh)
|
| 84 |
+
foot_xyh[1] = foot_xyh[1] + foot_xyh[2]
|
| 85 |
+
foot_xyh[2] = 0.0
|
| 86 |
+
fu, fv, fh = foot_xyh
|
| 87 |
+
|
| 88 |
+
a = self.right()
|
| 89 |
+
b = -self.up()
|
| 90 |
+
c = self.C()
|
| 91 |
+
mat = self.get_ABC_mat()
|
| 92 |
+
|
| 93 |
+
w = -self.height/(a[1] * fu + b[1] * fv + c[1])
|
| 94 |
+
# print('w: {} a*u + b * v + c: {}, b/c: {}/{}, fv: {}'.format(w, a[1] * fu + b[1] * fv + c[1], b, c, fv))
|
| 95 |
+
xyz = self.O + np.dot(mat, np.array([u, v, 1.0]).T) * w
|
| 96 |
+
|
| 97 |
+
# print('w: {}, -{}/{}'.format(w, self.height, a[1] * fu + b[1] * fv + c[1]))
|
| 98 |
+
|
| 99 |
+
return xyz
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def xyz2xyh(self, xyz):
|
| 103 |
+
foot_xyz = np.copy(xyz)
|
| 104 |
+
foot_xyz[1] = 0.0
|
| 105 |
+
|
| 106 |
+
foot_xy = self.project(foot_xyz)
|
| 107 |
+
xy = self.project(xyz)
|
| 108 |
+
|
| 109 |
+
ret = np.copy(xyz)
|
| 110 |
+
ret[:2] = xy
|
| 111 |
+
ret[2] = foot_xy[1] - xy[1]
|
| 112 |
+
|
| 113 |
+
return ret
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def get_ABC_mat(self):
|
| 117 |
+
a = self.right()
|
| 118 |
+
b = -self.up()
|
| 119 |
+
c = self.C()
|
| 120 |
+
|
| 121 |
+
mat = np.concatenate([a[:, None], b[:,None], c[:, None]], axis=1)
|
| 122 |
+
return mat
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class pitch_camera(camera):
|
| 126 |
+
""" Picth alignment camera
|
| 127 |
+
"""
|
| 128 |
+
def __init__(self, hfov, h, w, height=100.0):
|
| 129 |
+
"""
|
| 130 |
+
alignment algorithm:
|
| 131 |
+
1. pitch alignment
|
| 132 |
+
2. axis alignment
|
| 133 |
+
"""
|
| 134 |
+
super().__init__(hfov, h, w, height)
|
| 135 |
+
|
| 136 |
+
self.ori_view = np.array([0.0, 0.0, -1.0])
|
| 137 |
+
self.cur_view = np.copy(self.ori_view)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def align_horizon(self, cur_horizon):
|
| 141 |
+
""" Given horizon, compute the camera pitch
|
| 142 |
+
"""
|
| 143 |
+
ref_horizon = self.h / 2
|
| 144 |
+
rel_distance = -(ref_horizon - cur_horizon)
|
| 145 |
+
|
| 146 |
+
focal = self.focal()
|
| 147 |
+
pitch = math.atan2(rel_distance, focal)
|
| 148 |
+
|
| 149 |
+
# construct a rotation matrix
|
| 150 |
+
c, s = np.cos(pitch), np.sin(pitch)
|
| 151 |
+
rot = np.array([[0, 0, 0], [0, c, -s], [0, s, c]])
|
| 152 |
+
|
| 153 |
+
# compute the new view vector
|
| 154 |
+
img_plane_view = self.ori_view * focal
|
| 155 |
+
img_plane_view = rot @ img_plane_view.T
|
| 156 |
+
|
| 157 |
+
self.cur_view = img_plane_view / math.sqrt(np.dot(img_plane_view, img_plane_view))
|
| 158 |
+
|
| 159 |
+
def C(self):
|
| 160 |
+
return self.view() * self.focal() - 0.5 * self.w * self.right() + 0.5 * self.h * self.up()
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def right(self):
|
| 164 |
+
return np.array([1.0, 0.0, 0.0])
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def up(self):
|
| 168 |
+
return np.cross(self.right(), self.view())
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def focal(self):
|
| 172 |
+
focal = self.w * 0.5 / math.tan(self.deg2rad(self.fov * 0.5))
|
| 173 |
+
return focal
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def view(self):
|
| 177 |
+
return self.cur_view
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class axis_camera(camera):
|
| 182 |
+
""" Axis alignment camera
|
| 183 |
+
"""
|
| 184 |
+
def __init__(self, hfov, h, w, height=100.0):
|
| 185 |
+
super().__init__(hfov, h, w, height)
|
| 186 |
+
|
| 187 |
+
focal = self.w * 0.5 / math.tan(self.deg2rad(self.fov * 0.5))
|
| 188 |
+
self.up_vec = np.array([0.0, 1.0, 0.0])
|
| 189 |
+
self.right_vec = np.array([1.0, 0.0, 0.0])
|
| 190 |
+
|
| 191 |
+
self.ori_c = np.array([-0.5 * self.w, 0.5 * self.h, -focal])
|
| 192 |
+
self.c_vec = np.copy(self.ori_c)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def align_horizon(self, cur_horizon):
|
| 196 |
+
""" Given horizon, we move the axis to update the horizon
|
| 197 |
+
i.e. we need to change C
|
| 198 |
+
"""
|
| 199 |
+
ref_horizon = self.h // 2
|
| 200 |
+
delta_horizon = cur_horizon - ref_horizon
|
| 201 |
+
self.c_vec = self.ori_c + delta_horizon * self.up()
|
| 202 |
+
# self.height = self.ori_height + delta_horizon
|
| 203 |
+
# self.O = np.array([0.0, self.height, 0.0])
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def C(self):
|
| 208 |
+
return self.c_vec
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def right(self):
|
| 212 |
+
return self.right_vec
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def up(self):
|
| 216 |
+
return self.up_vec
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def test(ppc):
|
| 220 |
+
xyh = np.array([500, 500, 100.0])
|
| 221 |
+
|
| 222 |
+
proj_xyz = ppc.xyh2xyz(xyh)
|
| 223 |
+
proj_xyh = ppc.xyz2xyh(proj_xyz)
|
| 224 |
+
|
| 225 |
+
print('xyh: {}, proj xyz: {}, proj xyh: {}'.format(xyh, proj_xyz, proj_xyh))
|
| 226 |
+
|
| 227 |
+
# import pdb; pdb.set_trace()
|
| 228 |
+
new_horizon_list = [0, 100, 250, 400, 500]
|
| 229 |
+
new_horizon_list = [100, 250, 400, 500]
|
| 230 |
+
|
| 231 |
+
# import pdb; pdb.set_trace()
|
| 232 |
+
for cur_horizon in new_horizon_list:
|
| 233 |
+
ppc.align_horizon(cur_horizon)
|
| 234 |
+
test_xyh = np.array([500, cur_horizon, 0])
|
| 235 |
+
test_xyz = ppc.xyh2xyz(test_xyh)
|
| 236 |
+
|
| 237 |
+
# print('{} -> {} -> {}'.format(test_xyh, test_xyz, ppc.xyz2xyh(test_xyz)))
|
| 238 |
+
print('{} \t -> {} \t -> {}'.format(test_xyh, test_xyz, ppc.xyz2xyh(test_xyz)))
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
if __name__ == '__main__':
|
| 242 |
+
p_camera = pitch_camera(90.0, 500, 500)
|
| 243 |
+
a_camera = axis_camera(90.0, 500, 500)
|
| 244 |
+
|
| 245 |
+
test(p_camera)
|
| 246 |
+
test(a_camera)
|
PixHtLab-Src/Demo/PixhtLab/gssn_demo.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import torch
|
| 3 |
+
import gradio as gr
|
| 4 |
+
from torch import nn
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def render_btn_fn(mask, background, buffers, pitch, roll, softness):
|
| 9 |
+
print('Pitch and roll: {}, {}'.format(pitch, roll))
|
| 10 |
+
print('Mask, background, bufferss: {}, {}, {}'.format(mask.shape, background.shape, buffers.shape))
|
| 11 |
+
pass
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
with gr.Blocks() as demo:
|
| 15 |
+
with gr.Row():
|
| 16 |
+
mask_input = gr.Image(shape=(256, 256), image_mode="L", label="Mask")
|
| 17 |
+
bg_input = gr.Image(shape=(256, 256), image_mode="RGB", label="Background")
|
| 18 |
+
buff_input = gr.Image(shape=(256, 256), image_mode="RGB", label="Buffers")
|
| 19 |
+
|
| 20 |
+
with gr.Row():
|
| 21 |
+
with gr.Column():
|
| 22 |
+
pitch_input = gr.Slider(minimum=0, maximum=1, step=0.01, default=0.5, label="Pitch")
|
| 23 |
+
roll_input = gr.Slider(minimum=0, maximum=1, step=0.01, default=0.5, label="Roll")
|
| 24 |
+
softness_input = gr.Slider(minimum=0, maximum=1, step=0.01, default=0.5, label="Softness")
|
| 25 |
+
|
| 26 |
+
render_btn = gr.Button(label="Render")
|
| 27 |
+
output = gr.Image(shape=(256, 256), image_mode="RGB", label="Output")
|
| 28 |
+
|
| 29 |
+
render_btn.click(render_btn_fn, inputs=[mask_input, bg_input, buff_input, pitch_input, roll_input, softness_input], outputs=output)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
demo.launch()
|
PixHtLab-Src/Demo/PixhtLab/hshadow_render.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import torch
|
| 3 |
+
import hshadow
|
| 4 |
+
# import plane_visualize
|
| 5 |
+
import numpy as np
|
| 6 |
+
from torchvision import transforms
|
| 7 |
+
from scipy.ndimage import uniform_filter
|
| 8 |
+
from ShadowStyle.inference import inference_shadow
|
| 9 |
+
import cv2
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
from utils import *
|
| 12 |
+
from GSSN.inference_shadow import SSN_Infernece
|
| 13 |
+
|
| 14 |
+
device = torch.device("cuda:0")
|
| 15 |
+
to_tensor = transforms.ToTensor()
|
| 16 |
+
model = inference_shadow.init_models('/home/ysheng/Documents/Research/GSSN/HardShadow/qtGUI/weights/human_baseline_all_21-July-04-52-AM.pt')
|
| 17 |
+
# GSSN_model = SSN_Infernece('GSSN/weights/0000000700.pt')
|
| 18 |
+
GSSN_model = SSN_Infernece('/home/ysheng/Documents/Research/GSSN/HardShadow/qtGUI/GSSN/weights/only_shadow/0000000200.pt')
|
| 19 |
+
|
| 20 |
+
def crop_mask(mask):
|
| 21 |
+
hnon, wnon = np.nonzero(mask)
|
| 22 |
+
aabb = (hnon.min(), hnon.max(), wnon.min(), wnon.max())
|
| 23 |
+
return aabb
|
| 24 |
+
|
| 25 |
+
def norm_output(np_img):
|
| 26 |
+
return np.clip(cv2.normalize(np_img, None, 0.0, 1.0, cv2.NORM_MINMAX),0.0,1.0)
|
| 27 |
+
|
| 28 |
+
def padding(mask, shadow, mask_aabb, shadow_aabb, final_shape=(512, 512)):
|
| 29 |
+
mh, mhh, mw, mww = mask_aabb
|
| 30 |
+
sh, shh, sw, sww = shadow_aabb
|
| 31 |
+
cropped_mask, cropped_shadow = mask[mh:mhh, mw:mww], shadow[sh:shh, sw:sww]
|
| 32 |
+
global_h, global_w = mask.shape[:2]
|
| 33 |
+
h, w, c, sc = *cropped_mask.shape, shadow.shape[2]
|
| 34 |
+
fract = 0.4
|
| 35 |
+
if h > w:
|
| 36 |
+
newh = int(final_shape[0]*fract)
|
| 37 |
+
neww = int(newh/h*w)
|
| 38 |
+
else:
|
| 39 |
+
neww = int(final_shape[1]*fract)
|
| 40 |
+
newh = int(neww/w*h)
|
| 41 |
+
|
| 42 |
+
small_mask = cv2.resize(cropped_mask, (neww, newh), interpolation=cv2.INTER_AREA)
|
| 43 |
+
if len(small_mask.shape) == 2:
|
| 44 |
+
small_mask = small_mask[...,np.newaxis]
|
| 45 |
+
|
| 46 |
+
mask_ret, shadow_ret = np.zeros((final_shape[0], final_shape[1], c)),np.ones((final_shape[0], final_shape[1], sc))
|
| 47 |
+
paddingh, paddingw = 10, (final_shape[0]-neww)//2
|
| 48 |
+
mask_lpos = (paddingh, paddingw)
|
| 49 |
+
mask_ret = overlap_replace(mask_ret, small_mask, mask_lpos)
|
| 50 |
+
|
| 51 |
+
# padding shadow
|
| 52 |
+
hscale, wscale = newh/h, neww/w
|
| 53 |
+
newsh, newsw = int((shh-sh) * hscale), int((sww-sw) * wscale)
|
| 54 |
+
small_shadow = cv2.resize(cropped_shadow, (newsw, newsh), interpolation=cv2.INTER_AREA)
|
| 55 |
+
|
| 56 |
+
if len(small_shadow.shape) == 2:
|
| 57 |
+
small_shadow = small_shadow[...,np.newaxis]
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
loffseth, loffsetw = int((sh-mh)*hscale), int((sw-mw)*wscale)
|
| 61 |
+
shadow_lpos = (paddingh + loffseth, paddingw + loffsetw)
|
| 62 |
+
shadow_ret = overlap_replace(shadow_ret, small_shadow, shadow_lpos)
|
| 63 |
+
|
| 64 |
+
# return mask_ret, shadow_ret[...,0:1], [mask_aabb, mask_lpos, hscale, wscale, final_shape, mask.shape[0], mask.shape[1]]
|
| 65 |
+
return mask_ret, shadow_ret, [mask_aabb, mask_lpos, hscale, wscale, final_shape, mask.shape[0], mask.shape[1]]
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def transform_input(mask, hardshadow):
|
| 69 |
+
""" Note, trans_info marks the AABBs, and scaling factors
|
| 70 |
+
"""
|
| 71 |
+
mask_aabb, shadow_aabb = crop_mask(mask[...,0]), crop_mask(hardshadow[...,0])
|
| 72 |
+
# import pdb; pdb.set_trace()
|
| 73 |
+
cmask, cshadow, trans_info = padding(mask, hardshadow, mask_aabb, shadow_aabb)
|
| 74 |
+
return cmask.transpose(2,0,1)[np.newaxis,...], 1.0 - cshadow.transpose(2,0,1)[np.newaxis, ...], trans_info
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def transform_output(softshadow, trans_info):
|
| 78 |
+
mask_aabb, mask_lpos, hscale, wscale, final_shape, h, w = trans_info
|
| 79 |
+
# import pdb; pdb.set_trace()
|
| 80 |
+
ret, gsh, gsw = np.zeros((h,w,1)), int(final_shape[0]/hscale), int(final_shape[1]/wscale)
|
| 81 |
+
global_shadow = cv2.resize(softshadow[0,0], (gsw, gsh))
|
| 82 |
+
|
| 83 |
+
# global start = global_mask_aabb - (local_mask_start)/scaling
|
| 84 |
+
mh, mw, mask_lh, mask_lw = mask_aabb[0], mask_aabb[2], mask_lpos[0], mask_lpos[1]
|
| 85 |
+
starth, startw = int(mh - mask_lh / hscale), int(mw - mask_lw / wscale)
|
| 86 |
+
ret = norm_output(overlap_replace(ret, global_shadow[...,np.newaxis], (starth, startw)))
|
| 87 |
+
if len(ret.shape) == 2:
|
| 88 |
+
ret = ret[..., np.newaxis]
|
| 89 |
+
|
| 90 |
+
return 1.0-ret.repeat(3,axis=2)
|
| 91 |
+
|
| 92 |
+
def style_hardshadow(mask, hardshadow, softness):
|
| 93 |
+
mask_net, hardshadow_net, trans_info = transform_input(mask, hardshadow)
|
| 94 |
+
netsoftshadow = inference_shadow.net_render_np(model, mask_net, hardshadow_net, softness, 0.0)
|
| 95 |
+
softshadow = transform_output(netsoftshadow, trans_info)
|
| 96 |
+
|
| 97 |
+
return softshadow, (norm_output(mask_net[0,0]), norm_output(hardshadow_net[0,0]), norm_output(netsoftshadow[0,0]))
|
| 98 |
+
|
| 99 |
+
def gssn_shadow(mask, pixel_height, shadow_channels, softness):
|
| 100 |
+
# mask_net, hardshadow_net, trans_info = transform_input(mask, shadow_channels)
|
| 101 |
+
|
| 102 |
+
mask_aabb, shadow_aabb = crop_mask(mask[...,0]), crop_mask(shadow_channels[...,0])
|
| 103 |
+
ph_channel, hardshadow_net, trans_info = padding(pixel_height, shadow_channels, mask_aabb, shadow_aabb)
|
| 104 |
+
|
| 105 |
+
ph_channel = ph_channel/512.0
|
| 106 |
+
hardshadow_net = 1.0-hardshadow_net
|
| 107 |
+
input_np = np.concatenate([ph_channel, hardshadow_net], axis=2)
|
| 108 |
+
|
| 109 |
+
# import pdb; pdb.set_trace()
|
| 110 |
+
|
| 111 |
+
netsoftshadow = np.clip(GSSN_model.render_ss(input_np, softness), 0.0, 1.0)
|
| 112 |
+
netsoftshadow = netsoftshadow.transpose((2,0,1))[None, ...]
|
| 113 |
+
softshadow = transform_output(netsoftshadow, trans_info)
|
| 114 |
+
|
| 115 |
+
return softshadow
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def proj_ground(p, light_pos):
|
| 119 |
+
tmpp = p.copy()
|
| 120 |
+
|
| 121 |
+
t = (0-tmpp[2])/(light_pos[:, 2:3]-tmpp[2]+1e-6)
|
| 122 |
+
tmpp = (1.0-t) * tmpp[:2] + t * light_pos[:, :2]
|
| 123 |
+
return tmpp
|
| 124 |
+
|
| 125 |
+
def proj_bb(mask, hmap, light_pos, mouse_pos):
|
| 126 |
+
tmp_lights = light_pos.copy()
|
| 127 |
+
if len(light_pos.shape) == 1:
|
| 128 |
+
tmp_lights = tmp_lights[..., np.newaxis]
|
| 129 |
+
|
| 130 |
+
# bb -> four points
|
| 131 |
+
highest = hmap.max()
|
| 132 |
+
highest_h, highest_w = list(np.unravel_index(np.argmax(hmap), hmap.shape))
|
| 133 |
+
hbb, wbb = np.nonzero(mask)
|
| 134 |
+
h, hh, w, ww = hbb.min(), hbb.max(), wbb.min(), wbb.max()
|
| 135 |
+
bb0, bb1, bb2, bb3 = np.array([w, h, hmap.max()]), np.array([ww, h, hmap.max()]), np.array([w, hh, 0]), np.array([ww, hh, 0])
|
| 136 |
+
|
| 137 |
+
# compute projection for the four points
|
| 138 |
+
tmp_lights = tmp_lights.transpose(1,0)
|
| 139 |
+
bb0, bb1, bb2, bb3 = proj_ground(bb0, tmp_lights), proj_ground(bb1, tmp_lights), proj_ground(bb2, tmp_lights), proj_ground(bb3, tmp_lights)
|
| 140 |
+
|
| 141 |
+
batch = len(tmp_lights)
|
| 142 |
+
new_bb = np.zeros((batch, 4))
|
| 143 |
+
for i in range(batch):
|
| 144 |
+
new_bb[i, 0] = min([bb0[i, 1], bb1[i,1], bb2[i, 1], bb3[i, 1], mouse_pos[1], h]) # h
|
| 145 |
+
new_bb[i, 1] = max([bb0[i, 1], bb1[i,1], bb2[i, 1], bb3[i, 1], mouse_pos[1], hh])
|
| 146 |
+
new_bb[i, 2] = min([bb0[i, 0], bb1[i,0], bb2[i, 0], bb3[i, 0], mouse_pos[0], w]) # w
|
| 147 |
+
new_bb[i, 3] = max([bb0[i, 0], bb1[i,0], bb2[i, 0], bb3[i, 0], mouse_pos[0], ww])
|
| 148 |
+
|
| 149 |
+
return new_bb
|
| 150 |
+
|
| 151 |
+
def to_torch_device(np_img):
|
| 152 |
+
if len(np_img.shape) == 3:
|
| 153 |
+
return to_tensor(np_img).float().unsqueeze(dim=0).contiguous().to(device)
|
| 154 |
+
else:
|
| 155 |
+
return torch.from_numpy(np_img).float().contiguous().to(device)
|
| 156 |
+
|
| 157 |
+
def hshadow_render(rgb, mask, hmap, rechmap, light_pos, mouse_pos):
|
| 158 |
+
""" Heightmap Shadow Rendering
|
| 159 |
+
rgb: H x W x e
|
| 160 |
+
mask: H x W x 1
|
| 161 |
+
hmap: H x W x 1
|
| 162 |
+
rechmap: H x W x 1
|
| 163 |
+
light_pos: (3,B)
|
| 164 |
+
return:
|
| 165 |
+
shadow masking
|
| 166 |
+
"""
|
| 167 |
+
|
| 168 |
+
hbb, wbb = np.nonzero(mask[...,0])
|
| 169 |
+
# speed optimization
|
| 170 |
+
bb = proj_bb(mask[...,0], hmap[...,0], light_pos, mouse_pos)
|
| 171 |
+
|
| 172 |
+
# import pdb; pdb.set_trace()
|
| 173 |
+
if len(light_pos.shape) == 1:
|
| 174 |
+
light_pos_d = torch.from_numpy(light_pos).to(device).unsqueeze(dim=0).float()
|
| 175 |
+
rgb_d, mask_d, hmap_d, rechmap_d = to_torch_device(rgb), to_torch_device(mask), to_torch_device(hmap), to_torch_device(rechmap)
|
| 176 |
+
bb_d = torch.from_numpy(bb).float().to(device)
|
| 177 |
+
batch = 1
|
| 178 |
+
else:
|
| 179 |
+
light_pos_d = torch.from_numpy(np.ascontiguousarray(light_pos.transpose(1,0))).float().to(device)
|
| 180 |
+
batch = len(light_pos_d)
|
| 181 |
+
h,w = rgb.shape[:2]
|
| 182 |
+
rgb_d = to_torch_device(np.repeat(rgb[np.newaxis,...].transpose(0,3,1,2), batch, axis=0))
|
| 183 |
+
mask_d = to_torch_device(np.repeat(mask[np.newaxis,...].transpose(0,3,1,2), batch, axis=0))
|
| 184 |
+
hmap_d = to_torch_device(np.repeat(hmap[np.newaxis,...].transpose(0,3,1,2), batch, axis=0))
|
| 185 |
+
rechmap_d = to_torch_device(np.repeat(rechmap[np.newaxis,...].transpose(0,3,1,2), batch, axis=0))
|
| 186 |
+
bb_d = torch.from_numpy(np.ascontiguousarray(bb)).float().to(device)
|
| 187 |
+
|
| 188 |
+
shadow = hshadow.forward(rgb_d, mask_d, bb_d, hmap_d, rechmap_d, light_pos_d)
|
| 189 |
+
# mask_top_pos = list(np.unravel_index(np.argmax(hmap), hmap.shape))
|
| 190 |
+
# x,y = mask_top_pos[1], mask_top_pos[0]
|
| 191 |
+
# mh = hmap[y,x,0]
|
| 192 |
+
# light_top_d = light_pos_d - torch.tensor([[x,y,mh]]).to(light_pos_d)
|
| 193 |
+
# weights = torch.abs(light_top_d[:,2]/torch.sqrt((light_top_d[:,0] **2 + light_top_d[:,1] **2)))
|
| 194 |
+
# print('weights: ', weights)
|
| 195 |
+
# weights = (weights)/weights.sum()
|
| 196 |
+
|
| 197 |
+
# print(weights.shape, shadow[0].shape)
|
| 198 |
+
# flipped = (weights[...,None, None,None] * (1.0-shadow[0])).sum(dim=0, keepdim=True)
|
| 199 |
+
# shadow = shadow[0].sum(dim=0, keepdim=True)/len(shadow[0])
|
| 200 |
+
# return (1.0-flipped)[0].detach().cpu().numpy().transpose(1,2,0)
|
| 201 |
+
|
| 202 |
+
shadow = shadow[0].sum(dim=0, keepdim=True)/len(shadow[0])
|
| 203 |
+
return shadow[0].detach().cpu().numpy().transpose(1,2,0)
|
| 204 |
+
|
| 205 |
+
def refine_shadow(shadow, intensity=0.6, filter=5):
|
| 206 |
+
shadow[...,0] = uniform_filter(shadow[...,0], size=filter)
|
| 207 |
+
shadow[...,1] = uniform_filter(shadow[...,1], size=filter)
|
| 208 |
+
shadow[...,2] = uniform_filter(shadow[...,2], size=filter)
|
| 209 |
+
return 1.0 - (1.0-shadow) * intensity
|
| 210 |
+
|
| 211 |
+
def render_ao(rgb, mask, hmap):
|
| 212 |
+
rechmap = np.zeros_like(hmap)
|
| 213 |
+
hbb, wbb = np.nonzero(mask[...,0])
|
| 214 |
+
# light_pos = np.array([hbb.min(), (wbb.min() + wbb.max()) * 0.8, -100000])
|
| 215 |
+
light_pos = np.array([-1300.10811363, -46999.86253089, 46486.73121776])
|
| 216 |
+
mouse_pos = light_pos
|
| 217 |
+
|
| 218 |
+
shadow = hshadow_render(rgb, mask, hmap, rechmap, light_pos, mouse_pos)
|
| 219 |
+
softshadow = style_hardshadow(mask, shadow[..., :1], 0.45)[0]
|
| 220 |
+
softshadow = refine_shadow(softshadow)
|
| 221 |
+
return softshadow
|
| 222 |
+
|
| 223 |
+
def ao_composite(rgb, mask, hmap, rechmap, light_pos, mouse_pos):
|
| 224 |
+
# shadow = hshadow_render(rgb, mask, hmap, rechmap, light_pos, mouse_pos)
|
| 225 |
+
# softshadow = style_hardshadow(mask, shadow, 0.45)[0]
|
| 226 |
+
# softshadow = refine_shadow(softshadow)
|
| 227 |
+
|
| 228 |
+
softshadow = render_ao(rgb, mask, hmap)
|
| 229 |
+
mask_ = np.repeat(mask, 3, axis=2)
|
| 230 |
+
return (1.0-mask_) * softshadow * rgb + mask_ * rgb, softshadow.copy()
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def render_shadow(rgb, mask, hmap, rechmap, light_pos, mouse_pos, softness, shadow_intensity=0.6):
|
| 234 |
+
shadow = hshadow_render(rgb, mask, hmap, rechmap, light_pos, mouse_pos)
|
| 235 |
+
|
| 236 |
+
if softness is not None:
|
| 237 |
+
shadow, dbgs = style_hardshadow(mask, shadow[..., :1], softness)
|
| 238 |
+
else:
|
| 239 |
+
dbgs = None
|
| 240 |
+
|
| 241 |
+
shadow = refine_shadow(shadow, intensity=shadow_intensity)
|
| 242 |
+
return shadow, dbgs
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def hshadow_composite(rgb, mask, hmap, rechmap, light_pos, mouse_pos, softness, shadow_intensity=0.6):
|
| 246 |
+
""" Shadow Rendering and Composition
|
| 247 |
+
rgb: H x W x 3
|
| 248 |
+
mask: H x W x 1
|
| 249 |
+
hmap: H x W x 1
|
| 250 |
+
rechmap: H x W x 1
|
| 251 |
+
light_pos: [x,y,h]
|
| 252 |
+
return:
|
| 253 |
+
Compositied image
|
| 254 |
+
"""
|
| 255 |
+
shadow, dbgs = render_shadow(rgb, mask, hmap, rechmap, light_pos, mouse_pos, softness, shadow_intensity)
|
| 256 |
+
mask_ = np.repeat(mask, 3, axis=2)
|
| 257 |
+
return (1.0-mask_) * shadow * rgb + mask_ * rgb, shadow.copy(), dbgs
|
| 258 |
+
|
| 259 |
+
# def vis_horizon(fov, horizon, h, w):
|
| 260 |
+
# # fov, horizon = 120, 400
|
| 261 |
+
# camera = torch.tensor([[fov, horizon]])
|
| 262 |
+
# planes = torch.tensor([[0.0, 0.0, 0.0, 0.0, 1.0, 0.0]])
|
| 263 |
+
|
| 264 |
+
# camera = camera.float().to(device)
|
| 265 |
+
# planes = planes.float().to(device)
|
| 266 |
+
|
| 267 |
+
# ground_vis = plane_visualize.forward(planes, camera, h, w)[0]
|
| 268 |
+
# return 1.0-ground_vis[0].detach().cpu().numpy().transpose(1,2,0)
|