karthikeya1212 commited on
Commit
cda88e0
·
verified ·
1 Parent(s): 978fc9b

Upload 115 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. PixHtLab-Src/.gitignore +5 -0
  3. PixHtLab-Src/Data/data_prepare.py +21 -0
  4. PixHtLab-Src/Demo/PixhtLab/Demo.ipynb +0 -0
  5. PixHtLab-Src/Demo/PixhtLab/Examples/009_depth.npy +3 -0
  6. PixHtLab-Src/Demo/PixhtLab/Examples/009_depth_valid_mask.npy +3 -0
  7. PixHtLab-Src/Demo/PixhtLab/Examples/009_pixht_new.npy +3 -0
  8. PixHtLab-Src/Demo/PixhtLab/Examples/010_depth.npy +3 -0
  9. PixHtLab-Src/Demo/PixhtLab/Examples/010_depth_valid_mask.npy +3 -0
  10. PixHtLab-Src/Demo/PixhtLab/Examples/010_pixht_new.npy +3 -0
  11. PixHtLab-Src/Demo/PixhtLab/Examples/011_depth.npy +3 -0
  12. PixHtLab-Src/Demo/PixhtLab/Examples/011_depth_valid_mask.npy +3 -0
  13. PixHtLab-Src/Demo/PixhtLab/Examples/011_pixht_new.npy +3 -0
  14. PixHtLab-Src/Demo/PixhtLab/Examples/c82c09fc01e84282bc8870c263dcf81b_bg.jpg +3 -0
  15. PixHtLab-Src/Demo/PixhtLab/GSSN/__init__.py +0 -0
  16. PixHtLab-Src/Demo/PixhtLab/GSSN/inference_shadow.py +70 -0
  17. PixHtLab-Src/Demo/PixhtLab/ShadowStyle/__init__.py +0 -0
  18. PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/__init__.py +0 -0
  19. PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/inference_shadow.py +70 -0
  20. PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/params.py +92 -0
  21. PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/ssn/__init__.py +0 -0
  22. PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/ssn/perturb_touch.py +24 -0
  23. PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/ssn/random_pattern.py +92 -0
  24. PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/ssn/ssn.py +146 -0
  25. PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/ssn/ssn_dataset.py +290 -0
  26. PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/ssn/ssn_submodule.py +282 -0
  27. PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/ssn/test.py +24 -0
  28. PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/ssn/test_dataset.py +21 -0
  29. PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/__init__.py +0 -0
  30. PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/html.py +61 -0
  31. PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/html_server.py +9 -0
  32. PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/imgs +1 -0
  33. PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/index.html +0 -0
  34. PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/make_html.py +133 -0
  35. PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/net_utils.py +70 -0
  36. PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/tensorboard_utils.py +29 -0
  37. PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/time_utils.py +6 -0
  38. PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/utils_file.py +59 -0
  39. PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/vis_test_results.py +21 -0
  40. PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/visdom_utils.py +53 -0
  41. PixHtLab-Src/Demo/PixhtLab/Torch_Render/hshadow_cuda.cpp +98 -0
  42. PixHtLab-Src/Demo/PixhtLab/Torch_Render/hshadow_cuda_kernel.cu +682 -0
  43. PixHtLab-Src/Demo/PixhtLab/Torch_Render/plane_visualize.cpp +26 -0
  44. PixHtLab-Src/Demo/PixhtLab/Torch_Render/plane_visualize_cuda.cu +237 -0
  45. PixHtLab-Src/Demo/PixhtLab/Torch_Render/setup.py +29 -0
  46. PixHtLab-Src/Demo/PixhtLab/Torch_Render/test_ground.py +33 -0
  47. PixHtLab-Src/Demo/PixhtLab/Torch_Render/test_hshadow.py +130 -0
  48. PixHtLab-Src/Demo/PixhtLab/camera.py +246 -0
  49. PixHtLab-Src/Demo/PixhtLab/gssn_demo.py +32 -0
  50. PixHtLab-Src/Demo/PixhtLab/hshadow_render.py +268 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ PixHtLab-Src/Demo/PixhtLab/Examples/c82c09fc01e84282bc8870c263dcf81b_bg.jpg filter=lfs diff=lfs merge=lfs -text
PixHtLab-Src/.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ .DS_Store
2
+ .idea
3
+ *.log
4
+ tmp/
5
+ *__pycache__*
PixHtLab-Src/Data/data_prepare.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import objaverse
2
+ import multiprocessing
3
+ import random
4
+
5
+ processes = 8
6
+
7
+ random.seed(0)
8
+ uids = objaverse.load_uids()
9
+ random_object_uids = random.sample(uids, 100)
10
+ objects = objaverse.load_objects(
11
+ uids=random_object_uids,
12
+ download_processes=processes
13
+ )
14
+
15
+
16
+ objects = objaverse.load_objects(
17
+ uids=random_object_uids,
18
+ download_processes=processes
19
+ )
20
+
21
+ import pdb; pdb.set_trace()
PixHtLab-Src/Demo/PixhtLab/Demo.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
PixHtLab-Src/Demo/PixhtLab/Examples/009_depth.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:17338c5be1e9aa07c8d68ed26076af44f433374a92a90130b3f803d66f8d2442
3
+ size 1048704
PixHtLab-Src/Demo/PixhtLab/Examples/009_depth_valid_mask.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:85f888fa80c5da36acf97938cda8f55323758aefda263fc859ab83b2955c92c1
3
+ size 262272
PixHtLab-Src/Demo/PixhtLab/Examples/009_pixht_new.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:25dfedce61e31739fec2164898d9975ab9224c02aea1540029eb7b04d0505118
3
+ size 2097280
PixHtLab-Src/Demo/PixhtLab/Examples/010_depth.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:872277399788bc63b43a1fc63e817d8c12d18e6960b6d7e51a41d2ea2fae21f6
3
+ size 1048704
PixHtLab-Src/Demo/PixhtLab/Examples/010_depth_valid_mask.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ceb9b34576604a240e557f2bb2fa87b7328e30a28a44357d799c6a3271c6bf6
3
+ size 262272
PixHtLab-Src/Demo/PixhtLab/Examples/010_pixht_new.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb2031249ece53825e983e62e1a695c53bdfcd5ae12d3a1350dff80e710f0aa3
3
+ size 2097280
PixHtLab-Src/Demo/PixhtLab/Examples/011_depth.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bfdabce8a328c02e7e52d156a2f2f03e1302fb695e9e70f35355d1aced7de24f
3
+ size 1048704
PixHtLab-Src/Demo/PixhtLab/Examples/011_depth_valid_mask.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:12eaae90b70229aa83450d8ea63fd77c7e8fdc6240e473aafab77021c2891efe
3
+ size 262272
PixHtLab-Src/Demo/PixhtLab/Examples/011_pixht_new.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9ef1a1f21f4ac53dc680fc6a965cce4f8043bd313a89e52ef733c5986be16595
3
+ size 2097280
PixHtLab-Src/Demo/PixhtLab/Examples/c82c09fc01e84282bc8870c263dcf81b_bg.jpg ADDED

Git LFS Details

  • SHA256: d51aad0c7a87ff499f6bee8834e87c9013e3800b0db25aee2b497398d0687ab7
  • Pointer size: 132 Bytes
  • Size of remote file: 1.26 MB
PixHtLab-Src/Demo/PixhtLab/GSSN/__init__.py ADDED
File without changes
PixHtLab-Src/Demo/PixhtLab/GSSN/inference_shadow.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from PIL import Image
4
+ import matplotlib.pyplot as plt
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.utils.data import Dataset, DataLoader
9
+ import torch.optim as optim
10
+ import torchvision
11
+ import torchvision.transforms as T
12
+ import argparse
13
+ import time
14
+ from tqdm import tqdm
15
+ import numpy as np
16
+ import os
17
+ from os.path import join
18
+
19
+ import math
20
+
21
+ import cv2
22
+ import random
23
+
24
+ import sys
25
+ # sys.path.insert(0, '../../Training/app/models')
26
+ sys.path.insert(0, '/home/ysheng/Documents/Research/GSSN/Training/app/models')
27
+
28
+ from SSN_v1 import SSN_v1
29
+ from SSN import SSN
30
+
31
+
32
+ class SSN_Infernece():
33
+ def __init__(self, ckpt, device=torch.device('cuda:0')):
34
+ self.device = device
35
+ self.model = SSN(3, 1, mid_act='gelu', out_act='null', resnet=False)
36
+
37
+ weight = torch.load(ckpt)
38
+ self.model.to(device)
39
+ self.model.load_state_dict(weight['model'])
40
+
41
+ # inference related
42
+ BINs = 100
43
+ MAX_RAD = 20
44
+ self.size_interval = MAX_RAD / BINs
45
+ self.soft_distribution = [[np.exp(-0.2 * (i - j) ** 2) for i in np.arange(BINs)] for j in np.arange(BINs)]
46
+
47
+
48
+ def render_ss(self, input_np, softness):
49
+ """ input_np:
50
+ H x W x C
51
+ """
52
+ input_tensor = torch.tensor(input_np.transpose((2, 0, 1)))[None, ...].float().to(self.device)
53
+ transform = T.Resize((256, 256))
54
+
55
+ c = input_tensor.shape[1]
56
+ # for i in range(c):
57
+ # print(input_tensor[:, i].min(), input_tensor[:, i].max())
58
+
59
+ # print('softness: ', softness)
60
+ l = torch.from_numpy(np.array(self.soft_distribution[int(softness/self.size_interval)]).astype(np.float32)).unsqueeze(dim=0).to(self.device)
61
+
62
+ input_tensor = transform(input_tensor)
63
+ output_tensor = self.model(input_tensor, l)
64
+ output_np = output_tensor[0].detach().cpu().numpy().transpose((1,2,0))
65
+
66
+ return output_np
67
+
68
+
69
+ if __name__ == '__main__':
70
+ model = SSN_Infernece('weights/0000000700.pt')
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/__init__.py ADDED
File without changes
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/__init__.py ADDED
File without changes
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/inference_shadow.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from PIL import Image
4
+ import matplotlib.pyplot as plt
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.utils.data import Dataset, DataLoader
9
+ import torch.optim as optim
10
+ import torchvision
11
+ import torchvision.transforms as transforms
12
+ import argparse
13
+ import time
14
+ from tqdm import tqdm
15
+ import numpy as np
16
+ import os
17
+ from os.path import join
18
+
19
+ import math
20
+
21
+ import cv2
22
+ import random
23
+ from .ssn.ssn import Relight_SSN
24
+ device = torch.device('cuda:0')
25
+
26
+ def net_render_np(model, mask_np, hard_shadow_np, size, orientation):
27
+ """
28
+ input:
29
+ mask_np shape: b x c x h x w
30
+ ibl_np shape: 1 x 16 x 32
31
+ output:
32
+ shadow_predict shape: b x c x h x w
33
+ """
34
+
35
+ size_interval = 0.5 / 100
36
+ ori_interval = np.pi / 100
37
+
38
+ soft_distribution = [[np.exp(-0.2 * (i - j) ** 2) for i in np.arange(0.5 / size_interval)]
39
+ for j in np.arange(0.5 / size_interval)]
40
+
41
+ # print('mask_np: {}, hard_shadow_np: {}'.format(mask_np.shape, hard_shadow_np.shape))
42
+ s = time.time()
43
+ if mask_np.dtype == np.uint8:
44
+ mask_np = mask_np / 255.0
45
+
46
+ mask, h_shadow = torch.Tensor(mask_np), torch.Tensor(hard_shadow_np)
47
+ size_soft = torch.Tensor(np.array(soft_distribution[int(size / size_interval)])).unsqueeze(0)
48
+ ori_soft = torch.Tensor(np.array(soft_distribution[int(orientation / ori_interval)])).unsqueeze(0)
49
+
50
+ with torch.no_grad():
51
+ I_m, I_h, size_t, ori = mask.to(device), h_shadow.to(device), size_soft.to(device), ori_soft.to(device)
52
+ # print('I_m: {}, I_h: {}'.format(I_m.shape, I_h.shape))
53
+ predicted_img = model(I_h, I_m, size_t, ori)
54
+
55
+ # print('net predict finished, time: {}s'.format(time.time() - s))
56
+
57
+ return predicted_img.detach().cpu().numpy()
58
+
59
+ def init_models(ckpt):
60
+ baseline_model = Relight_SSN(1, 1, is_training=False)
61
+ baseline_checkpoint = torch.load(ckpt)
62
+ baseline_model.to(device)
63
+ baseline_model.load_state_dict(baseline_checkpoint['model_state_dict'])
64
+ return baseline_model
65
+
66
+ if __name__ == '__main__':
67
+ softness = [0.02, 0.2, 0.3, 0.4]
68
+ model = init_models('weights/human_baseline123.pt')
69
+ mask, hard_shadow, size, orientation = np.random.randn(1,1,256,256), np.random.randn(1,1,256,256), softness[0], 0
70
+ shadow = net_render_np(model, mask, hard_shadow, size, orientation)
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/params.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ class params():
4
+ """ Singleton class for doing experiments """
5
+
6
+ class __params():
7
+ def __init__(self):
8
+ self.norm = 'group_norm'
9
+ self.prelu = False
10
+ self.weight_decay = 5e-4
11
+ self.small_ds = False
12
+ self.multi_gpu = False
13
+ self.log = False
14
+ self.input_channel = 1
15
+ self.vis_port = 8002
16
+ self.cpu = False
17
+ self.pred_touch = False
18
+ self.tbaseline = False
19
+ self.touch_loss = False
20
+ self.input_channel = 1
21
+
22
+ def set_params(self, options):
23
+ self.options = options
24
+ self.norm = options.norm
25
+ self.prelu = options.prelu
26
+ self.weight_decay = options.weight_decay
27
+ self.small_ds = options.small_ds
28
+ self.multi_gpu = options.multi_gpu
29
+ self.log = options.log
30
+ self.input_channel = options.input_channel
31
+ self.vis_port = options.vis_port
32
+ self.cpu = options.cpu
33
+ self.ds_folder = options.ds_folder
34
+ self.pred_touch = options.pred_touch
35
+ self.tbaseline = options.tbaseline
36
+ self.touch_loss = options.touch_loss
37
+
38
+ def __str__(self):
39
+ return 'norm: {} prelu: {} weight decay: {} small ds: {}'.format(self.norm, self.prelu, self.weight_decay, self.small_ds)
40
+
41
+ # private static variable
42
+ param_instance = None
43
+
44
+ def __init__(self):
45
+ if not params.param_instance:
46
+ params.param_instance = params.__params()
47
+
48
+ def get_params(self):
49
+ return params.param_instance
50
+
51
+ def set_params(self, options):
52
+ params.param_instance.set_params(options)
53
+
54
+ def parse_params():
55
+ parser = argparse.ArgumentParser()
56
+ parser.add_argument('--workers', type=int, help='number of data loading workers', default=16)
57
+ parser.add_argument('--batch_size', type=int, default=28, help='input batch size during training')
58
+ parser.add_argument('--epochs', type=int, default=10000, help='number of epochs to train for')
59
+ parser.add_argument('--lr', type=float, default=0.003, help='learning rate, default=0.005')
60
+ parser.add_argument('--beta1', type=float, default=0.9, help='momentum for SGD, default=0.9')
61
+ parser.add_argument('--resume', action='store_true', help='resume training')
62
+ parser.add_argument('--relearn', action='store_true', help='forget previous best validation loss')
63
+ parser.add_argument('--weight_file',type=str, help='weight file')
64
+ parser.add_argument('--multi_gpu', action='store_true', help='use multiple GPU training')
65
+ parser.add_argument('--timers', type=int, default=1, help='number of epochs to train for')
66
+ parser.add_argument('--use_schedule', action='store_true',help='use automatic schedule')
67
+ parser.add_argument('--patience', type=int, default=2, help='use automatic schedule')
68
+ parser.add_argument('--exp_name', type=str, default='l1 loss',help='experiment name')
69
+ parser.add_argument('--norm', type=str, default='group_norm', help='use group norm')
70
+ parser.add_argument('--ds_folder', type=str, default='./dataset/general_dataset', help='Dataset folder')
71
+ parser.add_argument('--hd_dir', type=str, default='/mnt/yifan/data/Adobe/HD_styleshadow/', help='Dataset folder')
72
+ parser.add_argument('--prelu', action='store_true', help='use prelu')
73
+ parser.add_argument('--small_ds', action='store_true', help='small dataset')
74
+ parser.add_argument('--log', action='store_true', help='log information')
75
+ parser.add_argument('--vis_port', default=8002,type=int, help='visdom port')
76
+ parser.add_argument('--weight_decay', type=float, default=4e-5, help='weight decay for model weight')
77
+ parser.add_argument('--save', action='store_true', help='save batch results?')
78
+ parser.add_argument('--cpu', action='store_true', help='Force training on CPU')
79
+ parser.add_argument('--pred_touch', action='store_true', help='Use touching surface')
80
+ parser.add_argument('--input_channel', type=int, default=1, help='how many input channels')
81
+
82
+ # based on baseline method, for fine tuning
83
+ parser.add_argument('--from_baseline', action='store_true', help='training from baseline')
84
+ parser.add_argument('--tbaseline', action='store_true', help='T-baseline, input two channels')
85
+ parser.add_argument('--touch_loss', action='store_true', help='Use touching loss')
86
+
87
+
88
+ arguments = parser.parse_args()
89
+ parameter = params()
90
+ parameter.set_params(arguments)
91
+
92
+ return arguments
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/ssn/__init__.py ADDED
File without changes
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/ssn/perturb_touch.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import random
4
+
5
+ random.seed(19920208)
6
+
7
+ def random_kernel():
8
+ ksize = random.randint(1,3)
9
+ kernel = np.ones((ksize, ksize))
10
+ return kernel
11
+
12
+ def random_perturb(img):
13
+ return img
14
+ # perturbed = img.copy()
15
+ # if random.random() < 0.5:
16
+ # perturbed = cv2.erode(perturbed, random_kernel(), iterations = 1)
17
+
18
+ # if random.random() < 0.5:
19
+ # perturbed = cv2.dilate(perturbed, random_kernel(), iterations = 1)
20
+
21
+ # cv2.normalize(perturbed, perturbed, 0.0,1.0, cv2.NORM_MINMAX)
22
+ # if len(perturbed.shape) == 2:
23
+ # perturbed = perturbed[:,:,np.newaxis]
24
+ # return perturbed
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/ssn/random_pattern.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import time
3
+ import numbergen as ng
4
+ import imagen as ig
5
+ import numpy as np
6
+ import cv2
7
+ from param.parameterized import get_logger
8
+ import logging
9
+
10
+ get_logger().setLevel(logging.ERROR)
11
+
12
+ class random_pattern():
13
+ def __init__(self, maximum_blob=50):
14
+ # self.generator_list = []
15
+
16
+ # start = time.time()
17
+ # for i in range(maximum_blob):
18
+ # self.generator_list.append(ig.Gaussian(size=))
19
+ # print('random pattern init time: {}s'.format(time.time()-start))
20
+
21
+ pass
22
+
23
+ def y_transform(self, y):
24
+ # y = []
25
+ pass
26
+
27
+ def get_pattern(self, w, h, x_density=512, y_density=128, num=50, scale=3.0, size=0.1, energy=3500,
28
+ mitsuba=False, seed=None, dataset=False):
29
+ if seed is None:
30
+ seed = random.randint(0, 19920208)
31
+ else:
32
+ seed = seed + int(time.time())
33
+
34
+ if num == 0:
35
+ ibl = np.zeros((y_density, x_density))
36
+ orientation = np.pi * ng.UniformRandom(seed=seed + 3)()
37
+ else:
38
+ y_fact = y_density / 256
39
+ num = 1
40
+ size = size * ng.UniformRandom(seed=seed + 4)()
41
+ orientation = np.pi * ng.UniformRandom(seed=seed + 3)()
42
+ gs = ig.Composite(operator=np.add,
43
+ generators=[ig.Gaussian(
44
+ size=size,
45
+ scale=1.0,
46
+ x=ng.UniformRandom(seed=seed + i + 1) - 0.5,
47
+ y=((1.0 - ng.UniformRandom(seed=seed + i + 2) * y_fact) - 0.5),
48
+ aspect_ratio=0.7,
49
+ orientation=orientation,
50
+ ) for i in range(num)],
51
+ position=(0, 0),
52
+ xdensity=512)
53
+
54
+ # gs = ig.Composite(operator=np.add,
55
+ # generators=[ig.Gaussian(
56
+ # size=size * ng.UniformRandom(seed=seed + i + 4),
57
+ # scale=scale * (ng.UniformRandom(seed=seed + i + 5) + 1e-3),
58
+ # x=int(ind / h),
59
+ # y=ind % h,
60
+ # aspect_ratio=0.7,
61
+ # orientation=np.pi * ng.UniformRandom(seed=seed + i + 3),
62
+ # ) for i in range(num)],
63
+ # position=(0, 0),
64
+ # xdensity=512)
65
+ ibl = gs()[:y_density, :]
66
+
67
+ # prepare to fix energy inconsistent
68
+ if dataset:
69
+ ibl = self.to_dataset(ibl, w, h)
70
+
71
+ if mitsuba:
72
+ return ibl, size, orientation
73
+ else:
74
+ return ibl, size, orientation
75
+
76
+ def to_mts_ibl(self, ibl):
77
+ """ Input: 256 x 512 pattern generated ibl
78
+ Output: the ibl in mitsuba ibl
79
+ """
80
+ return np.repeat(ibl[:, :, np.newaxis], 3, axis=2)
81
+
82
+ def normalize(self, ibl, energy=30.0):
83
+ total_energy = np.sum(ibl)
84
+ if total_energy < 1e-3:
85
+ print('small energy: ', total_energy)
86
+ h, w = ibl.shape
87
+ return np.zeros((h, w))
88
+
89
+ return ibl * energy / total_energy
90
+
91
+ def to_dataset(self, ibl, w, h):
92
+ return self.normalize(cv2.flip(cv2.resize(ibl, (w, h)), 0), 30)
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/ssn/ssn.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .ssn_submodule import Conv, Up, Up_Stream, get_layer_info, add_coords
6
+ import copy
7
+
8
+ class Relight_SSN(nn.Module):
9
+ """ Implementation of Relighting Net """
10
+
11
+ def __init__(self, n_channels=3, out_channels=3, is_training=True, activation_func = 'relu'):
12
+ super(Relight_SSN, self).__init__()
13
+ self.is_training = is_training
14
+
15
+ norm_layer1, activation_func1 = get_layer_info(16, activation_func)
16
+ norm_layer2, activation_func2 = get_layer_info(16, activation_func)
17
+ if norm_layer1 is not None:
18
+ self.in_conv1 = nn.Sequential(
19
+ nn.Conv2d(n_channels, 16, kernel_size=7, padding=3, bias=True),
20
+ norm_layer1,
21
+ activation_func1
22
+ )
23
+ elif norm_layer1 is None:
24
+ self.in_conv1 = nn.Sequential(
25
+ nn.Conv2d(n_channels, 16, kernel_size=7, padding=3, bias=True),
26
+ activation_func1
27
+ )
28
+
29
+ if norm_layer2 is not None:
30
+ self.in_conv2 = nn.Sequential(
31
+ nn.Conv2d(n_channels, 16, kernel_size=7, padding=3, bias=True),
32
+ norm_layer2,
33
+ activation_func2
34
+ )
35
+ elif norm_layer2 is None:
36
+ self.in_conv2 = nn.Sequential(
37
+ nn.Conv2d(n_channels, 16, kernel_size=7, padding=3, bias=True),
38
+ activation_func2
39
+ )
40
+
41
+ self.down_256_128 = Conv(32, 64, conv_stride=2)
42
+ self.down_128_128 = Conv(64, 64, conv_stride=1)
43
+ self.down_128_64 = Conv(64, 128, conv_stride=2)
44
+ self.down_64_64 = Conv(128, 128, conv_stride=1)
45
+ self.down_64_32 = Conv(128, 256, conv_stride=2)
46
+ self.down_32_32 = Conv(256, 256, conv_stride=1)
47
+ self.down_32_16 = Conv(256, 512, conv_stride=2)
48
+ self.down_16_16_1 = Conv(512, 512, conv_stride=1)
49
+ self.down_16_16_2 = Conv(512, 512, conv_stride=1)
50
+ self.down_16_16_3 = Conv(512, 512, conv_stride=1)
51
+ self.to_bottleneck = Conv(512, 2, conv_stride=1)
52
+
53
+ self.up_stream = Up_Stream(out_channels)
54
+
55
+ """
56
+ Input is (source image, target light, source light, )
57
+ Output is: predicted new image, predicted source light, self-supervision image
58
+ """
59
+
60
+ def forward(self, I_h, I_m, size, angle):
61
+ if self.is_training:
62
+ bs, fake_bs, dim = size.shape
63
+ bs, fake_bs, ch, w, h = I_h.shape
64
+ size = size.reshape(bs * fake_bs, dim)
65
+ angle = angle.reshape(bs * fake_bs, dim)
66
+ I_h = I_h.reshape(bs * fake_bs, ch, w, h)
67
+ I_m = I_m.reshape(bs * fake_bs, ch, w, h)
68
+ else:
69
+ size = size
70
+ angle = angle
71
+ I_h = I_h
72
+ I_m = I_m
73
+ style = torch.cat((size, angle), dim=1)
74
+ x1 = self.in_conv1(I_m) # 29 x 256 x 256
75
+ x2 = self.in_conv2(I_h)
76
+
77
+ x1 = torch.cat((x1, x2), dim=1) # 32 x 256 x 256
78
+
79
+ x2 = self.down_256_128(x1, x1) # 64 x 128 x 128
80
+
81
+ x3 = self.down_128_128(x2, x1) # 64 x 128 x 128
82
+
83
+ x4 = self.down_128_64(x3, x1) # 128 x 64 x 64
84
+
85
+ x5 = self.down_64_64(x4, x1) # 128 x 64 x 64
86
+
87
+ x6 = self.down_64_32(x5, x1) # 256 x 32 x 32
88
+
89
+ x7 = self.down_32_32(x6, x1) # 256 x 32 x 32
90
+
91
+ x8 = self.down_32_16(x7, x1) # 512 x 16 x 16
92
+
93
+ x9 = self.down_16_16_1(x8, x1) # 512 x 16 x 16
94
+
95
+ x10 = self.down_16_16_2(x9, x1) # 512 x 16 x 16
96
+
97
+ x11 = self.down_16_16_3(x10, x1) # 512 x 16 x 16
98
+ ty = self.up_stream(x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, style)
99
+
100
+ return ty
101
+
102
+
103
+ def baseline_2_tbaseline(model):
104
+ """ change input layer to be two channels
105
+ """
106
+ input_channel = 2
107
+ tbase_inlayer = nn.Sequential(
108
+ nn.Conv2d(input_channel, 32 - input_channel, kernel_size=7, padding=3, bias=True),
109
+ nn.GroupNorm(1, 32 - input_channel),
110
+ nn.ReLU()
111
+ )
112
+ model.in_conv = tbase_inlayer
113
+ return model
114
+
115
+
116
+ def baseline_2_touchloss(model):
117
+ """ change output layer to be two channels
118
+ """
119
+ touchless_outlayer = nn.Sequential(
120
+ nn.Conv2d(64, 2, stride=1, kernel_size=3, padding=1, bias=True),
121
+ nn.GroupNorm(1, 2),
122
+ nn.ReLU()
123
+ )
124
+ model.up_stream.out_conv = touchless_outlayer
125
+ return model
126
+
127
+
128
+ if __name__ == '__main__':
129
+ mask_test, touch_test = torch.zeros((1, 1, 256, 256)), torch.zeros((1, 1, 256, 256))
130
+ ibl = torch.zeros((1, 1, 16, 32))
131
+
132
+ I_s = mask_test
133
+ baseline = Relight_SSN(1, 1)
134
+ baseline_output, _ = baseline(I_s, ibl)
135
+
136
+ tbaseline = baseline_2_tbaseline(copy.deepcopy(baseline))
137
+ I_s = torch.cat((mask_test, touch_test), axis=1)
138
+ tbaseline_output, _ = tbaseline(I_s, ibl)
139
+
140
+ t_loss_baseline = baseline_2_touchloss(copy.deepcopy(baseline))
141
+ I_s = mask_test
142
+ tloss_output, _ = t_loss_baseline(I_s, ibl)
143
+
144
+ print('baseline output: ', baseline_output.shape)
145
+ print('tbaseline output: ', tbaseline_output.shape)
146
+ print('tloss output: ', tloss_output.shape)
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/ssn/ssn_dataset.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ sys.path.append("..")
4
+
5
+ import os
6
+ from os.path import join
7
+ import torch
8
+ import numpy as np
9
+ import torch
10
+ from torch.utils.data import Dataset, DataLoader
11
+ from torchvision import transforms, utils
12
+ import time
13
+ import random
14
+ # import matplotlib.pyplot as plt
15
+ import cv2
16
+ from params import params
17
+ from .random_pattern import random_pattern
18
+ from .perturb_touch import random_perturb
19
+
20
+
21
+ class ToTensor(object):
22
+ """Convert ndarrays in sample to Tensors."""
23
+
24
+ def __call__(self, img, is_transpose=True):
25
+ # swap color axis because
26
+ # numpy image: H x W x C
27
+ # torch image: C X H X W
28
+ if is_transpose:
29
+ img = img.transpose((0, 3, 1, 2))
30
+ return torch.Tensor(img)
31
+
32
+
33
+ class SSN_Dataset(Dataset):
34
+ def __init__(self, ds_dir, hd_dir, is_training, fake_batch_size=8):
35
+ start = time.time()
36
+ self.fake_batch_size = fake_batch_size
37
+ # # of samples in each group
38
+ # magic number here
39
+ self.ibl_group_size = 16
40
+
41
+ parameter = params().get_params()
42
+
43
+ # (shadow_path, mask_path)
44
+ self.meta_data = self.init_meta(ds_dir, hd_dir)
45
+
46
+ self.is_training = is_training
47
+ self.to_tensor = ToTensor()
48
+
49
+ end = time.time()
50
+ print("Dataset initialize spent: {} ms".format(end - start))
51
+
52
+ # fake random
53
+ np.random.seed(19950220)
54
+ np.random.shuffle(self.meta_data)
55
+
56
+ self.valid_divide = 10
57
+ if parameter.small_ds:
58
+ self.meta_data = self.meta_data[:len(self.meta_data) // self.valid_divide]
59
+
60
+ self.training_num = len(self.meta_data) - len(self.meta_data) // self.valid_divide
61
+ print('training: {}, validation: {}'.format(self.training_num, len(self.meta_data) // self.valid_divide))
62
+
63
+ self.random_pattern_generator = random_pattern()
64
+
65
+ self.thread_id = os.getpid()
66
+ self.seed = os.getpid()
67
+ self.perturb = not parameter.pred_touch and not parameter.touch_loss
68
+ self.size_interval = 0.5 / 100
69
+ self.ori_interval = np.pi / 100
70
+
71
+ self.soft_distribution = [[np.exp(-0.4 * (i - j) ** 2) for i in np.arange(0.5 / self.size_interval)]
72
+ for j in np.arange(0.5 / self.size_interval)]
73
+
74
+ def __len__(self):
75
+ if self.is_training:
76
+ return self.training_num
77
+ else:
78
+ # return len(self.meta_data) - self.training_num
79
+ return len(self.meta_data) // self.valid_divide
80
+
81
+ def __getitem__(self, idx):
82
+ if self.is_training and idx > self.training_num:
83
+ print("error")
84
+ # offset to validation set
85
+ if not self.is_training:
86
+ idx = self.training_num + idx
87
+
88
+ cur_seed = idx * 1234 + os.getpid() + time.time()
89
+ random.seed(cur_seed)
90
+
91
+ # random ibls
92
+ shadow_path, mask_path, hard_path, touch_path = self.meta_data[idx]
93
+ hard_folder = hard_path.replace(hard_path.split('/')[-1], '')
94
+ if os.path.exists(hard_folder):
95
+ # hard_shadow = cv2.imread(hard_folder)
96
+
97
+ mask_img = cv2.imread(mask_path)
98
+ mask_img = mask_img[:, :, 0]
99
+ if mask_img.dtype == np.uint8:
100
+ mask_img = mask_img / 255.0
101
+ mask_img, shadow_bases = np.expand_dims(mask_img, axis=2), np.load(shadow_path)
102
+
103
+ w, h, c, m = shadow_bases.shape
104
+ shadow_soft_list = []
105
+ shadow_hard_list = []
106
+ size_list = []
107
+ orientation_list = []
108
+ mask_img_list = []
109
+ for i in range(int(self.fake_batch_size)):
110
+ shadow_img, light_img, size, orientation = self.render_new_shadow(shadow_bases)
111
+
112
+ h, w = mask_img.shape[0], mask_img.shape[1]
113
+ hi, wi = np.where(light_img == light_img.max())
114
+
115
+ while len(hi) > 1:
116
+ shadow_img, light_img, size, orientation = self.render_new_shadow(shadow_bases)
117
+ hi, wi = np.where(light_img == light_img.max())
118
+ size_soft = np.array(self.soft_distribution[int(size / self.size_interval)])
119
+ ori_soft = np.array(self.soft_distribution[int(orientation / self.ori_interval)])
120
+ prefix = '_ibli_' + str(int(wi * 8)) + '_iblj_' + str(int(hi * 8) + 128) + '_shadow.png'
121
+ shadow_hard_path = hard_path.replace('_shadow.png', prefix)
122
+ shadow_base = cv2.imread(shadow_hard_path, -1)[:, :, 0] / 255.0
123
+ shadow_base = np.expand_dims(shadow_base, axis=2)
124
+ shadow_base = self.line_aug(shadow_base)
125
+ shadow_soft_list.append(shadow_img)
126
+ shadow_hard_list.append(shadow_base)
127
+ size_list.append(size_soft)
128
+ orientation_list.append(ori_soft)
129
+ mask_img_list.append(mask_img)
130
+ shadow_softs = np.array(shadow_soft_list)
131
+ shadow_hards = np.array(shadow_hard_list)
132
+ sizes = np.array(size_list)
133
+ orientations = np.array(orientation_list)
134
+ mask_imgs = np.array(mask_img_list)
135
+
136
+ # touch_img = self.read_img(touch_path)
137
+ # touch_img = touch_img[:, :, 0:1]
138
+
139
+ # if self.perturb:
140
+ # touch_img = random_perturb(touch_img)
141
+
142
+ # input_img = np.concatenate((mask_img, touch_img), axis=2)
143
+ size = torch.Tensor(sizes)
144
+ ori = torch.Tensor(orientations)
145
+ hard_shadow, soft_shadow, mask_img = self.to_tensor(shadow_hards), self.to_tensor(
146
+ shadow_softs), self.to_tensor(
147
+ mask_imgs)
148
+ return {"hard_shadow": hard_shadow, "soft_shadow": soft_shadow, "mask_img": mask_img, "size": size,
149
+ "angle": ori}
150
+ else:
151
+ mask_img = cv2.imread(mask_path)
152
+ mask_img = mask_img[:, :, 0]
153
+ if mask_img.dtype == np.uint8:
154
+ mask_img = mask_img / 255.0
155
+ mask_img, shadow_bases = np.expand_dims(mask_img, axis=2), 1.0 - np.load(shadow_path)
156
+
157
+ w, h, c, m = shadow_bases.shape
158
+ shadow_soft_list = []
159
+ shadow_hard_list = []
160
+ size_list = []
161
+ orientation_list = []
162
+ mask_img_list = []
163
+ for i in range(int(self.fake_batch_size)):
164
+ shadow_img, light_img, size, orientation = self.render_new_shadow(shadow_bases)
165
+
166
+ h, w = mask_img.shape[0], mask_img.shape[1]
167
+ hi, wi = np.where(light_img == light_img.max())
168
+
169
+ while len(hi) > 1:
170
+ shadow_img, light_img, size, orientation = self.render_new_shadow(shadow_bases)
171
+ hi, wi, _ = np.where(light_img == light_img[:, :, :].max())
172
+ size_soft = np.array(self.soft_distribution[int(size / self.size_interval)])
173
+ ori_soft = np.array(self.soft_distribution[int(orientation / self.ori_interval)])
174
+
175
+ shadow_base = shadow_bases[:, :, wi, hi]
176
+ shadow_base[shadow_base > 0.3] = 1
177
+ shadow_base[shadow_base < 0.4] = 0
178
+ shadow_base = self.line_aug(shadow_base)
179
+ mask_img = np.expand_dims(cv2.resize(mask_img, (512, 512)), axis=2)
180
+ shadow_base = np.expand_dims(cv2.resize(shadow_base, (512, 512)), axis=2)
181
+ shadow_img = np.expand_dims(cv2.resize(shadow_img, (512, 512)), axis=2)
182
+ shadow_soft_list.append(shadow_img)
183
+ shadow_hard_list.append(shadow_base)
184
+ size_list.append(size_soft)
185
+ orientation_list.append(ori_soft)
186
+ mask_img_list.append(mask_img)
187
+ shadow_softs = np.array(shadow_soft_list)
188
+ shadow_hards = np.array(shadow_hard_list)
189
+ sizes = np.array(size_list)
190
+ orientations = np.array(orientation_list)
191
+ mask_imgs = np.array(mask_img_list)
192
+
193
+ # touch_img = self.read_img(touch_path)
194
+ # touch_img = touch_img[:, :, 0:1]
195
+
196
+ # if self.perturb:
197
+ # touch_img = random_perturb(touch_img)
198
+
199
+ # input_img = np.concatenate((mask_img, touch_img), axis=2)
200
+ size = torch.Tensor(sizes)
201
+ ori = torch.Tensor(orientations)
202
+
203
+ hard_shadow, soft_shadow, mask_img = self.to_tensor(shadow_hards), self.to_tensor(
204
+ shadow_softs), self.to_tensor(
205
+ mask_imgs)
206
+
207
+ return {"hard_shadow": hard_shadow, "soft_shadow": soft_shadow, "mask_img": mask_img, "size": size,
208
+ "angle": ori}
209
+
210
+ def init_meta(self, ds_dir, hd_dir):
211
+ metadata = []
212
+ # base_folder = join(ds_dir, 'base')
213
+ # mask_folder = join(ds_dir, 'mask')
214
+ # hard_folder = join(ds_dir, 'hard')
215
+ # touch_folder = join(ds_dir, 'touch')
216
+ # model_list = [f for f in os.listdir(base_folder) if os.path.isdir(join(base_folder, f))]
217
+ # for m in model_list:
218
+ # shadow_folder, cur_mask_folder = join(base_folder, m), join(mask_folder, m)
219
+ # shadows = [f for f in os.listdir(shadow_folder) if f.find('_shadow.npy') != -1]
220
+ # for s in shadows:
221
+ # prefix = s[:s.find('_shadow')]
222
+ # metadata.append((join(shadow_folder, s),
223
+ # join(cur_mask_folder, prefix + '_mask.png'),
224
+ # join(join(hard_folder, m), prefix + '_shadow.png'),
225
+ # join(join(touch_folder, m), prefix + '_touch.png')))
226
+
227
+ base_folder = join(hd_dir, 'base')
228
+ mask_folder = join(hd_dir, 'mask')
229
+ hard_folder = join(hd_dir, 'hard')
230
+ touch_folder = join(hd_dir, 'touch')
231
+ model_list = [f for f in os.listdir(base_folder) if os.path.isdir(join(base_folder, f))]
232
+ for m in model_list:
233
+ shadow_folder, cur_mask_folder = join(base_folder, m), join(mask_folder, m)
234
+ shadows = [f for f in os.listdir(shadow_folder) if f.find('_shadow.npy') != -1]
235
+ for s in shadows:
236
+ prefix = s[:s.find('_shadow')]
237
+ metadata.append((join(shadow_folder, s),
238
+ join(cur_mask_folder, prefix + '_mask.png'),
239
+ join(join(hard_folder, m), prefix + '_shadow.png'),
240
+ join(join(touch_folder, m), prefix + '_touch.png')))
241
+
242
+ return metadata
243
+
244
+ def line_aug(self, shadow):
245
+ p = np.random.random()
246
+ if p > 0.6:
247
+ k = np.tan(min((np.random.random() + 0.000000001), 0.999) * np.pi - np.pi / 2)
248
+ x, y, c = shadow.shape
249
+ b_max = y - x * k
250
+ line_num = np.random.randint(1, 20)
251
+ b_list = np.random.random(line_num) * b_max
252
+ x_coord = np.tile(np.arange(shadow.shape[1])[None, :], (shadow.shape[0], 1))
253
+ y_coord = np.tile(np.arange(shadow.shape[0])[:, None], (1, shadow.shape[1]))
254
+
255
+ for b in b_list:
256
+ mask_res = y_coord - k * x_coord - b
257
+ shadow[np.abs(mask_res) < 1] = 0
258
+ return shadow
259
+
260
+ def get_prefix(self, path):
261
+ folder = os.path.dirname(path)
262
+ basename = os.path.basename(path)
263
+ return os.path.join(folder, basename[:basename.find('_')])
264
+
265
+ def render_new_shadow(self, shadow_bases):
266
+ shadow_bases = shadow_bases[:, :, :, :]
267
+ h, w, iw, ih = shadow_bases.shape
268
+
269
+ num = random.randint(0, 50)
270
+ pattern_img, size, orientation = self.random_pattern_generator.get_pattern(iw, ih, num=num, size=0.5,
271
+ mitsuba=False)
272
+
273
+ # flip to mitsuba ibl
274
+ pattern_img = self.normalize_energy(cv2.flip(cv2.resize(pattern_img, (iw, ih)), 0))
275
+ shadow = np.tensordot(shadow_bases, pattern_img, axes=([2, 3], [1, 0]))
276
+ # pattern_img = np.expand_dims(cv2.resize(pattern_img, (iw, 16)), 2)
277
+
278
+ return np.expand_dims(shadow, 2), pattern_img, size, orientation
279
+
280
+ def get_min_max(self, batch_data, name):
281
+ print('{} min: {}, max: {}'.format(name, np.min(batch_data), np.max(batch_data)))
282
+
283
+ def log(self, log_info):
284
+ with open('log.txt', 'a+') as f:
285
+ f.write(log_info)
286
+
287
+ def normalize_energy(self, ibl, energy=30.0):
288
+ if np.sum(ibl) < 1e-3:
289
+ return ibl
290
+ return ibl * energy / np.sum(ibl)
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/ssn/ssn_submodule.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ def get_layer_info(out_channels, activation_func='relu'):
6
+ if out_channels >= 32:
7
+ group_num = 32
8
+ else:
9
+ group_num = 1
10
+
11
+ norm_layer = nn.GroupNorm(group_num, out_channels)
12
+
13
+ if activation_func == 'relu':
14
+ activation_func = nn.ReLU()
15
+ elif activation_func == 'prelu':
16
+ activation_func = nn.PReLU(out_channels)
17
+
18
+ return norm_layer, activation_func
19
+
20
+
21
+ # add coord_conv
22
+ class add_coords(nn.Module):
23
+ def __init__(self, use_cuda=True):
24
+ super(add_coords, self).__init__()
25
+ self.use_cuda = use_cuda
26
+
27
+ def forward(self, input_tensor):
28
+ b, c, dim_y, dim_x = input_tensor.shape
29
+ xx_ones = torch.ones([1, 1, 1, dim_x], dtype=torch.int32)
30
+ yy_ones = torch.ones([1, 1, 1, dim_y], dtype=torch.int32)
31
+
32
+ xx_range = torch.arange(dim_y, dtype=torch.int32)
33
+ yy_range = torch.arange(dim_x, dtype=torch.int32)
34
+ xx_range = xx_range[None, None, :, None]
35
+ yy_range = yy_range[None, None, :, None]
36
+
37
+ xx_channel = torch.matmul(xx_range, xx_ones)
38
+ yy_channel = torch.matmul(yy_range, yy_ones)
39
+
40
+ # transpose y
41
+ yy_channel = yy_channel.permute(0, 1, 3, 2)
42
+
43
+ xx_channel = xx_channel.float() / (dim_y - 1)
44
+ yy_channel = yy_channel.float() / (dim_x - 1)
45
+
46
+ xx_channel = xx_channel * 2 - 1
47
+ yy_channel = yy_channel * 2 - 1
48
+
49
+ xx_channel = xx_channel.repeat(b, 1, 1, 1)
50
+ yy_channel = yy_channel.repeat(b, 1, 1, 1)
51
+
52
+ if torch.cuda.is_available and self.use_cuda:
53
+ input_tensor = input_tensor.cuda()
54
+ xx_channel = xx_channel.cuda()
55
+ yy_channel = yy_channel.cuda()
56
+ out = torch.cat([input_tensor, xx_channel, yy_channel], dim=1)
57
+ return out
58
+
59
+
60
+ class Conv(nn.Module):
61
+ """ (convolution => [BN] => ReLU) """
62
+
63
+ def __init__(self, in_channels, out_channels, kernel_size=3, conv_stride=1, padding=1, bias=True,
64
+ activation_func='relu', style=False):
65
+ super().__init__()
66
+
67
+ self.style = style
68
+ norm_layer, activation_func = get_layer_info(out_channels, activation_func)
69
+ if style:
70
+ self.styleconv = Conv2DMod(in_channels, out_channels, kernel_size)
71
+ self.relu = nn.LeakyReLU(0.2, inplace=True)
72
+ else:
73
+ if norm_layer is not None:
74
+ self.conv = nn.Sequential(
75
+ nn.Conv2d(in_channels, out_channels, stride=conv_stride, kernel_size=kernel_size, padding=padding,
76
+ bias=bias),
77
+ norm_layer,
78
+ activation_func)
79
+ else:
80
+ self.conv = nn.Sequential(
81
+ nn.Conv2d(in_channels, out_channels, stride=conv_stride, kernel_size=kernel_size, padding=padding,
82
+ bias=bias),
83
+ activation_func)
84
+
85
+ def forward(self, x, style_fea):
86
+ if self.style:
87
+ res = self.styleconv(x, style_fea)
88
+ res = self.relu(res)
89
+ return res
90
+ else:
91
+ return self.conv(x)
92
+
93
+
94
+ class Conv2DMod(nn.Module):
95
+ def __init__(self, in_chan, out_chan, kernel, demod=True, stride=1, dilation=1, eps=1e-8, **kwargs):
96
+ super().__init__()
97
+ self.filters = out_chan
98
+ self.demod = demod
99
+ self.kernel = kernel
100
+ self.stride = stride
101
+ self.dilation = dilation
102
+ self.weight = nn.Parameter(torch.randn((out_chan, in_chan, kernel, kernel)))
103
+ self.eps = eps
104
+ nn.init.kaiming_normal_(self.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
105
+
106
+ def _get_same_padding(self, size, kernel, dilation, stride):
107
+ return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2
108
+
109
+ def forward(self, x, y):
110
+ b, c, h, w = x.shape
111
+
112
+ w1 = y[:, None, :, None, None]
113
+ w2 = self.weight[None, :, :, :, :]
114
+ weights = w2 * (w1 + 1)
115
+
116
+ if self.demod:
117
+ d = torch.rsqrt((weights ** 2).sum(dim=(2, 3, 4), keepdim=True) + self.eps)
118
+ weights = weights * d
119
+
120
+ x = x.reshape(1, -1, h, w)
121
+
122
+ _, _, *ws = weights.shape
123
+ weights = weights.reshape(b * self.filters, *ws)
124
+
125
+ padding = self._get_same_padding(h, self.kernel, self.dilation, self.stride)
126
+ x = F.conv2d(x, weights, padding=padding, groups=b)
127
+
128
+ x = x.reshape(-1, self.filters, h, w)
129
+ return x
130
+
131
+ class Up(nn.Module):
132
+ """ Upscaling then conv """
133
+
134
+ def __init__(self, in_channels, out_channels, activation_func='relu', style=False):
135
+ super().__init__()
136
+ self.style = style
137
+ activation_func = 'relu'
138
+ norm_layer, activation_func = get_layer_info(out_channels, activation_func)
139
+
140
+ self.up_layer = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
141
+ self.up = Conv(in_channels, in_channels // 4, activation_func=activation_func, style=style)
142
+
143
+ def forward(self, x, style_fea):
144
+ if self.style:
145
+ x = self.up_layer(x)
146
+ return self.up(x, style_fea)
147
+ else:
148
+ x = self.up_layer(x)
149
+ return self.up(x, style_fea)
150
+
151
+
152
+ class PSPUpsample(nn.Module):
153
+ def __init__(self, in_channels, out_channels, scale_x, scale_y):
154
+ super().__init__()
155
+ self.conv = Conv(in_channels, out_channels)
156
+ self.scale_x = scale_x
157
+ self.scale_y = scale_y
158
+
159
+ def forward(self, x):
160
+ h, w = self.scale_y * x.size(2), self.scale_x * x.size(3)
161
+ p = F.upsample(input=x, size=(h, w), mode='bilinear')
162
+ return self.conv(p)
163
+
164
+
165
+ class PSP(nn.Module):
166
+ def __init__(self, in_channels):
167
+ super().__init__()
168
+ # pooling
169
+ self.pool2 = nn.AdaptiveAvgPool2d((2, 2))
170
+ self.pool4 = nn.AdaptiveAvgPool2d((4, 4))
171
+ self.pool8 = nn.AdaptiveAvgPool2d((8, 8))
172
+
173
+ # conv -> compress channels
174
+ avg_channel = in_channels // 4
175
+ self.conv2 = Conv(in_channels, avg_channel)
176
+ self.conv4 = Conv(in_channels, avg_channel)
177
+ self.conv8 = Conv(in_channels, avg_channel)
178
+ self.conv16 = Conv(in_channels, avg_channel)
179
+
180
+ # up sapmle -> match dimension
181
+ self.up2 = PSPUpsample(avg_channel, avg_channel, 16 // 2, 16 // 2)
182
+ self.up4 = PSPUpsample(avg_channel, avg_channel, 16 // 4, 16 // 4)
183
+ self.up8 = PSPUpsample(avg_channel, avg_channel, 16 // 8, 16 // 8)
184
+
185
+ def forward(self, x):
186
+ x2 = self.up2(self.conv2(self.pool2(x)))
187
+ x4 = self.up4(self.conv4(self.pool4(x)))
188
+ x8 = self.up8(self.conv8(self.pool8(x)))
189
+ x16 = self.conv16(x)
190
+ return torch.cat((x2, x4, x8, x16), dim=1)
191
+
192
+
193
+ class Up_Stream(nn.Module):
194
+ """ Up Stream Sequence """
195
+
196
+ def __init__(self, out_channels=3, activation_func = 'relu'):
197
+ super(Up_Stream, self).__init__()
198
+
199
+ input_channel = 512
200
+ fea_dim = 200
201
+ norm_layer, activation_func = get_layer_info(input_channel, activation_func)
202
+ self.to_style1 = nn.Linear(in_features=fea_dim, out_features=input_channel)
203
+ self.up_16_16_1 = Conv(input_channel, 256, activation_func=activation_func, style=True)
204
+ self.up_16_16_2 = Conv(768, 512, activation_func=activation_func)
205
+ self.up_16_16_3 = Conv(1024, 512, activation_func=activation_func)
206
+
207
+ self.up_16_32 = Up(1024, 256, activation_func=activation_func)
208
+ self.to_style2 = nn.Linear(in_features=fea_dim, out_features=512)
209
+ self.up_32_32_1 = Conv(512, 256, activation_func=activation_func, style=True)
210
+
211
+ self.up_32_64 = Up(512, 128, activation_func=activation_func)
212
+ self.to_style3 = nn.Linear(in_features=fea_dim, out_features=256)
213
+ self.up_64_64_1 = Conv(256, 128, activation_func=activation_func, style=True)
214
+
215
+ self.up_64_128 = Up(256, 64, activation_func=activation_func)
216
+ self.to_style4 = nn.Linear(in_features=fea_dim, out_features=128)
217
+ self.up_128_128_1 = Conv(128, 64, activation_func=activation_func, style=True)
218
+
219
+ self.up_128_256 = Up(128, 32, activation_func=activation_func)
220
+ self.out_conv = Conv(64, out_channels, activation_func='relu')
221
+
222
+ def forward(self, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, style):
223
+ batch_size, c, h, w = x1.size()
224
+
225
+ # import pdb; pdb.set_trace()
226
+ # multiple channel ibl
227
+ # style = torch.zeros(batch_size, 500).to(x11.device)
228
+
229
+ # y = l.view(-1, 512, 1, 1).repeat(1, 1, 16, 16)
230
+ style1 = self.to_style1(style)
231
+ y = self.up_16_16_1(x11, style1) # 256 x 16 x 16
232
+
233
+ y = torch.cat((x10, y), dim=1) # 768 x 16 x 16
234
+ # print(y.size())
235
+
236
+ y = self.up_16_16_2(y, y) # 512 x 16 x 16
237
+ # print(y.size())
238
+
239
+ y = torch.cat((x9, y), dim=1) # 1024 x 16 x 16
240
+ # print(y.size())
241
+
242
+ # import pdb; pdb.set_trace()
243
+ y = self.up_16_16_3(y, y) # 512 x 16 x 16
244
+ # print(y.size())
245
+
246
+ y = torch.cat((x8, y), dim=1) # 1024 x 16 x 16
247
+ # print(y.size())
248
+
249
+ # import pdb; pdb.set_trace()
250
+ y = self.up_16_32(y, y) # 256 x 32 x 32
251
+ # print(y.size())
252
+
253
+ y = torch.cat((x7, y), dim=1)
254
+ style2 = self.to_style2(style)
255
+ y = self.up_32_32_1(y, style2) # 256 x 32 x 32
256
+ # print(y.size())
257
+
258
+ y = torch.cat((x6, y), dim=1)
259
+ y = self.up_32_64(y, y)
260
+ # print(y.size())
261
+ y = torch.cat((x5, y), dim=1)
262
+ style3 = self.to_style3(style)
263
+ y = self.up_64_64_1(y, style3) # 128 x 64 x 64
264
+ # print(y.size())
265
+
266
+ y = torch.cat((x4, y), dim=1)
267
+ y = self.up_64_128(y, y)
268
+ # print(y.size())
269
+ y = torch.cat((x3, y), dim=1)
270
+ style4 = self.to_style4(style)
271
+ y = self.up_128_128_1(y, style4) # 64 x 128 x 128
272
+ # print(y.size())
273
+
274
+ y = torch.cat((x2, y), dim=1)
275
+ y = self.up_128_256(y, y) # 32 x 256 x 256
276
+ # print(y.size())
277
+
278
+ y = torch.cat((x1, y), dim=1)
279
+
280
+ y = self.out_conv(y, y) # 3 x 256 x 256
281
+
282
+ return y
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/ssn/test.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import ssn_dataset
3
+ from torchvision import transforms, utils
4
+ import numpy as np
5
+
6
+ csv_file = "~/Dataset/soft_shadow/train/metadata.csv"
7
+ # compose_transform = None
8
+ training_dataset = ssn_dataset.SSN_Dataset(csv_file, is_training = True)
9
+ testing_dataset = ssn_dataset.SSN_Dataset(csv_file, is_training = False)
10
+
11
+ print('training set size: ', len(training_dataset))
12
+ print('testing set size: ',len(testing_dataset))
13
+
14
+ print(len(training_dataset.meta_data))
15
+ print(training_dataset.meta_data[0])
16
+
17
+ # for j in range(10):
18
+ # for i in range(len(training_dataset)):
19
+ # data = training_dataset[i]
20
+ # # print("{} \r".format(i), flush=True, end="")
21
+ # print("{} ".format(i))
22
+
23
+ # for i,data in enumerate(testing_dataset):
24
+ # print("{} \r".format(i), flush=True, end="")
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/ssn/test_dataset.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ssn_dataset
2
+ import time
3
+
4
+ if __name__ == '__main__':
5
+ start = time.time()
6
+ csv_file = "~/Dataset/soft_shadow/single_human/metadata.csv"
7
+ training_dataset = ssn_dataset.SSN_Dataset(csv_file, is_training=True)
8
+ testing_dataset = ssn_dataset.SSN_Dataset(csv_file, is_training=False)
9
+
10
+ print("Training dataset num: ", len(training_dataset))
11
+ print("Testing dataset num: ",len(testing_dataset))
12
+
13
+ for i in range(len(training_dataset)):
14
+ data = training_dataset[i]
15
+ print('Training set: successfully iterate {} \r'.format(i), flush=True, end='')
16
+
17
+ for i in range(len(testing_dataset)):
18
+ data = testing_dataset[i]
19
+ print('Validation set: successfully iterate {} \r'.format(i), flush=True, end='')
20
+
21
+ end = time.time()
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/__init__.py ADDED
File without changes
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/html.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dominate
2
+ from dominate.tags import meta, h3, table, tr, td, p, a, img, br
3
+ import os
4
+
5
+ class HTML:
6
+ def __init__(self, web_dir, title, reflesh=0):
7
+ self.title = title
8
+ self.web_dir = web_dir
9
+ if not os.path.exists(self.web_dir):
10
+ os.makedirs(self.web_dir)
11
+
12
+ # print(self.img_dir)
13
+
14
+ self.doc = dominate.document(title=title)
15
+ if reflesh > 0:
16
+ with self.doc.head:
17
+ meta(http_equiv="reflesh", content=str(reflesh))
18
+
19
+ def get_image_dir(self):
20
+ return self.img_dir
21
+
22
+ def add_header(self, str):
23
+ with self.doc:
24
+ h3(str)
25
+
26
+ def add_table(self, border=1):
27
+ self.t = table(border=border, style="table-layout: fixed;")
28
+ self.doc.add(self.t)
29
+
30
+ def add_images(self, ims, txts, links, width=400, height=300):
31
+ self.add_table()
32
+ with self.t:
33
+ with tr():
34
+ for im, txt, link in zip(ims, txts, links):
35
+ with td(style="word-wrap: break-word; height:{}px; width:{}px".format(height + 10,width + 10), halign="center", valign="top"):
36
+ with p():
37
+ with a(href=os.path.join('/',link)):
38
+ img(style="width:{}px;height:{}".format(width, height), src=os.path.join('/',im))
39
+ br()
40
+ p(txt)
41
+
42
+ def save(self):
43
+ html_file = '%s/index.html' % self.web_dir
44
+ f = open(html_file, 'wt')
45
+ f.write(self.doc.render())
46
+ f.close()
47
+
48
+
49
+ if __name__ == '__main__':
50
+ html = HTML('web/', 'test_html')
51
+ html.add_header('hello world')
52
+
53
+ ims = []
54
+ txts = []
55
+ links = []
56
+ for n in range(4):
57
+ ims.append('image_%d.png' % n)
58
+ txts.append('text_%d' % n)
59
+ links.append('image_%d.png' % n)
60
+ html.add_images(ims, txts, links)
61
+ html.save()
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/html_server.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import http.server
2
+ import socketserver
3
+
4
+ PORT = 8081
5
+ Handler = http.server.SimpleHTTPRequestHandler
6
+
7
+ with socketserver.TCPServer(("", PORT), Handler) as httpd:
8
+ print("serving at port", PORT)
9
+ httpd.serve_forever()
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/imgs ADDED
@@ -0,0 +1 @@
 
 
1
+ /home/ysheng/Dataset/dropbox/Research/SSN_Training_Share/touch
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/index.html ADDED
The diff for this file is too large to render. See raw diff
 
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/make_html.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import json
3
+ import pdb
4
+ import os
5
+ from os.path import join
6
+ import html
7
+ from tqdm import tqdm
8
+ import argparse
9
+ import pandas as pd
10
+ import matplotlib.pyplot as plt
11
+
12
+ def get_files(folder):
13
+ return [join(folder, f) for f in os.listdir(folder) if os.path.isfile(join(folder, f))]
14
+
15
+ def get_folders(folder):
16
+ return [join(folder, f) for f in os.listdir(folder) if os.path.isdir(join(folder, f))]
17
+
18
+ vis_img_folder = 'imgs'
19
+ vis_eval_img_folder = 'eval_imgs'
20
+ def eval_gen(webpage, output_folder, is_pattern=True):
21
+ def get_file(files, key_world):
22
+ mitsuba_shadow = ''
23
+ for f in files:
24
+ if f.find(key_world) != -1:
25
+ mitsuba_shadow = f
26
+ break
27
+ return mitsuba_shadow
28
+
29
+ def flip_shadow(img_file):
30
+ dirname, fname = os.path.dirname(img_file), os.path.splitext(os.path.basename(img_file))[0]
31
+ if img_file == '':
32
+ print('find one zero')
33
+ mts_shadow_np = np.zeros((256,256,3))
34
+ else:
35
+ mts_shadow_np = plt.imread(img_file)
36
+
37
+ save_path = join(dirname, fname + '_flip.png')
38
+ plt.imsave(save_path, 1.0-mts_shadow_np)
39
+ return save_path
40
+
41
+ img_folders = join(output_folder, 'imgs')
42
+ folders = get_folders(img_folders)
43
+ print("There are {} folders".format(len(folders)))
44
+
45
+ for model in tqdm(folders):
46
+ cur_model_relative = join(vis_img_folder, os.path.basename(model))
47
+ evl_cur_model_relative = join(vis_eval_img_folder, os.path.basename(model))
48
+
49
+ if is_pattern:
50
+ ibl_relative = join(cur_model_relative, 'pattern')
51
+ else:
52
+ ibl_relative = join(cur_model_relative, 'real')
53
+
54
+ # import pdb; pdb.set_trace()
55
+ ibl_folders = get_folders(ibl_relative)
56
+ ibl_folders.sort()
57
+ for ibl in ibl_folders:
58
+ cur_ibl_relative = join(ibl_relative,os.path.basename(ibl))
59
+ gt_files = get_files(cur_ibl_relative)
60
+ mts_shadow = get_file(gt_files, '_shadow.png')
61
+
62
+ ibl_name = os.path.basename(ibl)
63
+ ibl = join(cur_ibl_relative, ibl_name + '.png')
64
+
65
+ mitsuba_shadow = flip_shadow(mts_shadow)
66
+
67
+ cur_eval_folder = join(evl_cur_model_relative, join('pattern', ibl_name))
68
+ net_predict = get_file(get_files(cur_eval_folder), 'predict.png')
69
+
70
+ # mitsuba_final = join(cur_ibl_relative, 'composite.png')
71
+ # pred_final = join(cur_ibl_relative, 'composite_pred.png')
72
+
73
+ # print(ibl_name)
74
+ ims, txts, links = [ibl,mitsuba_shadow, net_predict], ['ibl','mitsuba', 'predict'], [ibl,mitsuba_shadow, net_predict]
75
+
76
+ webpage.add_images(ims, txts, links)
77
+
78
+ vis_pattern_folder = '/home/ysheng/Documents/vis_pattern'
79
+ vis_real_folder = '/home/ysheng/Documents/vis_real'
80
+ def vis_files_in_folder():
81
+ folder = '/home/ysheng/Documents/vis_models'
82
+ webpage = html.HTML(folder, 'models', reflesh=1)
83
+ img_folders = join(folder, 'imgs')
84
+ files = get_files(img_folders)
85
+ print("There are {} files".format(len(files)))
86
+
87
+ prefix_set = set()
88
+ for cur_file in tqdm(files):
89
+ cur_name = os.path.splitext(os.path.basename(cur_file))[0]
90
+ prefix_set.add(cur_name[:-3])
91
+
92
+ print('there are {} prefixs'.format(len(prefix_set)))
93
+ prefix_set = list(prefix_set)
94
+ prefix_set.sort()
95
+
96
+ # import pdb; pdb.set_trace()
97
+ relative_folder = './imgs'
98
+ for i, prefix in enumerate(prefix_set):
99
+ ims = [join(relative_folder, prefix + '{:03d}.png'.format(i)) for i in range(len(files) // len(prefix_set))]
100
+ txts = [prefix + '{:03d}'.format(i) for i in range(len(files) // len(prefix_set))]
101
+ links = ims
102
+ webpage.add_images(ims, txts, links)
103
+
104
+ webpage.save()
105
+ print('finished')
106
+
107
+ def vis_files(df_file):
108
+ """ input is a pandas dataframe
109
+ format: path, path,..., name,name, ...
110
+ """
111
+ folder = '.'
112
+ webpage = html.HTML(folder, 'benchmark', reflesh=1)
113
+
114
+ relative_folder = './imgs'
115
+ # for i, prefix in enumerate(prefix_set):
116
+ # ims = [join(relative_folder, prefix + '{:03d}.png'.format(i)) for i in range(len(files) // len(prefix_set))]
117
+ # txts = [prefix + '{:03d}'.format(i) for i in range(len(files) // len(prefix_set))]
118
+ # links = ims
119
+ # webpage.add_images(ims, txts, links)
120
+
121
+ df = pd.read_csv(df_file)
122
+ for i,v in tqdm(df.iterrows(), total=len(df)):
123
+ img_range = len(v)//2+1
124
+ imgs = [join(relative_folder,v[i]) for i in range(1,img_range)]
125
+ txts = [v[i] for i in range(img_range, len(v))]
126
+ links = imgs
127
+ webpage.add_images(imgs, txts, links)
128
+
129
+ webpage.save()
130
+ print('finished')
131
+
132
+ if __name__ == "__main__":
133
+ vis_files('/home/ysheng/Documents/paper_project/adobe/soft_shadow/benchmark_results/html.csv')
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/net_utils.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #import matplotlib.pyplot as plt
2
+ import os
3
+ from torchvision import transforms, utils
4
+ import torch
5
+ #import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ from utils.utils_file import get_cur_time_stamp, create_folder
8
+
9
+ def compute_differentiable_params(net):
10
+ return sum(p.numel() for p in net.parameters() if p.requires_grad)
11
+
12
+ def convert_Relight_latent_light(latent_feature):
13
+ """ Convert n x 6 x 16 x 16 -> n x 3 x 16 x 32 """
14
+ # torch image: C X H X W
15
+ batch_size, C, H, W = latent_feature.size()
16
+ latent_feature = torch.reshape(latent_feature, (batch_size, 3, 16, 32)) # make sure it is right
17
+ # print(latent_feature.size())
18
+ return latent_feature
19
+
20
+ def show_batch(sample_batch, out_file=None):
21
+ grid = utils.make_grid(sample_batch)
22
+ plt.figure(figsize=(30,20))
23
+ plt.imshow(grid.detach().cpu().numpy().transpose((1,2,0)))
24
+
25
+ if not out_file is None:
26
+ print('try save ', out_file)
27
+ plt.savefig(out_file)
28
+
29
+ plt.show()
30
+
31
+ def show_light_batch(light_batch):
32
+ light_batch = convert_Relight_latent_light(light_batch)
33
+ show_batch(light_batch)
34
+
35
+ def save_loss(figure_fname, train_loss, valid_loss):
36
+ plt.plot(train_loss)
37
+ plt.plot(valid_loss)
38
+ plt.legend(['train_loss', 'valid_loss'])
39
+ plt.savefig(figure_fname)
40
+
41
+ def save_model(output_folder, model, optimizer, epoch, best_loss, fname, hist_train_loss, hist_valid_loss, hist_lr, params):
42
+ """ Save current best model into some folder """
43
+ create_folder(output_folder)
44
+
45
+ # cur_time_stamp = get_cur_time_stamp()
46
+ # output_fname = os.path.join(output_folder, exp_name + '_' + cur_time_stamp + ".pt")
47
+ output_fname = os.path.join(output_folder, fname)
48
+ tmp_model = model
49
+ if params.multi_gpu and hasattr(tmp_model, 'module'):
50
+ tmp_model = model.module
51
+
52
+ torch.save({
53
+ 'epoch': epoch,
54
+ 'best_loss': best_loss,
55
+ 'model_state_dict': tmp_model.state_dict(),
56
+ 'optimizer_state_dict': optimizer.state_dict(),
57
+ 'hist_train_loss': hist_train_loss,
58
+ 'hist_valid_loss': hist_valid_loss,
59
+ 'hist_lr':hist_lr,
60
+ 'params':str(params)
61
+ }, output_fname)
62
+ return output_fname
63
+
64
+ def get_lr(optimizer):
65
+ for param_group in optimizer.param_groups:
66
+ return param_group['lr']
67
+
68
+ def set_lr(optimizer, lr):
69
+ for param_group in optimizer.param_groups:
70
+ param_group['lr'] = lr
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/tensorboard_utils.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch.utils.tensorboard import SummaryWriter
4
+
5
+ def tensorboard_plot_loss(win_name, loss, writer):
6
+ writer.add_scalar("Loss/{}".format(win_name), loss[-1], len(loss))
7
+ writer.flush()
8
+
9
+ def normalize_img(imgs):
10
+ b,c,h,w = imgs.shape
11
+ gt_batch = b//2
12
+ for i in range(gt_batch):
13
+ factor = torch.max(imgs[i])
14
+ imgs[i] = imgs[i]/factor
15
+ imgs[gt_batch + i] = imgs[gt_batch + i]/factor
16
+
17
+ imgs = torch.clamp(imgs, 0.0,1.0)
18
+ return imgs
19
+
20
+ def tensorboard_show_batch(imgs, writer, win_name=None, nrow=2, normalize=True, step=0):
21
+ if normalize:
22
+ imgs = normalize_img(imgs)
23
+
24
+ writer.add_images('{}'.format(win_name), imgs, step)
25
+ writer.flush()
26
+
27
+ def tensorboard_log(log_info, writer, win_name='logger', step=0):
28
+ writer.add_text(win_name, log_info, step)
29
+ writer.flush()
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/time_utils.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import datetime
2
+
3
+ def get_time_stamp():
4
+ return '{:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now())
5
+
6
+
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/utils_file.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from shutil import copyfile
2
+ from os import listdir
3
+ from os.path import isfile, join
4
+ import os
5
+ import datetime
6
+
7
+ def get_all_folders(folder):
8
+ if check_file_exists(folder) == False:
9
+ print("Cannot find the folder ", folder)
10
+ return []
11
+ subfolders = [f for f in os.listdir(folder) if not isfile(join(folder,f))]
12
+ return subfolders
13
+
14
+ def get_all_files(folder):
15
+ if check_file_exists(folder) == False:
16
+ print("Cannot find the folder ", folder)
17
+ return []
18
+
19
+ ori_files = [f for f in listdir(folder) if isfile(join(folder, f))]
20
+ return ori_files
21
+
22
+ def create_folder(folder):
23
+ if not os.path.exists(folder):
24
+ os.mkdir(folder)
25
+
26
+ def create_folders(folder_list):
27
+ for f in folder_list:
28
+ create_folder(f)
29
+
30
+ def replace_file_ext(fname, new_ext):
31
+ ext_pos = fname.find(".")
32
+ if ext_pos != -1:
33
+ return fname[0:ext_pos] + "."+ new_ext
34
+ else:
35
+ print("Please check " + fname)
36
+
37
+ def check_file_exists(fname, verbose=True):
38
+ try:
39
+ if not os.path.exists(fname) or get_file_size(fname) == 0:
40
+ if verbose:
41
+ print("file {} does not exists! ".format(fname))
42
+ return False
43
+ except:
44
+ print("File {} has some issue! ".format(fname))
45
+ return False
46
+ return True
47
+
48
+ def delete_file(fname):
49
+ if check_file_exists(fname):
50
+ os.remove(fname)
51
+
52
+ def get_file_size(fname):
53
+ return os.path.getsize(fname)
54
+
55
+ def get_folder_size(folder):
56
+ return sum(os.path.getsize(folder + f) for f in listdir(folder) if isfile(join(folder, f)))
57
+
58
+ def get_cur_time_stamp():
59
+ return datetime.datetime.now().strftime("%d-%B-%I-%M-%p")
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/vis_test_results.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import json
3
+ import pdb
4
+ import os
5
+ from os.path import join
6
+ import html
7
+ from tqdm import tqdm
8
+ import argparse
9
+ import pandas as pd
10
+ import cv2 as cv
11
+
12
+ base_softness = "0.1"
13
+ base_exp_name = "fov_results_real_89"
14
+ root_dir = "/mnt/share/yifan/code/soft_shadow-master/vis_res/"
15
+ base_dir = root_dir + base_exp_name + '/' + base_softness
16
+
17
+ exp_name = ["fov_results_real_hd", "fov_results_real_hd_0713"]
18
+
19
+ softness = [0.1, 0.2]
20
+
21
+ case_name = ["case1", "case2", "case7"]
PixHtLab-Src/Demo/PixhtLab/ShadowStyle/inference/utils/visdom_utils.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from visdom import Visdom
2
+ import numpy as np
3
+ import torch
4
+
5
+ # viz = Visdom(port=8002)
6
+ # viz2 = Visdom(port=8003)
7
+
8
+ def setup_visdom(port=8002):
9
+ return Visdom(port=port)
10
+
11
+ def visdom_plot_loss(win_name, loss, cur_viz):
12
+ loss_np = np.array(loss)
13
+ x = np.arange(1, 1 + len(loss))
14
+ cur_viz.line(win=win_name,
15
+ X=x,
16
+ Y=loss_np,
17
+ opts=dict(showlegend=True, legend=[win_name]))
18
+
19
+ def guassian_light(light_tensor):
20
+ light_tensor = light_tensor.detach().cpu()
21
+ channel = light_tensor.size()[0]
22
+ tensor_ret = torch.zeros(light_tensor.size())
23
+ for i in range(channel):
24
+ light_np = light_tensor[0].numpy() * 100.0
25
+ light_np = gaussian_filter(light_np, sigma=2)
26
+ tensor_ret[i] = torch.from_numpy(light_np)
27
+ tensor_ret[i] = torch.clamp(tensor_ret[i], 0.0, 1.0)
28
+
29
+ return tensor_ret
30
+
31
+ def normalize_img(imgs):
32
+ b,c,h,w = imgs.shape
33
+ gt_batch = b//2
34
+ for i in range(gt_batch):
35
+ factor = torch.max(imgs[i])
36
+ imgs[i] = imgs[i]/factor
37
+ imgs[gt_batch + i] = imgs[gt_batch + i]/factor
38
+ # imgs[i] = imgs[i]/3.0
39
+
40
+ imgs = torch.clamp(imgs, 0.0,1.0)
41
+ return imgs
42
+
43
+ def visdom_show_batch(imgs, cur_viz, win_name=None, nrow=2, normalize=True):
44
+ if normalize:
45
+ imgs = normalize_img(imgs)
46
+
47
+ if win_name is None:
48
+ cur_viz.images(imgs, win="batch visualize",nrow=nrow)
49
+ else:
50
+ cur_viz.images(imgs, win=win_name, opts=dict(title=win_name),nrow=nrow)
51
+
52
+ def visdom_log(log_info, viz, win_name='logger'):
53
+ viz.text(log_info, win=win_name)
PixHtLab-Src/Demo/PixhtLab/Torch_Render/hshadow_cuda.cpp ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+
3
+ #include <vector>
4
+ #include <stdio.h>
5
+
6
+ // CUDA forward declarations
7
+ std::vector<torch::Tensor> hshadow_render_cuda_forward(torch::Tensor rgb, torch::Tensor mask, torch::Tensor mask_bb, torch::Tensor hmap, torch::Tensor rechmap, torch::Tensor light_pos);
8
+ std::vector<torch::Tensor> reflect_render_cuda_forward(torch::Tensor rgb, torch::Tensor mask, torch::Tensor hmap, torch::Tensor rechmap, torch::Tensor thresholds);
9
+ std::vector<torch::Tensor> glossy_reflect_render_cuda_forward(torch::Tensor rgb, torch::Tensor mask, torch::Tensor hmap, torch::Tensor rechmap, const int sample_n, const float glossy);
10
+ torch::Tensor ray_intersect_cuda_forward(torch::Tensor rgb, torch::Tensor mask, torch::Tensor hmap, torch::Tensor rechmap, torch::Tensor rd_map);
11
+ torch::Tensor ray_scene_intersect_cuda_forward(torch::Tensor rgb, torch::Tensor mask, torch::Tensor hmap, torch::Tensor ro, torch::Tensor rd, float dh);
12
+
13
+ // C++ interface
14
+ // NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
15
+ #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
16
+ #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
17
+ #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
18
+
19
+ /* Heightmap Shadow Rendering
20
+ rgb: B x 3 x H x W
21
+ mask: B x 1 x H x W
22
+ mask: B x 1
23
+ hmap: B x 1 x H x W
24
+ rechmap: B x 1 x H x W
25
+ light_pos: B x 1 (x,y,h)
26
+ */
27
+ std::vector<torch::Tensor> hshadow_render_forward(torch::Tensor rgb, torch::Tensor mask, torch::Tensor bb, torch::Tensor hmap, torch::Tensor rechmap, torch::Tensor light_pos) {
28
+ CHECK_INPUT(rgb);
29
+ CHECK_INPUT(mask);
30
+ CHECK_INPUT(bb);
31
+ CHECK_INPUT(hmap);
32
+ CHECK_INPUT(rechmap);
33
+ CHECK_INPUT(light_pos);
34
+
35
+ return hshadow_render_cuda_forward(rgb, mask, bb, hmap, rechmap, light_pos);
36
+ }
37
+
38
+ std::vector<torch::Tensor> reflect_render_forward(torch::Tensor rgb, torch::Tensor mask, torch::Tensor hmap, torch::Tensor rechmap, torch::Tensor thresholds) {
39
+ CHECK_INPUT(rgb);
40
+ CHECK_INPUT(mask);
41
+ CHECK_INPUT(hmap);
42
+ CHECK_INPUT(rechmap);
43
+ CHECK_INPUT(thresholds);
44
+
45
+ return reflect_render_cuda_forward(rgb, mask, hmap, rechmap, thresholds);
46
+ }
47
+
48
+
49
+ std::vector<torch::Tensor> glossy_reflect_render_forward(torch::Tensor rgb,
50
+ torch::Tensor mask,
51
+ torch::Tensor hmap,
52
+ torch::Tensor rechmap,
53
+ int sample_n,
54
+ float glossy) {
55
+ CHECK_INPUT(rgb);
56
+ CHECK_INPUT(mask);
57
+ CHECK_INPUT(hmap);
58
+ CHECK_INPUT(rechmap);
59
+
60
+ return glossy_reflect_render_cuda_forward(rgb, mask, hmap, rechmap, sample_n, glossy);
61
+ }
62
+
63
+
64
+ torch::Tensor ray_intersect_foward(torch::Tensor rgb,
65
+ torch::Tensor mask,
66
+ torch::Tensor hmap,
67
+ torch::Tensor rechmap,
68
+ torch::Tensor rd_map) {
69
+ CHECK_INPUT(rgb);
70
+ CHECK_INPUT(mask);
71
+ CHECK_INPUT(hmap);
72
+ CHECK_INPUT(rechmap);
73
+
74
+ return ray_intersect_cuda_forward(rgb, mask, hmap, rechmap, rd_map);
75
+ }
76
+
77
+ torch::Tensor ray_scene_intersect_foward(torch::Tensor rgb,
78
+ torch::Tensor mask,
79
+ torch::Tensor hmap,
80
+ torch::Tensor ro,
81
+ torch::Tensor rd,
82
+ float dh) {
83
+ CHECK_INPUT(rgb);
84
+ CHECK_INPUT(mask);
85
+ CHECK_INPUT(hmap);
86
+ CHECK_INPUT(ro);
87
+ CHECK_INPUT(rd);
88
+
89
+ return ray_scene_intersect_cuda_forward(rgb, mask, hmap, ro, rd, dh);
90
+ }
91
+
92
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
93
+ m.def("forward", &hshadow_render_forward, "Heightmap Shadow Rendering Forward (CUDA)");
94
+ m.def("reflection", &reflect_render_forward, "Reflection Rendering Forward (CUDA)");
95
+ m.def("glossy_reflection", &glossy_reflect_render_forward, "Glossy Reflection Rendering Forward (CUDA)");
96
+ m.def("ray_intersect", &ray_intersect_foward, "Ray scene intersection");
97
+ m.def("ray_scene_intersect", &ray_scene_intersect_foward, "Ray scene intersection");
98
+ }
PixHtLab-Src/Demo/PixhtLab/Torch_Render/hshadow_cuda_kernel.cu ADDED
@@ -0,0 +1,682 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+
3
+ #include <cuda.h>
4
+ #include <cuda_runtime.h>
5
+
6
+ #include <vector>
7
+ #include <stdio.h>
8
+
9
+ namespace {
10
+ template <typename scalar_t>
11
+ __device__
12
+ scalar_t sign(scalar_t t) {
13
+ if (t > 0.0) {
14
+ return (scalar_t)1.0;
15
+ } else {
16
+ return -(scalar_t)1.0;
17
+ }
18
+ }
19
+
20
+ template <typename scalar_t>
21
+ struct vec2 {
22
+ scalar_t x, y;
23
+
24
+ __device__
25
+ vec2() { x=0.0, y=0.0;}
26
+
27
+ __device__
28
+ vec2(scalar_t x, scalar_t y):x(x), y(y) {}
29
+ };
30
+
31
+
32
+ template <typename scalar_t>
33
+ struct vec3 {
34
+ scalar_t x, y, z;
35
+
36
+ __device__
37
+ vec3() { x=0.0, y=0.0, z=0.0;}
38
+
39
+ __device__
40
+ vec3(scalar_t x, scalar_t y, scalar_t z):x(x), y(y), z(z) {}
41
+ };
42
+
43
+
44
+ template <typename scalar_t>
45
+ __device__
46
+ scalar_t lerp(scalar_t a, scalar_t b, scalar_t t) {
47
+ return (1.0-t) * a + t * b;
48
+ }
49
+
50
+ template <typename scalar_t>
51
+ __device__
52
+ void proj_ground(
53
+ scalar_t x0, scalar_t y0, scalar_t h0,
54
+ scalar_t x1, scalar_t y1, scalar_t h1,
55
+ scalar_t &x2, scalar_t &y2
56
+ ) {
57
+ scalar_t t = (0-h0)/(h1-h0);
58
+ x2 = lerp(x0, x1, t);
59
+ y2 = lerp(y0, y1, t);
60
+ }
61
+
62
+
63
+ // line checking condition with thickness value dh, which is the height difference for double-height map
64
+ // we can also use dh as a tolerance value
65
+ template <typename scalar_t>
66
+ __device__
67
+ bool check_intersect(
68
+ scalar_t xa, scalar_t ya, scalar_t ha,
69
+ scalar_t xb, scalar_t yb, scalar_t hb,
70
+ scalar_t x, scalar_t y, scalar_t h,
71
+ scalar_t dh, int& flag) {
72
+ scalar_t t = xa == xb ? (y-ya)/(yb-ya):(x-xa)/(xb-xa);
73
+ scalar_t h_ = lerp(ha, hb, t);
74
+ flag = h_ <= h ? 1:-1;
75
+ return (h_ <= h) && (h_ >= h-dh);
76
+ }
77
+
78
+ /*
79
+ * Ray trace in the current scene
80
+ * Given start point xyh, light point xyh, current receiver's height map,
81
+ * Return:
82
+ * 1. if intersect or not
83
+ * 2. the color for the intersection point
84
+ * */
85
+ template <typename scalar_t>
86
+ __device__
87
+ bool ray_trace(vec3<scalar_t> s,
88
+ vec3<scalar_t> l,
89
+ const int bi,
90
+ const torch::PackedTensorAccessor64<scalar_t,4> d_mask,
91
+ const torch::PackedTensorAccessor64<scalar_t,4> d_hmap,
92
+ const torch::PackedTensorAccessor64<scalar_t,4> d_rechmap,
93
+ const torch::PackedTensorAccessor64<scalar_t,4> d_rgb,
94
+ vec3<scalar_t> &out) {
95
+ bool ret = false;
96
+
97
+ const int batch_size = d_rgb.size(0), h = d_rgb.size(2), w = d_rgb.size(3);
98
+ scalar_t lx = l.x;
99
+ scalar_t ly = l.y;
100
+ scalar_t lh = l.z;
101
+
102
+ scalar_t recx = s.x;
103
+ scalar_t recy = s.y;
104
+ scalar_t rech = s.z;
105
+
106
+ scalar_t dirx = lx - recx, diry = ly - recy;
107
+ bool gox = abs(dirx) > abs(diry);
108
+ int searching_n = gox ? w : h;
109
+ int starti = 0, endi = searching_n;
110
+ if (lh > 0) {
111
+ if (gox) {
112
+ starti = recx < lx ? recx:lx;
113
+
114
+ endi = recx < lx ? lx:recx;
115
+ }
116
+ else {
117
+ starti = recy < ly ? recy:ly;
118
+ endi = recy < ly ? ly:recy;
119
+ }
120
+ }
121
+ if (lh < 0) {
122
+ if (gox) {
123
+ starti = recx < lx ? 0:recx;
124
+ endi = recx < lx ? recx:endi;
125
+ }
126
+ else {
127
+ starti = recy < ly ? 0:recy;
128
+ endi = recy < ly ? recy:endi;
129
+ }
130
+ }
131
+
132
+ scalar_t sx, sy;
133
+ int flag = 0, last_flag = 0;
134
+ for(int si = starti; si < endi; ++si) {
135
+ /* Searching Point xyh */
136
+ if (gox) {
137
+ sx = si;
138
+ sy = recy + (sx-recx)/dirx * diry;
139
+ } else {
140
+ sy = si;
141
+ sx = recx + (sy-recy)/diry * dirx;
142
+ }
143
+
144
+ if (sx < 0 || sx > w-1 || sy < 0 || sy > h-1 || d_mask[bi][0][sy][sx] < 0.989) {
145
+ last_flag = 0;
146
+ continue;
147
+ }
148
+
149
+
150
+ scalar_t sh0 = d_hmap[bi][0][sy][sx];
151
+ scalar_t sh1, sh;
152
+ // do linear interpolation for sh; note that either sy or sx are floating number
153
+ if (gox) {
154
+ if ( sy+1 > h-1 || d_mask[bi][0][sy+1][sx] < 0.989)
155
+ sh = sh0;
156
+ else {
157
+ sh1 = d_hmap[bi][0][sy+1][sx];
158
+ sh = lerp(sh0, sh1, sy - int(sy));
159
+ }
160
+ }
161
+ else {
162
+ if ( sx+1 > w-1 || d_mask[bi][0][sy][sx+1] < 0.989)
163
+ sh = sh0;
164
+ else {
165
+ sh1 = d_hmap[bi][0][sy][sx+1];
166
+ sh = lerp(sh0, sh1, sx - int(sx));
167
+ }
168
+ }
169
+
170
+ scalar_t dh = 1.0; // this controls the thickness; for double height map, dh = h_f - h_b
171
+ bool intersect = check_intersect(recx, recy, rech, lx, ly, lh, sx, sy, sh, dh, flag);
172
+ if (intersect) {
173
+ /* TODO, which sampling? linear interpolation? */
174
+ out.x = d_rgb[bi][0][(int)sy][(int)sx];
175
+ out.y = d_rgb[bi][1][(int)sy][(int)sx];
176
+ out.z = d_rgb[bi][2][(int)sy][(int)sx];
177
+
178
+ ret = true;
179
+ break;
180
+ }
181
+ if (last_flag != 0){
182
+ if (last_flag != flag) {
183
+ out.x = d_rgb[bi][0][(int)sy][(int)sx];
184
+ out.y = d_rgb[bi][1][(int)sy][(int)sx];
185
+ out.z = d_rgb[bi][2][(int)sy][(int)sx];
186
+
187
+ ret = true;
188
+ break;
189
+ }
190
+ }
191
+ last_flag = flag;
192
+ }
193
+
194
+ return ret;
195
+ }
196
+
197
+ /*
198
+ * Ray trace in the current scene
199
+ * Given start point xyh, light point xyh, current receiver's height map,
200
+ * Return:
201
+ * 1. if intersect or not
202
+ * 2. the color for the intersection point
203
+ * */
204
+ template <typename scalar_t>
205
+ __device__
206
+ bool ray_scene_intersect(vec3<scalar_t> ro,
207
+ vec3<scalar_t> rd,
208
+ const scalar_t dh,
209
+ const int bi,
210
+ const torch::PackedTensorAccessor64<scalar_t,4> d_mask,
211
+ const torch::PackedTensorAccessor64<scalar_t,4> d_hmap,
212
+ const torch::PackedTensorAccessor64<scalar_t,4> d_rgb,
213
+ vec3<scalar_t> &out) {
214
+ bool ret = false;
215
+ int h = d_mask.size(2);
216
+ int w = d_mask.size(3);
217
+ scalar_t dirx = rd.x, diry = rd.y, dirh = rd.z;
218
+
219
+ /* Special Case, there's no direction update in x or y, but in h */
220
+ if (abs(dirx) < 1e-6f && abs(diry) < 1e-6f) {
221
+ out.x = d_rgb[bi][0][(int)ro.y][(int)ro.x];
222
+ out.y = d_rgb[bi][1][(int)ro.y][(int)ro.x];
223
+ out.z = d_rgb[bi][2][(int)ro.y][(int)ro.x];
224
+ return true;
225
+ }
226
+
227
+ bool gox = abs(dirx) > abs(diry);
228
+ int searching_n = gox ? w : h;
229
+
230
+ scalar_t cur_h;
231
+ scalar_t sx, sy;
232
+
233
+ int prev_sign, cur_sign;
234
+
235
+ // for(int si = starti; si < endi; ++si) {
236
+ for(int si = 0; si < searching_n; ++si) {
237
+ /* Searching Point XYH */
238
+ if (gox) {
239
+ sx = ro.x + si * sign(dirx);
240
+ sy = ro.y + (sx-ro.x)/dirx * diry;
241
+ } else {
242
+ sy = ro.y + si * sign(diry);
243
+ sx = ro.x + (sy-ro.y)/diry * dirx;
244
+ }
245
+
246
+ if (sx < 0 || sx > w-1 || sy < 0 || sy > h-1 || d_mask[bi][0][sy][sx] < 0.989) {
247
+ continue;
248
+ }
249
+
250
+ scalar_t sh0 = d_hmap[bi][0][sy][sx];
251
+ scalar_t sh1, sh;
252
+ // do linear interpolation for sh; note that either sy or sx are floating number
253
+ if (gox) {
254
+ if (sy+1 > h-1 || d_mask[bi][0][sy+1][sx] < 0.989)
255
+ sh = sh0;
256
+ else {
257
+ sh1 = d_hmap[bi][0][sy+1][sx];
258
+ sh = lerp(sh0, sh1, sy - int(sy)); // Always use 0.5 to do interpolation
259
+ }
260
+
261
+ cur_h = ro.z + (sx - ro.x) / dirx * dirh;
262
+ }
263
+ else {
264
+ if ( sx + 1 > w-1 || d_mask[bi][0][sy][sx+1] < 0.989)
265
+ sh = sh0;
266
+ else {
267
+ sh1 = d_hmap[bi][0][sy][sx+1];
268
+ sh = lerp(sh0, sh1, sx - int(sx));
269
+ }
270
+
271
+ cur_h = ro.z + (sy - ro.y) / diry * dirh;
272
+ }
273
+
274
+ // collide with the rechmap?
275
+ if (si == 0) { /* First sign */
276
+ cur_sign = cur_h - sh;
277
+ continue;
278
+ } else {
279
+ prev_sign = cur_sign;
280
+ }
281
+
282
+ cur_sign = cur_h - sh;
283
+ // if (cur_sign * prev_sign < 0.0 || abs(cur_sign) < dh) { /* pass through some objects */
284
+ if (abs(cur_sign) < dh) { /* pass through some objects */
285
+ out.x = d_rgb[bi][0][(int)sy][(int)sx];
286
+ out.y = d_rgb[bi][1][(int)sy][(int)sx];
287
+ out.z = d_rgb[bi][2][(int)sy][(int)sx];
288
+ ret = true;
289
+ break;
290
+ }
291
+
292
+ }
293
+
294
+ return ret;
295
+ }
296
+
297
+
298
+ template <typename scalar_t>
299
+ __global__ void hshadow_render_cuda_forward(
300
+ const torch::PackedTensorAccessor64<scalar_t,4> d_rgb,
301
+ const torch::PackedTensorAccessor64<scalar_t,4> d_mask,
302
+ const torch::PackedTensorAccessor64<scalar_t,2> d_bb,
303
+ const torch::PackedTensorAccessor64<scalar_t,4> d_hmap,
304
+ const torch::PackedTensorAccessor64<scalar_t,4> d_rechmap,
305
+ const torch::PackedTensorAccessor64<scalar_t,2> d_lightpos,
306
+ torch::PackedTensorAccessor64<scalar_t,4> d_shadow) {
307
+ const int wstride = gridDim.x * blockDim.x, hstride = gridDim.y * blockDim.y, bstride = gridDim.z * blockDim.z;
308
+ const int batch_size = d_rgb.size(0), h = d_rgb.size(2), w = d_rgb.size(3);
309
+
310
+ for (int bi = blockIdx.z; bi < batch_size; bi += bstride) {
311
+ /* light xyh */
312
+ scalar_t lx = d_lightpos[bi][0], ly = d_lightpos[bi][1], lh = d_lightpos[bi][2];
313
+ int minh = max((int)d_bb[bi][0], 0), maxh = min((int)d_bb[bi][1], h-1), minw = max((int)d_bb[bi][2], 0), maxw = min((int)d_bb[bi][3], w-1);
314
+
315
+ vec3<scalar_t> light(lx, ly, lh);
316
+ for (int wi = blockIdx.x * blockDim.x + threadIdx.x; wi < w; wi += wstride) for(int hi = blockIdx.y * blockDim.y + threadIdx.y; hi < h; hi += hstride) {
317
+ scalar_t shadow(1.0), mask_alpha(0.0);
318
+ scalar_t recx = wi + 0.5, recy = hi+0.5, rech = d_rechmap[bi][0][hi][wi];
319
+
320
+ vec3<scalar_t> start(recx, recy, rech);
321
+ vec3<scalar_t> intersect_color;
322
+
323
+ /* Searching Potentials */
324
+ if (ray_trace(start, light, bi, d_mask, d_hmap, d_rechmap, d_rgb, intersect_color)) {
325
+ shadow = 0.0;
326
+ }
327
+
328
+ d_shadow[bi][0][hi][wi] = shadow;
329
+ d_shadow[bi][1][hi][wi] = shadow;
330
+ d_shadow[bi][2][hi][wi] = shadow;
331
+ }
332
+ }
333
+ }
334
+
335
+
336
+ template <typename scalar_t>
337
+ __global__ void ray_intersect_cuda_forward(
338
+ const torch::PackedTensorAccessor64<scalar_t,4> d_rgb,
339
+ const torch::PackedTensorAccessor64<scalar_t,4> d_mask,
340
+ const torch::PackedTensorAccessor64<scalar_t,4> d_hmap,
341
+ const torch::PackedTensorAccessor64<scalar_t,4> d_rechmap,
342
+ const torch::PackedTensorAccessor64<scalar_t,4> d_rd_map,
343
+ torch::PackedTensorAccessor64<scalar_t,4> d_intersect) {
344
+
345
+ const int wstride = gridDim.x * blockDim.x, hstride = gridDim.y * blockDim.y, bstride = gridDim.z * blockDim.z;
346
+ const int batch_size = d_rgb.size(0), h = d_rgb.size(2), w = d_rgb.size(3);
347
+ const scalar_t default_value = 0.0;
348
+
349
+ for (int bi = blockIdx.z; bi < batch_size; bi += bstride) {
350
+ for (int wi = blockIdx.x * blockDim.x + threadIdx.x; wi < w; wi += wstride) for(int hi = blockIdx.y * blockDim.y + threadIdx.y; hi < h; hi += hstride) {
351
+ scalar_t shadow(1.0), mask_alpha(0.0);
352
+ scalar_t recx = wi + 0.5, recy = hi+0.5, rech = d_rechmap[bi][0][hi][wi];
353
+
354
+ scalar_t lx = d_rd_map[bi][0][hi][wi];
355
+ scalar_t ly = d_rd_map[bi][1][hi][wi];
356
+ scalar_t lh = d_rd_map[bi][2][hi][wi];
357
+
358
+ vec3<scalar_t> start(recx, recy, rech);
359
+ vec3<scalar_t> rd(lx, ly, lh);
360
+
361
+ vec3<scalar_t> intersect_color;
362
+ if (ray_trace(start, rd, bi, d_mask, d_hmap, d_rechmap, d_rgb, intersect_color)) {
363
+ d_intersect[bi][0][hi][wi] = intersect_color.x;
364
+ d_intersect[bi][1][hi][wi] = intersect_color.y;
365
+ d_intersect[bi][2][hi][wi] = intersect_color.z;
366
+ d_intersect[bi][3][hi][wi] = 1.0;
367
+ } else {
368
+ d_intersect[bi][0][hi][wi] = default_value;
369
+ d_intersect[bi][1][hi][wi] = default_value;
370
+ d_intersect[bi][2][hi][wi] = default_value;
371
+ d_intersect[bi][3][hi][wi] = 0.0;
372
+ }
373
+ }
374
+ }
375
+ }
376
+
377
+ template <typename scalar_t>
378
+ __global__ void ray_scene_intersect_cuda_forward(
379
+ const torch::PackedTensorAccessor64<scalar_t,4> d_rgb,
380
+ const torch::PackedTensorAccessor64<scalar_t,4> d_mask,
381
+ const torch::PackedTensorAccessor64<scalar_t,4> d_hmap,
382
+ const torch::PackedTensorAccessor64<scalar_t,4> d_ro,
383
+ const torch::PackedTensorAccessor64<scalar_t,4> d_rd,
384
+ const scalar_t dh,
385
+ torch::PackedTensorAccessor64<scalar_t,4> d_intersect) {
386
+
387
+ const int wstride = gridDim.x * blockDim.x, hstride = gridDim.y * blockDim.y, bstride = gridDim.z * blockDim.z;
388
+ const int batch_size = d_rgb.size(0), h = d_rgb.size(2), w = d_rgb.size(3);
389
+ const scalar_t default_value = 0.0;
390
+
391
+ for (int bi = blockIdx.z; bi < batch_size; bi += bstride) {
392
+ for (int wi = blockIdx.x * blockDim.x + threadIdx.x; wi < w; wi += wstride) {
393
+ for(int hi = blockIdx.y * blockDim.y + threadIdx.y; hi < h; hi += hstride) {
394
+ // scalar_t rox = wi + 0.5, roy = hi+0.5, roh = d_ro[bi][0][hi][wi];
395
+ scalar_t rox = d_ro[bi][0][hi][wi];
396
+ scalar_t roy = d_ro[bi][1][hi][wi];
397
+ scalar_t roh = d_ro[bi][2][hi][wi];
398
+
399
+ scalar_t rdx = d_rd[bi][0][hi][wi];
400
+ scalar_t rdy = d_rd[bi][1][hi][wi];
401
+ scalar_t rdh = d_rd[bi][2][hi][wi];
402
+
403
+ vec3<scalar_t> ro(rox, roy, roh);
404
+ vec3<scalar_t> rd(rdx, rdy, rdh);
405
+
406
+ vec3<scalar_t> intersect_color;
407
+ if (ray_scene_intersect(ro, rd, dh, bi, d_mask, d_hmap, d_rgb, intersect_color)) {
408
+ d_intersect[bi][0][hi][wi] = intersect_color.x;
409
+ d_intersect[bi][1][hi][wi] = intersect_color.y;
410
+ d_intersect[bi][2][hi][wi] = intersect_color.z;
411
+ d_intersect[bi][3][hi][wi] = 1.0;
412
+ } else {
413
+ d_intersect[bi][0][hi][wi] = default_value;
414
+ d_intersect[bi][1][hi][wi] = default_value;
415
+ d_intersect[bi][2][hi][wi] = default_value;
416
+ d_intersect[bi][3][hi][wi] = 0.0;
417
+ }
418
+ }
419
+ }
420
+ }
421
+ }
422
+
423
+ template <typename scalar_t>
424
+ __global__ void reflect_render_cuda_forward(
425
+ const torch::PackedTensorAccessor64<scalar_t,4> d_rgb,
426
+ const torch::PackedTensorAccessor64<scalar_t,4> d_mask,
427
+ const torch::PackedTensorAccessor64<scalar_t,4> d_hmap,
428
+ const torch::PackedTensorAccessor64<scalar_t,4> d_rechmap,
429
+ const torch::PackedTensorAccessor64<scalar_t,2> d_thresholds,
430
+ torch::PackedTensorAccessor64<scalar_t,4> d_reflect,
431
+ torch::PackedTensorAccessor64<scalar_t,4> d_reflect_height,
432
+ torch::PackedTensorAccessor64<scalar_t,4> d_reflect_mask) {
433
+ const int wstride = gridDim.x * blockDim.x, hstride = gridDim.y * blockDim.y, bstride = gridDim.z * blockDim.z;
434
+ const int batch_size = d_rgb.size(0), h = d_rgb.size(2), w = d_rgb.size(3);
435
+
436
+ for (int bi = blockIdx.z; bi < batch_size; bi += bstride) {
437
+ for (int wi = blockIdx.x * blockDim.x + threadIdx.x; wi < w; wi += wstride) for(int hi = blockIdx.y * blockDim.y + threadIdx.y; hi < h; hi += hstride) {
438
+ /* Back tracing along the height
439
+ Find the closest point, filter the closest point
440
+ */
441
+ scalar_t min_dis = FLT_MAX;
442
+ scalar_t min_r, min_g, min_b, min_height, min_mask;
443
+ for(int ti = hi-1; ti >= 0; --ti) {
444
+ if (d_mask[bi][0][ti][wi] < 0.45)
445
+ continue;
446
+
447
+ scalar_t dis = abs(d_hmap[bi][0][ti][wi] * 2 + ti - hi);
448
+ if (dis < min_dis) {
449
+ min_dis = dis;
450
+ min_r = d_rgb[bi][0][ti][wi];
451
+ min_g = d_rgb[bi][1][ti][wi];
452
+ min_b = d_rgb[bi][2][ti][wi];
453
+
454
+ min_height = d_hmap[bi][0][ti][wi];
455
+ min_mask = d_mask[bi][0][ti][wi];
456
+ }
457
+ }
458
+
459
+ /* Check Condition */
460
+ scalar_t cur_thresholds = d_thresholds[bi][0];
461
+ if (min_dis < cur_thresholds) {
462
+ /* Let Use Nearest Neighbor First */
463
+ d_reflect[bi][0][hi][wi] = min_r;
464
+ d_reflect[bi][1][hi][wi] = min_g;
465
+ d_reflect[bi][2][hi][wi] = min_b;
466
+ d_reflect_height[bi][0][hi][wi] = min_height;
467
+ d_reflect_mask[bi][0][hi][wi] = 1.0;
468
+ }
469
+
470
+ // } else {
471
+ // scalar_t fadding = 1.0-(min_dis-cur_thresholds);
472
+ // if (fadding < 0.0) fadding = 0.0;
473
+ // d_reflect[bi][0][hi][wi] = min_r * fadding + (1.0-fadding);
474
+ // d_reflect[bi][1][hi][wi] = min_g * fadding + (1.0-fadding);
475
+ // d_reflect[bi][2][hi][wi] = min_b * fadding + (1.0-fadding);
476
+ // d_reflect_height[bi][0][hi][wi] = 0.0;
477
+ // d_reflect_mask[bi][0][hi][wi] = fadding;
478
+ // }
479
+ }
480
+ }
481
+ }
482
+
483
+ template <typename scalar_t>
484
+ __global__ void glossy_reflect_render_cuda_forward(
485
+ const torch::PackedTensorAccessor64<scalar_t,4> d_rgb,
486
+ const torch::PackedTensorAccessor64<scalar_t,4> d_mask,
487
+ const torch::PackedTensorAccessor64<scalar_t,4> d_hmap,
488
+ const torch::PackedTensorAccessor64<scalar_t,4> d_rechmap,
489
+ const int sample_n,
490
+ const float glossy,
491
+ torch::PackedTensorAccessor64<scalar_t,4> d_reflect) {
492
+
493
+ const int wstride = gridDim.x * blockDim.x, hstride = gridDim.y * blockDim.y, bstride = gridDim.z * blockDim.z;
494
+ const int batch_size = d_rgb.size(0), h = d_rgb.size(2), w = d_rgb.size(3);
495
+
496
+ for (int bi = blockIdx.z; bi < batch_size; bi += bstride) {
497
+ for (int wi = blockIdx.x * blockDim.x + threadIdx.x; wi < w; wi += wstride) for(int hi = blockIdx.y * blockDim.y + threadIdx.y; hi < h; hi += hstride) {
498
+ /* Back tracing along the height
499
+ Find the closest point, filter the closest point
500
+ */
501
+ scalar_t min_dis = FLT_MAX;
502
+ scalar_t min_r, min_g, min_b, min_height, min_mask;
503
+ for(int ti = hi-1; ti >= 0; --ti) {
504
+ if (d_mask[bi][0][ti][wi] < 0.45)
505
+ continue;
506
+
507
+ scalar_t dis = abs(d_hmap[bi][0][ti][wi] * 2 + ti - hi);
508
+ if (dis < min_dis) {
509
+ min_dis = dis;
510
+ min_r = d_rgb[bi][0][ti][wi];
511
+ min_g = d_rgb[bi][1][ti][wi];
512
+ min_b = d_rgb[bi][2][ti][wi];
513
+
514
+ min_height = d_hmap[bi][0][ti][wi];
515
+ min_mask = d_mask[bi][0][ti][wi];
516
+ }
517
+ }
518
+
519
+ /* Check Condition */
520
+ float cur_thresholds = 1e-1;
521
+ if (min_dis < cur_thresholds) {
522
+ /* Let Use Nearest Neighbor First */
523
+ d_reflect[bi][0][hi][wi] = min_r;
524
+ d_reflect[bi][1][hi][wi] = min_g;
525
+ d_reflect[bi][2][hi][wi] = min_b;
526
+ }
527
+ }
528
+ }
529
+ }
530
+
531
+ } // namespace
532
+
533
+ std::vector<torch::Tensor> hshadow_render_cuda_forward(
534
+ torch::Tensor rgb,
535
+ torch::Tensor mask,
536
+ torch::Tensor bb,
537
+ torch::Tensor hmap,
538
+ torch::Tensor rechmap,
539
+ torch::Tensor light_pos) {
540
+ const auto batch_size = rgb.size(0);
541
+ const auto channel_size = rgb.size(1);
542
+ const auto h = rgb.size(2);
543
+ const auto w = rgb.size(3);
544
+ const dim3 threads(16, 16, 1);
545
+ const dim3 blocks((w + threads.x - 1) / threads.x, (h+threads.y-1)/threads.y, batch_size);
546
+ torch::Tensor shadow_tensor = torch::ones({batch_size, 3, h, w}).to(rgb);
547
+
548
+ AT_DISPATCH_FLOATING_TYPES(rgb.type(), "hshadow_render_cuda_forward", ([&] {
549
+ hshadow_render_cuda_forward<scalar_t><<<blocks, threads>>>(
550
+ rgb.packed_accessor64<scalar_t,4>(),
551
+ mask.packed_accessor64<scalar_t,4>(),
552
+ bb.packed_accessor64<scalar_t,2>(),
553
+ hmap.packed_accessor64<scalar_t,4>(),
554
+ rechmap.packed_accessor64<scalar_t,4>(),
555
+ light_pos.packed_accessor64<scalar_t,2>(),
556
+ shadow_tensor.packed_accessor64<scalar_t,4>());
557
+ }));
558
+
559
+ return {shadow_tensor};
560
+ }
561
+
562
+ std::vector<torch::Tensor> reflect_render_cuda_forward(
563
+ torch::Tensor rgb,
564
+ torch::Tensor mask,
565
+ torch::Tensor hmap,
566
+ torch::Tensor rechmap,
567
+ torch::Tensor thresholds) {
568
+ const auto batch_size = rgb.size(0);
569
+ const auto channel_size = rgb.size(1);
570
+ const auto h = rgb.size(2);
571
+ const auto w = rgb.size(3);
572
+ const dim3 threads(16, 16, 1);
573
+ const dim3 blocks((w + threads.x - 1) / threads.x, (h+threads.y-1)/threads.y, batch_size);
574
+ torch::Tensor reflection_tensor = torch::ones({batch_size, 3, h, w}).to(rgb);
575
+ torch::Tensor reflection_mask_tensor = torch::zeros({batch_size, 1, h, w}).to(rgb);
576
+ torch::Tensor reflection_height_tensor = torch::zeros({batch_size, 1, h, w}).to(rgb);
577
+
578
+ AT_DISPATCH_FLOATING_TYPES(rgb.type(), "reflect_render_cuda_forward", ([&] {
579
+ reflect_render_cuda_forward<scalar_t><<<blocks, threads>>>(
580
+ rgb.packed_accessor64<scalar_t,4>(),
581
+ mask.packed_accessor64<scalar_t,4>(),
582
+ hmap.packed_accessor64<scalar_t,4>(),
583
+ rechmap.packed_accessor64<scalar_t,4>(),
584
+ thresholds.packed_accessor64<scalar_t,2>(),
585
+ reflection_tensor.packed_accessor64<scalar_t,4>(),
586
+ reflection_height_tensor.packed_accessor64<scalar_t,4>(),
587
+ reflection_mask_tensor.packed_accessor64<scalar_t,4>());
588
+ }));
589
+
590
+ return {reflection_tensor, reflection_height_tensor,reflection_mask_tensor};
591
+ }
592
+
593
+
594
+ std::vector<torch::Tensor> glossy_reflect_render_cuda_forward(torch::Tensor rgb,
595
+ torch::Tensor mask,
596
+ torch::Tensor hmap,
597
+ torch::Tensor rechmap,
598
+ const int sample_n,
599
+ const float glossy) {
600
+ const auto batch_size = rgb.size(0);
601
+ const auto channel_size = rgb.size(1);
602
+ const auto h = rgb.size(2);
603
+ const auto w = rgb.size(3);
604
+ const dim3 threads(16, 16, 1);
605
+ const dim3 blocks((w + threads.x - 1) / threads.x, (h+threads.y-1)/threads.y, batch_size);
606
+
607
+ torch::Tensor reflection_tensor = torch::ones({batch_size, 3, h, w}).to(rgb);
608
+
609
+ AT_DISPATCH_FLOATING_TYPES(rgb.type(), "reflect_render_cuda_forward", ([&] {
610
+ glossy_reflect_render_cuda_forward<scalar_t><<<blocks, threads>>>(
611
+ rgb.packed_accessor64<scalar_t,4>(),
612
+ mask.packed_accessor64<scalar_t,4>(),
613
+ hmap.packed_accessor64<scalar_t,4>(),
614
+ rechmap.packed_accessor64<scalar_t,4>(),
615
+ sample_n,
616
+ glossy,
617
+ reflection_tensor.packed_accessor64<scalar_t,4>());
618
+ }));
619
+
620
+ return {reflection_tensor};
621
+
622
+ }
623
+
624
+
625
+ torch::Tensor ray_intersect_cuda_forward(torch::Tensor rgb,
626
+ torch::Tensor mask,
627
+ torch::Tensor hmap,
628
+ torch::Tensor rechmap,
629
+ torch::Tensor rd_map){
630
+ const auto batch_size = rgb.size(0);
631
+ const auto channel_size = rgb.size(1);
632
+ const auto h = rgb.size(2);
633
+ const auto w = rgb.size(3);
634
+ const dim3 threads(16, 16, 1);
635
+ const dim3 blocks((w + threads.x - 1) / threads.x, (h+threads.y-1)/threads.y, batch_size);
636
+
637
+ torch::Tensor intersect_tensor = torch::ones({batch_size, 4, h, w}).to(rgb);
638
+
639
+ AT_DISPATCH_FLOATING_TYPES(rgb.type(), "reflect_render_cuda_forward", ([&] {
640
+ ray_intersect_cuda_forward<scalar_t><<<blocks, threads>>>(
641
+ rgb.packed_accessor64<scalar_t,4>(),
642
+ mask.packed_accessor64<scalar_t,4>(),
643
+ hmap.packed_accessor64<scalar_t,4>(),
644
+ rechmap.packed_accessor64<scalar_t,4>(),
645
+ rd_map.packed_accessor64<scalar_t,4>(),
646
+ intersect_tensor.packed_accessor64<scalar_t,4>());
647
+ }));
648
+
649
+ return intersect_tensor;
650
+
651
+ }
652
+
653
+
654
+ torch::Tensor ray_scene_intersect_cuda_forward(torch::Tensor rgb,
655
+ torch::Tensor mask,
656
+ torch::Tensor hmap,
657
+ torch::Tensor ro,
658
+ torch::Tensor rd,
659
+ float dh){
660
+ const auto batch_size = rgb.size(0);
661
+ const auto channel_size = rgb.size(1);
662
+ const auto h = rgb.size(2);
663
+ const auto w = rgb.size(3);
664
+ const dim3 threads(16, 16, 1);
665
+ const dim3 blocks((w + threads.x - 1) / threads.x, (h+threads.y-1)/threads.y, batch_size);
666
+
667
+ torch::Tensor intersect_tensor = torch::ones({batch_size, 4, h, w}).to(rgb);
668
+
669
+ AT_DISPATCH_FLOATING_TYPES(rgb.type(), "reflect_render_cuda_forward", ([&] {
670
+ ray_scene_intersect_cuda_forward<scalar_t><<<blocks, threads>>>(
671
+ rgb.packed_accessor64<scalar_t,4>(),
672
+ mask.packed_accessor64<scalar_t,4>(),
673
+ hmap.packed_accessor64<scalar_t,4>(),
674
+ ro.packed_accessor64<scalar_t,4>(),
675
+ rd.packed_accessor64<scalar_t,4>(),
676
+ dh,
677
+ intersect_tensor.packed_accessor64<scalar_t,4>());
678
+ }));
679
+
680
+ return intersect_tensor;
681
+
682
+ }
PixHtLab-Src/Demo/PixhtLab/Torch_Render/plane_visualize.cpp ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #include <torch/extension.h>
3
+
4
+ #include <vector>
5
+ #include <stdio.h>
6
+
7
+ // CUDA forward declarations
8
+ std::vector<torch::Tensor> plane_visualize_cuda(torch::Tensor planes, torch::Tensor camera, int h, int w);
9
+
10
+ // C++ interface
11
+ // NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
12
+ #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
13
+ #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
14
+ #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
15
+
16
+ std::vector<torch::Tensor> plane_visualize(torch::Tensor planes, torch::Tensor camera, int h, int w) {
17
+ CHECK_INPUT(planes);
18
+ CHECK_INPUT(camera);
19
+
20
+ return plane_visualize_cuda(planes, camera, h, w);
21
+ }
22
+
23
+
24
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
25
+ m.def("forward", &plane_visualize, "Plane Visualization (CUDA)");
26
+ }
PixHtLab-Src/Demo/PixhtLab/Torch_Render/plane_visualize_cuda.cu ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+
3
+ #include <cuda.h>
4
+ #include <cuda_runtime.h>
5
+
6
+ #include <vector>
7
+
8
+ namespace {
9
+ template <typename scalar_t>
10
+ struct vec3 {
11
+ scalar_t x, y, z;
12
+
13
+ __device__ __host__
14
+ vec3<scalar_t>():x(0),y(0),z(0) {}
15
+
16
+ __device__ __host__
17
+ vec3<scalar_t>(scalar_t a):x(a),y(a),z(a) {}
18
+
19
+ __device__ __host__
20
+ vec3<scalar_t>(scalar_t xx, scalar_t yy, scalar_t zz):x(xx),y(yy),z(zz) {}
21
+
22
+ __device__ __host__
23
+ vec3<scalar_t> operator*(const scalar_t &rhs) const {
24
+ vec3<scalar_t> ret(x,y,z);
25
+ ret.x = x * rhs;
26
+ ret.y = y * rhs;
27
+ ret.z = z * rhs;
28
+ return ret;
29
+ }
30
+
31
+ __device__ __host__
32
+ vec3<scalar_t> operator/(const scalar_t &rhs) const {
33
+ vec3<scalar_t> ret(x,y,z);
34
+ ret.x = x / rhs;
35
+ ret.y = y / rhs;
36
+ ret.z = z / rhs;
37
+ return ret;
38
+ }
39
+
40
+ __device__ __host__
41
+ vec3<scalar_t> operator+(const vec3<scalar_t> &rhs) const {
42
+ vec3<scalar_t> ret(x,y,z);
43
+ ret.x = x + rhs.x;
44
+ ret.y = y + rhs.y;
45
+ ret.z = z + rhs.z;
46
+ return ret;
47
+ }
48
+
49
+ __device__ __host__
50
+ vec3<scalar_t> operator-(const vec3<scalar_t> &rhs) const {
51
+ vec3<scalar_t> ret(x,y,z);
52
+ ret.x = x - rhs.x;
53
+ ret.y = y - rhs.y;
54
+ ret.z = z - rhs.z;
55
+ return ret;
56
+ }
57
+ };
58
+
59
+ template <typename scalar_t>
60
+ struct Ray {
61
+ vec3<scalar_t> ro, rd;
62
+ };
63
+
64
+ template <typename scalar_t>
65
+ struct Scene {
66
+ vec3<scalar_t> pp, pn;
67
+ };
68
+
69
+ template <typename scalar_t>
70
+ __device__
71
+ float deg2rad(scalar_t d) {
72
+ return d/180.0 * 3.1415926f;
73
+ }
74
+
75
+ template <typename scalar_t>
76
+ __device__
77
+ scalar_t dot(vec3<scalar_t> a, vec3<scalar_t> b) {
78
+ return a.x * b.x + a.y * b.y + a.z * b.z;
79
+ }
80
+
81
+ template <typename scalar_t>
82
+ __device__
83
+ vec3<scalar_t> cross(vec3<scalar_t> a, vec3<scalar_t> b) {
84
+ vec3<scalar_t> ret(0.0f);
85
+ ret.x = a.y * b.z - a.z * b.y;
86
+ ret.y = a.z * b.x - a.x * b.z;
87
+ ret.z = a.x * b.y - a.y * b.x;
88
+ return ret;
89
+ }
90
+
91
+ template <typename scalar_t>
92
+ __device__
93
+ scalar_t length(vec3<scalar_t> a) {
94
+ return sqrt(dot(a, a));
95
+ }
96
+
97
+ template <typename scalar_t>
98
+ __device__
99
+ vec3<scalar_t> normalize(vec3<scalar_t> a) {
100
+ return a/length(a);
101
+ }
102
+
103
+ template <typename scalar_t>
104
+ __device__
105
+ scalar_t get_focal(int w, scalar_t fov) {
106
+ return 0.5 * w / tan(deg2rad(fov * 0.5));
107
+ }
108
+
109
+ template <typename scalar_t>
110
+ __device__
111
+ vec3<scalar_t> get_rd(vec3<scalar_t> right, vec3<scalar_t> front, vec3<scalar_t> up, int h, int w, scalar_t focal, scalar_t x, scalar_t y) {
112
+ /* x, y in [-1, 1] */
113
+ right = normalize(right);
114
+ front = normalize(front);
115
+ up = normalize(up);
116
+
117
+ return front * focal + right * x * (float)w * 0.5f + up * y * (float)h * 0.5f;
118
+ }
119
+
120
+ template <typename scalar_t>
121
+ __device__
122
+ Ray<scalar_t> get_ray(int h, int w, float hi, float wi, vec3<scalar_t> right, vec3<scalar_t>front, vec3<scalar_t> up, vec3<scalar_t> cam_pos, float fov) {
123
+ /* Note, wi/hi is in [-1.0, 1.0] */
124
+ Ray<scalar_t> ray;
125
+
126
+ float focal = 0.5f * w / tan(deg2rad(fov));
127
+ ray.ro = cam_pos;
128
+ ray.rd = front * focal + right * 0.5f * w * wi + up * 0.5f * h * hi;
129
+ return ray;
130
+ }
131
+
132
+ template <typename scalar_t>
133
+ __device__
134
+ bool plane_intersect(Ray<scalar_t> ray, vec3<scalar_t> p, vec3<scalar_t> n, float &t) {
135
+ vec3<scalar_t> ro = ray.ro, rd = ray.rd;
136
+ t = dot(p-ro, n)/dot(rd, n);
137
+ return t >= 0.0;
138
+ }
139
+
140
+ template <typename scalar_t>
141
+ __device__
142
+ vec3<scalar_t> horizon2front(scalar_t horizon, int h, int w, float fov) {
143
+ scalar_t yoffset = h / 2 - horizon;
144
+ scalar_t focal = 0.5f * w / tan(deg2rad(fov));
145
+ vec3<scalar_t> front = vec3<scalar_t>(0.0f,0.0f,-1.0f) * focal + vec3<scalar_t>(0.0f, 1.0f, 0.0f) * yoffset;
146
+ return normalize(front);
147
+ }
148
+
149
+ template <typename scalar_t>
150
+ __device__
151
+ vec3<scalar_t> plane_texture(vec3<scalar_t> p) {
152
+ float freq = 6.0f;
153
+ float u = sin(p.x * freq), v = sin(p.z * freq);
154
+ vec3<scalar_t> ret(0.0f);
155
+ float line_width = 0.05f;
156
+ if ((abs(u) < line_width || abs(v) < line_width)) {
157
+ ret = vec3<scalar_t>(1.0f);
158
+ }
159
+ return ret;
160
+ }
161
+
162
+ template <typename scalar_t>
163
+ __device__
164
+ bool ray_scene_trace(Ray<scalar_t> ray, Scene<scalar_t> scene, vec3<scalar_t> &color) {
165
+ color = vec3<scalar_t>(0.0f);
166
+ float t;
167
+ if (plane_intersect(ray, scene.pp, scene.pn, t)) {
168
+ vec3<scalar_t> intersect_pos = ray.ro + ray.rd * t;
169
+ color = plane_texture(intersect_pos);
170
+ return true;
171
+ }
172
+ return false;
173
+ }
174
+
175
+ template <typename scalar_t>
176
+ __global__ void plane_visualize_foward(
177
+ const torch::PackedTensorAccessor64<scalar_t,2> d_plane,
178
+ const torch::PackedTensorAccessor64<scalar_t,2> d_camera,
179
+ torch::PackedTensorAccessor64<scalar_t,4> d_vis) {
180
+ const int wstride = gridDim.x * blockDim.x, hstride = gridDim.y * blockDim.y, bstride = gridDim.z * blockDim.z;
181
+ const int batch_size = d_vis.size(0), h = d_vis.size(2), w = d_vis.size(3);
182
+ const int samples = 10;
183
+
184
+ vec3<scalar_t> cam_pos(0.0f, 1.0f, 1.0f), front(0.0f,0.0f,-1.0f), right(1.0f, 0.0f, 0.0f), up(0.0f, 1.0f, 0.0f);
185
+ for (int bi = blockIdx.z; bi < batch_size; bi += bstride) {
186
+ // scalar_t px = d_plane[bi][0], py = d_plane[bi][1], pz = d_plane[bi][2];
187
+ vec3<scalar_t> plane_pos(d_plane[bi][0], d_plane[bi][1], d_plane[bi][2]);
188
+ vec3<scalar_t> plane_norm(d_plane[bi][3], d_plane[bi][4], d_plane[bi][5]);
189
+ Scene<scalar_t> scene = {plane_pos, plane_norm};
190
+
191
+ scalar_t fov = d_camera[bi][0], horizon = d_camera[bi][1];
192
+ front = normalize(horizon2front(horizon, h, w, fov));
193
+ up = normalize(cross(right, front));
194
+ for (int wi = blockIdx.x * blockDim.x + threadIdx.x; wi < w; wi += wstride)
195
+ for(int hi = blockIdx.y * blockDim.y + threadIdx.y; hi < h; hi += hstride) {
196
+ bool intersect = false;
197
+ vec3<scalar_t> color(0.0f);
198
+ for (int si = 0; si < samples * samples; ++si) {
199
+ float hoffset = (float)(si/samples)/max(samples-1, 1);
200
+ float woffset = (float)(si%samples)/max(samples-1, 1);
201
+ float x = (float)(wi + woffset)/w * 2.0 - 1.0;
202
+ float y = (float)(hi + hoffset)/h * 2.0 - 1.0;
203
+ Ray<scalar_t> ray = get_ray(h, w, y, x, right, front, up, cam_pos,fov);
204
+ vec3<scalar_t> tmp_color(0.0f);
205
+ if(ray_scene_trace(ray, scene, tmp_color)) {
206
+ color = color + tmp_color;
207
+ intersect = intersect || true;
208
+ }
209
+ }
210
+ if (intersect) {
211
+ color = color / float(samples * samples);
212
+ d_vis[bi][0][hi][wi] = color.x;
213
+ d_vis[bi][1][hi][wi] = color.y;
214
+ d_vis[bi][2][hi][wi] = color.z;
215
+ }
216
+ }
217
+ }
218
+ }
219
+
220
+ } // namespace
221
+
222
+ std::vector<torch::Tensor> plane_visualize_cuda(torch::Tensor planes, torch::Tensor camera, int h, int w){
223
+ const auto batch_size = planes.size(0);
224
+ const int threads = 512;
225
+ const dim3 blocks((w + threads - 1) / threads, (h+threads-1)/threads, batch_size);
226
+
227
+ torch::Tensor vis_tensor = torch::zeros({batch_size, 3, h, w}).to(planes);
228
+ AT_DISPATCH_FLOATING_TYPES(planes.type(), "plane_visualize_foward", ([&] {
229
+ plane_visualize_foward<scalar_t><<<blocks, threads>>>(
230
+ planes.packed_accessor64<scalar_t,2>(),
231
+ camera.packed_accessor64<scalar_t,2>(),
232
+ vis_tensor.packed_accessor64<scalar_t,4>()
233
+ );
234
+ }));
235
+
236
+ return {vis_tensor};
237
+ }
PixHtLab-Src/Demo/PixhtLab/Torch_Render/setup.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup
2
+ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
3
+
4
+
5
+ setup(
6
+ name='hshadow',
7
+ ext_modules=[
8
+ CUDAExtension('hshadow', [
9
+ 'hshadow_cuda.cpp',
10
+ 'hshadow_cuda_kernel.cu',
11
+ ])
12
+ ],
13
+ cmdclass={
14
+ 'build_ext': BuildExtension
15
+ }
16
+ )
17
+
18
+ # setup(
19
+ # name='plane_visualize',
20
+ # ext_modules=[
21
+ # CUDAExtension('plane_visualize', [
22
+ # 'plane_visualize.cpp',
23
+ # 'plane_visualize_cuda.cu',
24
+ # ])
25
+ # ],
26
+ # cmdclass={
27
+ # 'build_ext': BuildExtension
28
+ # }
29
+ # )
PixHtLab-Src/Demo/PixhtLab/Torch_Render/test_ground.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ import plane_visualize
4
+ import matplotlib.pyplot as plt
5
+ from PIL import Image
6
+ from torchvision import transforms
7
+ import os
8
+ from os.path import join
9
+ import numpy as np
10
+
11
+ test_output = 'imgs/output'
12
+ os.makedirs(test_output, exist_ok=True)
13
+ device = torch.device("cuda:0")
14
+
15
+ def test_ground():
16
+ fov, horizon = 120, 400
17
+ camera = torch.tensor([[fov, horizon]])
18
+ planes = torch.tensor([[0.0, 0.0, 0.0, 0.0, 1.0, 0.0]])
19
+
20
+ camera = camera.repeat(5,1).float().to(device)
21
+ planes = planes.repeat(5,1).float().to(device)
22
+
23
+ ground_vis = plane_visualize.forward(planes, camera, int(512), int(512))[0]
24
+ return ground_vis
25
+
26
+ t = time.time()
27
+ ground_vis = test_ground()
28
+ print('{} s'.format(time.time() - t))
29
+ batch = ground_vis.shape[0]
30
+ for bi in range(batch):
31
+ img = ground_vis[bi].detach().cpu().numpy().transpose(1,2,0)
32
+ img = np.clip(img, 0.0, 1.0)
33
+ plt.imsave(join(test_output, 'ground_{}.png'.format(bi)),img)
PixHtLab-Src/Demo/PixhtLab/Torch_Render/test_hshadow.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ import hshadow
4
+ import matplotlib.pyplot as plt
5
+ from PIL import Image
6
+ from torchvision import transforms
7
+ import os
8
+ from os.path import join
9
+ import numpy as np
10
+ from scipy.ndimage import uniform_filter
11
+
12
+ test_output = 'imgs/output'
13
+ os.makedirs(test_output, exist_ok=True)
14
+ def test_shadow(rgb, mask, hmap, rechmap, light_pos):
15
+ h,w = rgb.shape[:2]
16
+ bb = torch.tensor([[0, h-1, 0, w-1]]).float().to(device)
17
+ start = time.time()
18
+ shadow = hshadow.forward(rgb, mask, bb, hmap, rechmap, light_pos)[0]
19
+ end = time.time()
20
+
21
+ print('Shadow rendering: {}s'.format(end-start))
22
+ res = (1.0-mask) * rgb * shadow + mask * rgb
23
+ return res, shadow
24
+
25
+ def frenel_reflect(reflect_tensor, reflect_mask, fov, ref_ind):
26
+ # Use schilik approximation https://en.wikipedia.org/wiki/Schlick%27s_approximation
27
+ def deg2rad(deg):
28
+ return deg/180.0 * 3.1415926
29
+
30
+ def img2cos(reflect_img, fov, horizon):
31
+ # Note, this factor needs calibration if we have camera parameters
32
+ b, c, h, w = reflect_img.shape
33
+ focal = 0.5 * h / np.tan(deg2rad(0.5 * fov))
34
+ fadding_map = torch.arange(0, h).unsqueeze(1).expand(h, w).unsqueeze(0).unsqueeze(0).repeat(b, c, 1, 1)
35
+ fadding_map = focal / torch.sqrt((fadding_map-horizon)**2 + focal **2)
36
+ return fadding_map.to(reflect_img)
37
+
38
+
39
+ ind = (1.0-ref_ind)/(1.0+ref_ind)
40
+ ind = ind ** 2
41
+ h = reflect_tensor.shape[2]
42
+ horizon = h * 0.7
43
+ cos_map = img2cos(reflect_tensor, fov, horizon)
44
+ # fadding = 1.0 - (ind + (1.0-ind) * torch.pow(1.0-cos_map, 4))
45
+ b, c, h, w = reflect_tensor.shape
46
+ fadding = torch.linspace(3.0,0.0,h)[None, None, ..., None].repeat(b,c,1,w).to(reflect_tensor) ** 4
47
+ fadding = torch.clip(fadding, 0.0, 1.0)
48
+ plt.imsave('test_fadding.png', fadding[0].detach().cpu().numpy().transpose(1,2,0))
49
+ reflect_mask = reflect_mask.repeat(1,3,1,1)
50
+ return fadding * reflect_tensor * reflect_mask + (1.0-reflect_mask * fadding) * torch.ones_like(reflect_tensor)
51
+
52
+ def refine_boundary(output, filter=3):
53
+ return output
54
+ h,w,c = output.shape
55
+ for i in range(c):
56
+ output[...,i] = uniform_filter(output[...,i], size=filter)
57
+ return output
58
+
59
+ def height_fadding(reflect, reflect_height, reflect_mask, fadding_factor):
60
+ def np_sigmoid(a):
61
+ return 1.0/(1.0+np.exp(-a))
62
+
63
+ h,w,c = reflect.shape
64
+ reflect_h = reflect_height/h
65
+ reflect_h = reflect_h/reflect_h.max()
66
+ fadding = (1.0-(np_sigmoid(reflect_h * fadding_factor)-0.5)* 2.0) * reflect_mask
67
+ after_fadding = fadding * reflect + (1.0-fadding) * np.ones_like(reflect)
68
+ return after_fadding
69
+
70
+ def to_numpy(tensor):
71
+ return tensor[0].detach().cpu().numpy().transpose(1,2,0)
72
+
73
+ def test_reflect(rgb, mask, hmap, rechmap, thresholds=1.5, fadding_factor=10.0):
74
+ b, c, h, w = rgb.shape
75
+ # thresholds = torch.tensor([[1.0 + i/b] for i in range(b)]).float().to(device)
76
+ thresholds = torch.tensor([[thresholds]]).float().to(device)
77
+ start = time.time()
78
+ reflect, reflect_height, reflect_mask = hshadow.reflection(rgb, mask, hmap, rechmap, thresholds)
79
+ end = time.time()
80
+ print('Reflection rendering: {}s'.format(end-start))
81
+
82
+ reflect, reflect_height, reflect_mask = to_numpy(reflect), to_numpy(reflect_height), to_numpy(reflect_mask)
83
+ # reflect = frenel_reflect(reflect, reflect_height, reflect_mask, 175, 0.9)
84
+ refine_reflect, refine_reflect_height = refine_boundary(reflect), refine_boundary(reflect_height)
85
+ refine_reflect_mask = reflect_mask
86
+ reflect = height_fadding(refine_reflect, refine_reflect_height, refine_reflect_mask, fadding_factor)
87
+ rgb, mask = to_numpy(rgb), to_numpy(mask)
88
+ res = (1.0-mask) * rgb * reflect + mask * rgb
89
+ return res, reflect
90
+
91
+
92
+ def test_glossy_reflect(rgb, mask, hmap, rechmap, sample, glossy, fadding_factor=10.0):
93
+ b, c, h, w = rgb.shape
94
+ # thresholds = torch.tensor([[1.0 + i/b] for i in range(b)]).float().to(device)
95
+ start = time.time()
96
+ reflect = hshadow.glossy_reflection(rgb, mask, hmap, rechmap, sample, glossy)[0]
97
+ end = time.time()
98
+ print('Reflection rendering: {}s'.format(end-start))
99
+
100
+ reflect = to_numpy(reflect)
101
+ refine_reflect = refine_boundary(reflect)
102
+ rgb, mask = to_numpy(rgb), to_numpy(mask)
103
+ res = (1.0-mask) * rgb * reflect + mask * rgb
104
+ return res, reflect
105
+
106
+
107
+ device = torch.device("cuda:0")
108
+ to_tensor = transforms.ToTensor()
109
+ # for i in range(1,5):
110
+ i = 2
111
+ if True:
112
+ prefix = 'canvas{}'.format(i)
113
+ rgb, mask, hmap = to_tensor(Image.open('imgs/{}_rgb.png'.format(prefix)).convert('RGB')).to(device), to_tensor(Image.open('imgs/{}_mask.png'.format(prefix)).convert('RGB'))[0:1].to(device), to_tensor(Image.open('imgs/{}_height.png'.format(prefix)).convert('RGB'))[0:1].to(device)
114
+ h, w = hmap.shape[1:]
115
+ hmap = hmap * h * 0.45
116
+
117
+ rechmap = torch.zeros_like(hmap)
118
+ rgb, mask, hmap, rechmap = rgb.unsqueeze(dim=0), mask.unsqueeze(dim=0), hmap.unsqueeze(dim=0), rechmap.unsqueeze(dim=0)
119
+ lightpos = torch.tensor([[300, -100, 200.0]]).to(device)
120
+ shadow_res, shadow = test_shadow(rgb, mask, hmap, rechmap, lightpos)
121
+ reflect_res, reflect = test_reflect(rgb, mask, hmap, rechmap, thresholds=2.5, fadding_factor=18.0)
122
+
123
+ glossy_reflect_res, glossy_reflect = test_glossy_reflect(rgb, mask, hmap, rechmap, sample=10, glossy=0.5, fadding_factor=18.0)
124
+
125
+ plt.imsave(join(test_output, prefix + "_shadow_final.png"), shadow_res[0].detach().cpu().numpy().transpose(1,2,0))
126
+ plt.imsave(join(test_output, prefix + "_shadow.png"), shadow[0].detach().cpu().numpy().transpose(1,2,0))
127
+ plt.imsave(join(test_output, prefix + "_reflect_final.png"), reflect_res)
128
+ plt.imsave(join(test_output, prefix + "_reflect.png"), reflect)
129
+ plt.imsave(join(test_output, prefix + "_glossy_reflect_final.png"), glossy_reflect_res)
130
+ plt.imsave(join(test_output, prefix + "_glossy_reflect.png"), glossy_reflect)
PixHtLab-Src/Demo/PixhtLab/camera.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import math
3
+ from abc import ABC
4
+ import copy
5
+
6
+ class camera(ABC):
7
+ def __init__(self, hfov, h, w, height=100.0):
8
+ self.fov = hfov
9
+ self.h = h
10
+ self.w = w
11
+
12
+ self.ori_height = height
13
+ self.height = copy.deepcopy(self.ori_height)
14
+ self.O = np.array([0.0, self.height, 0.0]) # ray origianl
15
+
16
+
17
+ ######################################################################################
18
+ """ Abstraction
19
+ """
20
+ def align_horizon(self, cur_horizon):
21
+ raise NotImplementedError('Not implemented yet')
22
+
23
+ def C(self):
24
+ raise NotImplementedError('Not implemented yet')
25
+
26
+ def right(self):
27
+ raise NotImplementedError('Not implemented yet')
28
+
29
+ def up(self):
30
+ raise NotImplementedError('Not implemented yet')
31
+
32
+ ######################################################################################
33
+
34
+ def deg2rad(self, d):
35
+ return d / 180.0 * 3.1415925
36
+
37
+
38
+ def rad2deg(self, d):
39
+ return d / 3.1415925 * 180.0
40
+
41
+
42
+ def get_ray(self, xy):
43
+ """ Assume the center is on the top-left corner
44
+ """
45
+ u, v = xy
46
+ mat = self.get_ABC_mat()
47
+ r = np.dot(mat, np.array([u, v, 1.0]).T)
48
+
49
+ # r = r/np.sqrt(r @ r)
50
+ return r
51
+
52
+
53
+ def project(self, xyz):
54
+ relative = xyz - self.O
55
+
56
+ mat = self.get_ABC_mat()
57
+ pp = np.dot(np.linalg.inv(mat), relative)
58
+ pixel = np.array([pp[0]/pp[2], pp[1]/pp[2]])
59
+
60
+ return pixel
61
+
62
+
63
+ def xyh2w(self, xyh):
64
+ u, v, h = xyh
65
+
66
+ foot_xyh = np.copy(xyh)
67
+ foot_xyh[1] = foot_xyh[1] + foot_xyh[2]
68
+ foot_xyh[2] = 0.0
69
+ fu, fv, fh = foot_xyh
70
+
71
+ a = self.right()
72
+ b = -self.up()
73
+ c = self.C()
74
+ mat = self.get_ABC_mat()
75
+
76
+ w = -self.height/(a[1] * fu + b[1] * fv + c[1])
77
+ return w
78
+
79
+
80
+ def xyh2xyz(self, xyh):
81
+ u, v, h = xyh
82
+
83
+ foot_xyh = np.copy(xyh)
84
+ foot_xyh[1] = foot_xyh[1] + foot_xyh[2]
85
+ foot_xyh[2] = 0.0
86
+ fu, fv, fh = foot_xyh
87
+
88
+ a = self.right()
89
+ b = -self.up()
90
+ c = self.C()
91
+ mat = self.get_ABC_mat()
92
+
93
+ w = -self.height/(a[1] * fu + b[1] * fv + c[1])
94
+ # print('w: {} a*u + b * v + c: {}, b/c: {}/{}, fv: {}'.format(w, a[1] * fu + b[1] * fv + c[1], b, c, fv))
95
+ xyz = self.O + np.dot(mat, np.array([u, v, 1.0]).T) * w
96
+
97
+ # print('w: {}, -{}/{}'.format(w, self.height, a[1] * fu + b[1] * fv + c[1]))
98
+
99
+ return xyz
100
+
101
+
102
+ def xyz2xyh(self, xyz):
103
+ foot_xyz = np.copy(xyz)
104
+ foot_xyz[1] = 0.0
105
+
106
+ foot_xy = self.project(foot_xyz)
107
+ xy = self.project(xyz)
108
+
109
+ ret = np.copy(xyz)
110
+ ret[:2] = xy
111
+ ret[2] = foot_xy[1] - xy[1]
112
+
113
+ return ret
114
+
115
+
116
+ def get_ABC_mat(self):
117
+ a = self.right()
118
+ b = -self.up()
119
+ c = self.C()
120
+
121
+ mat = np.concatenate([a[:, None], b[:,None], c[:, None]], axis=1)
122
+ return mat
123
+
124
+
125
+ class pitch_camera(camera):
126
+ """ Picth alignment camera
127
+ """
128
+ def __init__(self, hfov, h, w, height=100.0):
129
+ """
130
+ alignment algorithm:
131
+ 1. pitch alignment
132
+ 2. axis alignment
133
+ """
134
+ super().__init__(hfov, h, w, height)
135
+
136
+ self.ori_view = np.array([0.0, 0.0, -1.0])
137
+ self.cur_view = np.copy(self.ori_view)
138
+
139
+
140
+ def align_horizon(self, cur_horizon):
141
+ """ Given horizon, compute the camera pitch
142
+ """
143
+ ref_horizon = self.h / 2
144
+ rel_distance = -(ref_horizon - cur_horizon)
145
+
146
+ focal = self.focal()
147
+ pitch = math.atan2(rel_distance, focal)
148
+
149
+ # construct a rotation matrix
150
+ c, s = np.cos(pitch), np.sin(pitch)
151
+ rot = np.array([[0, 0, 0], [0, c, -s], [0, s, c]])
152
+
153
+ # compute the new view vector
154
+ img_plane_view = self.ori_view * focal
155
+ img_plane_view = rot @ img_plane_view.T
156
+
157
+ self.cur_view = img_plane_view / math.sqrt(np.dot(img_plane_view, img_plane_view))
158
+
159
+ def C(self):
160
+ return self.view() * self.focal() - 0.5 * self.w * self.right() + 0.5 * self.h * self.up()
161
+
162
+
163
+ def right(self):
164
+ return np.array([1.0, 0.0, 0.0])
165
+
166
+
167
+ def up(self):
168
+ return np.cross(self.right(), self.view())
169
+
170
+
171
+ def focal(self):
172
+ focal = self.w * 0.5 / math.tan(self.deg2rad(self.fov * 0.5))
173
+ return focal
174
+
175
+
176
+ def view(self):
177
+ return self.cur_view
178
+
179
+
180
+
181
+ class axis_camera(camera):
182
+ """ Axis alignment camera
183
+ """
184
+ def __init__(self, hfov, h, w, height=100.0):
185
+ super().__init__(hfov, h, w, height)
186
+
187
+ focal = self.w * 0.5 / math.tan(self.deg2rad(self.fov * 0.5))
188
+ self.up_vec = np.array([0.0, 1.0, 0.0])
189
+ self.right_vec = np.array([1.0, 0.0, 0.0])
190
+
191
+ self.ori_c = np.array([-0.5 * self.w, 0.5 * self.h, -focal])
192
+ self.c_vec = np.copy(self.ori_c)
193
+
194
+
195
+ def align_horizon(self, cur_horizon):
196
+ """ Given horizon, we move the axis to update the horizon
197
+ i.e. we need to change C
198
+ """
199
+ ref_horizon = self.h // 2
200
+ delta_horizon = cur_horizon - ref_horizon
201
+ self.c_vec = self.ori_c + delta_horizon * self.up()
202
+ # self.height = self.ori_height + delta_horizon
203
+ # self.O = np.array([0.0, self.height, 0.0])
204
+
205
+
206
+
207
+ def C(self):
208
+ return self.c_vec
209
+
210
+
211
+ def right(self):
212
+ return self.right_vec
213
+
214
+
215
+ def up(self):
216
+ return self.up_vec
217
+
218
+
219
+ def test(ppc):
220
+ xyh = np.array([500, 500, 100.0])
221
+
222
+ proj_xyz = ppc.xyh2xyz(xyh)
223
+ proj_xyh = ppc.xyz2xyh(proj_xyz)
224
+
225
+ print('xyh: {}, proj xyz: {}, proj xyh: {}'.format(xyh, proj_xyz, proj_xyh))
226
+
227
+ # import pdb; pdb.set_trace()
228
+ new_horizon_list = [0, 100, 250, 400, 500]
229
+ new_horizon_list = [100, 250, 400, 500]
230
+
231
+ # import pdb; pdb.set_trace()
232
+ for cur_horizon in new_horizon_list:
233
+ ppc.align_horizon(cur_horizon)
234
+ test_xyh = np.array([500, cur_horizon, 0])
235
+ test_xyz = ppc.xyh2xyz(test_xyh)
236
+
237
+ # print('{} -> {} -> {}'.format(test_xyh, test_xyz, ppc.xyz2xyh(test_xyz)))
238
+ print('{} \t -> {} \t -> {}'.format(test_xyh, test_xyz, ppc.xyz2xyh(test_xyz)))
239
+
240
+
241
+ if __name__ == '__main__':
242
+ p_camera = pitch_camera(90.0, 500, 500)
243
+ a_camera = axis_camera(90.0, 500, 500)
244
+
245
+ test(p_camera)
246
+ test(a_camera)
PixHtLab-Src/Demo/PixhtLab/gssn_demo.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import torch
3
+ import gradio as gr
4
+ from torch import nn
5
+ import numpy as np
6
+
7
+
8
+ def render_btn_fn(mask, background, buffers, pitch, roll, softness):
9
+ print('Pitch and roll: {}, {}'.format(pitch, roll))
10
+ print('Mask, background, bufferss: {}, {}, {}'.format(mask.shape, background.shape, buffers.shape))
11
+ pass
12
+
13
+
14
+ with gr.Blocks() as demo:
15
+ with gr.Row():
16
+ mask_input = gr.Image(shape=(256, 256), image_mode="L", label="Mask")
17
+ bg_input = gr.Image(shape=(256, 256), image_mode="RGB", label="Background")
18
+ buff_input = gr.Image(shape=(256, 256), image_mode="RGB", label="Buffers")
19
+
20
+ with gr.Row():
21
+ with gr.Column():
22
+ pitch_input = gr.Slider(minimum=0, maximum=1, step=0.01, default=0.5, label="Pitch")
23
+ roll_input = gr.Slider(minimum=0, maximum=1, step=0.01, default=0.5, label="Roll")
24
+ softness_input = gr.Slider(minimum=0, maximum=1, step=0.01, default=0.5, label="Softness")
25
+
26
+ render_btn = gr.Button(label="Render")
27
+ output = gr.Image(shape=(256, 256), image_mode="RGB", label="Output")
28
+
29
+ render_btn.click(render_btn_fn, inputs=[mask_input, bg_input, buff_input, pitch_input, roll_input, softness_input], outputs=output)
30
+
31
+
32
+ demo.launch()
PixHtLab-Src/Demo/PixhtLab/hshadow_render.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ import hshadow
4
+ # import plane_visualize
5
+ import numpy as np
6
+ from torchvision import transforms
7
+ from scipy.ndimage import uniform_filter
8
+ from ShadowStyle.inference import inference_shadow
9
+ import cv2
10
+ import matplotlib.pyplot as plt
11
+ from utils import *
12
+ from GSSN.inference_shadow import SSN_Infernece
13
+
14
+ device = torch.device("cuda:0")
15
+ to_tensor = transforms.ToTensor()
16
+ model = inference_shadow.init_models('/home/ysheng/Documents/Research/GSSN/HardShadow/qtGUI/weights/human_baseline_all_21-July-04-52-AM.pt')
17
+ # GSSN_model = SSN_Infernece('GSSN/weights/0000000700.pt')
18
+ GSSN_model = SSN_Infernece('/home/ysheng/Documents/Research/GSSN/HardShadow/qtGUI/GSSN/weights/only_shadow/0000000200.pt')
19
+
20
+ def crop_mask(mask):
21
+ hnon, wnon = np.nonzero(mask)
22
+ aabb = (hnon.min(), hnon.max(), wnon.min(), wnon.max())
23
+ return aabb
24
+
25
+ def norm_output(np_img):
26
+ return np.clip(cv2.normalize(np_img, None, 0.0, 1.0, cv2.NORM_MINMAX),0.0,1.0)
27
+
28
+ def padding(mask, shadow, mask_aabb, shadow_aabb, final_shape=(512, 512)):
29
+ mh, mhh, mw, mww = mask_aabb
30
+ sh, shh, sw, sww = shadow_aabb
31
+ cropped_mask, cropped_shadow = mask[mh:mhh, mw:mww], shadow[sh:shh, sw:sww]
32
+ global_h, global_w = mask.shape[:2]
33
+ h, w, c, sc = *cropped_mask.shape, shadow.shape[2]
34
+ fract = 0.4
35
+ if h > w:
36
+ newh = int(final_shape[0]*fract)
37
+ neww = int(newh/h*w)
38
+ else:
39
+ neww = int(final_shape[1]*fract)
40
+ newh = int(neww/w*h)
41
+
42
+ small_mask = cv2.resize(cropped_mask, (neww, newh), interpolation=cv2.INTER_AREA)
43
+ if len(small_mask.shape) == 2:
44
+ small_mask = small_mask[...,np.newaxis]
45
+
46
+ mask_ret, shadow_ret = np.zeros((final_shape[0], final_shape[1], c)),np.ones((final_shape[0], final_shape[1], sc))
47
+ paddingh, paddingw = 10, (final_shape[0]-neww)//2
48
+ mask_lpos = (paddingh, paddingw)
49
+ mask_ret = overlap_replace(mask_ret, small_mask, mask_lpos)
50
+
51
+ # padding shadow
52
+ hscale, wscale = newh/h, neww/w
53
+ newsh, newsw = int((shh-sh) * hscale), int((sww-sw) * wscale)
54
+ small_shadow = cv2.resize(cropped_shadow, (newsw, newsh), interpolation=cv2.INTER_AREA)
55
+
56
+ if len(small_shadow.shape) == 2:
57
+ small_shadow = small_shadow[...,np.newaxis]
58
+
59
+
60
+ loffseth, loffsetw = int((sh-mh)*hscale), int((sw-mw)*wscale)
61
+ shadow_lpos = (paddingh + loffseth, paddingw + loffsetw)
62
+ shadow_ret = overlap_replace(shadow_ret, small_shadow, shadow_lpos)
63
+
64
+ # return mask_ret, shadow_ret[...,0:1], [mask_aabb, mask_lpos, hscale, wscale, final_shape, mask.shape[0], mask.shape[1]]
65
+ return mask_ret, shadow_ret, [mask_aabb, mask_lpos, hscale, wscale, final_shape, mask.shape[0], mask.shape[1]]
66
+
67
+
68
+ def transform_input(mask, hardshadow):
69
+ """ Note, trans_info marks the AABBs, and scaling factors
70
+ """
71
+ mask_aabb, shadow_aabb = crop_mask(mask[...,0]), crop_mask(hardshadow[...,0])
72
+ # import pdb; pdb.set_trace()
73
+ cmask, cshadow, trans_info = padding(mask, hardshadow, mask_aabb, shadow_aabb)
74
+ return cmask.transpose(2,0,1)[np.newaxis,...], 1.0 - cshadow.transpose(2,0,1)[np.newaxis, ...], trans_info
75
+
76
+
77
+ def transform_output(softshadow, trans_info):
78
+ mask_aabb, mask_lpos, hscale, wscale, final_shape, h, w = trans_info
79
+ # import pdb; pdb.set_trace()
80
+ ret, gsh, gsw = np.zeros((h,w,1)), int(final_shape[0]/hscale), int(final_shape[1]/wscale)
81
+ global_shadow = cv2.resize(softshadow[0,0], (gsw, gsh))
82
+
83
+ # global start = global_mask_aabb - (local_mask_start)/scaling
84
+ mh, mw, mask_lh, mask_lw = mask_aabb[0], mask_aabb[2], mask_lpos[0], mask_lpos[1]
85
+ starth, startw = int(mh - mask_lh / hscale), int(mw - mask_lw / wscale)
86
+ ret = norm_output(overlap_replace(ret, global_shadow[...,np.newaxis], (starth, startw)))
87
+ if len(ret.shape) == 2:
88
+ ret = ret[..., np.newaxis]
89
+
90
+ return 1.0-ret.repeat(3,axis=2)
91
+
92
+ def style_hardshadow(mask, hardshadow, softness):
93
+ mask_net, hardshadow_net, trans_info = transform_input(mask, hardshadow)
94
+ netsoftshadow = inference_shadow.net_render_np(model, mask_net, hardshadow_net, softness, 0.0)
95
+ softshadow = transform_output(netsoftshadow, trans_info)
96
+
97
+ return softshadow, (norm_output(mask_net[0,0]), norm_output(hardshadow_net[0,0]), norm_output(netsoftshadow[0,0]))
98
+
99
+ def gssn_shadow(mask, pixel_height, shadow_channels, softness):
100
+ # mask_net, hardshadow_net, trans_info = transform_input(mask, shadow_channels)
101
+
102
+ mask_aabb, shadow_aabb = crop_mask(mask[...,0]), crop_mask(shadow_channels[...,0])
103
+ ph_channel, hardshadow_net, trans_info = padding(pixel_height, shadow_channels, mask_aabb, shadow_aabb)
104
+
105
+ ph_channel = ph_channel/512.0
106
+ hardshadow_net = 1.0-hardshadow_net
107
+ input_np = np.concatenate([ph_channel, hardshadow_net], axis=2)
108
+
109
+ # import pdb; pdb.set_trace()
110
+
111
+ netsoftshadow = np.clip(GSSN_model.render_ss(input_np, softness), 0.0, 1.0)
112
+ netsoftshadow = netsoftshadow.transpose((2,0,1))[None, ...]
113
+ softshadow = transform_output(netsoftshadow, trans_info)
114
+
115
+ return softshadow
116
+
117
+
118
+ def proj_ground(p, light_pos):
119
+ tmpp = p.copy()
120
+
121
+ t = (0-tmpp[2])/(light_pos[:, 2:3]-tmpp[2]+1e-6)
122
+ tmpp = (1.0-t) * tmpp[:2] + t * light_pos[:, :2]
123
+ return tmpp
124
+
125
+ def proj_bb(mask, hmap, light_pos, mouse_pos):
126
+ tmp_lights = light_pos.copy()
127
+ if len(light_pos.shape) == 1:
128
+ tmp_lights = tmp_lights[..., np.newaxis]
129
+
130
+ # bb -> four points
131
+ highest = hmap.max()
132
+ highest_h, highest_w = list(np.unravel_index(np.argmax(hmap), hmap.shape))
133
+ hbb, wbb = np.nonzero(mask)
134
+ h, hh, w, ww = hbb.min(), hbb.max(), wbb.min(), wbb.max()
135
+ bb0, bb1, bb2, bb3 = np.array([w, h, hmap.max()]), np.array([ww, h, hmap.max()]), np.array([w, hh, 0]), np.array([ww, hh, 0])
136
+
137
+ # compute projection for the four points
138
+ tmp_lights = tmp_lights.transpose(1,0)
139
+ bb0, bb1, bb2, bb3 = proj_ground(bb0, tmp_lights), proj_ground(bb1, tmp_lights), proj_ground(bb2, tmp_lights), proj_ground(bb3, tmp_lights)
140
+
141
+ batch = len(tmp_lights)
142
+ new_bb = np.zeros((batch, 4))
143
+ for i in range(batch):
144
+ new_bb[i, 0] = min([bb0[i, 1], bb1[i,1], bb2[i, 1], bb3[i, 1], mouse_pos[1], h]) # h
145
+ new_bb[i, 1] = max([bb0[i, 1], bb1[i,1], bb2[i, 1], bb3[i, 1], mouse_pos[1], hh])
146
+ new_bb[i, 2] = min([bb0[i, 0], bb1[i,0], bb2[i, 0], bb3[i, 0], mouse_pos[0], w]) # w
147
+ new_bb[i, 3] = max([bb0[i, 0], bb1[i,0], bb2[i, 0], bb3[i, 0], mouse_pos[0], ww])
148
+
149
+ return new_bb
150
+
151
+ def to_torch_device(np_img):
152
+ if len(np_img.shape) == 3:
153
+ return to_tensor(np_img).float().unsqueeze(dim=0).contiguous().to(device)
154
+ else:
155
+ return torch.from_numpy(np_img).float().contiguous().to(device)
156
+
157
+ def hshadow_render(rgb, mask, hmap, rechmap, light_pos, mouse_pos):
158
+ """ Heightmap Shadow Rendering
159
+ rgb: H x W x e
160
+ mask: H x W x 1
161
+ hmap: H x W x 1
162
+ rechmap: H x W x 1
163
+ light_pos: (3,B)
164
+ return:
165
+ shadow masking
166
+ """
167
+
168
+ hbb, wbb = np.nonzero(mask[...,0])
169
+ # speed optimization
170
+ bb = proj_bb(mask[...,0], hmap[...,0], light_pos, mouse_pos)
171
+
172
+ # import pdb; pdb.set_trace()
173
+ if len(light_pos.shape) == 1:
174
+ light_pos_d = torch.from_numpy(light_pos).to(device).unsqueeze(dim=0).float()
175
+ rgb_d, mask_d, hmap_d, rechmap_d = to_torch_device(rgb), to_torch_device(mask), to_torch_device(hmap), to_torch_device(rechmap)
176
+ bb_d = torch.from_numpy(bb).float().to(device)
177
+ batch = 1
178
+ else:
179
+ light_pos_d = torch.from_numpy(np.ascontiguousarray(light_pos.transpose(1,0))).float().to(device)
180
+ batch = len(light_pos_d)
181
+ h,w = rgb.shape[:2]
182
+ rgb_d = to_torch_device(np.repeat(rgb[np.newaxis,...].transpose(0,3,1,2), batch, axis=0))
183
+ mask_d = to_torch_device(np.repeat(mask[np.newaxis,...].transpose(0,3,1,2), batch, axis=0))
184
+ hmap_d = to_torch_device(np.repeat(hmap[np.newaxis,...].transpose(0,3,1,2), batch, axis=0))
185
+ rechmap_d = to_torch_device(np.repeat(rechmap[np.newaxis,...].transpose(0,3,1,2), batch, axis=0))
186
+ bb_d = torch.from_numpy(np.ascontiguousarray(bb)).float().to(device)
187
+
188
+ shadow = hshadow.forward(rgb_d, mask_d, bb_d, hmap_d, rechmap_d, light_pos_d)
189
+ # mask_top_pos = list(np.unravel_index(np.argmax(hmap), hmap.shape))
190
+ # x,y = mask_top_pos[1], mask_top_pos[0]
191
+ # mh = hmap[y,x,0]
192
+ # light_top_d = light_pos_d - torch.tensor([[x,y,mh]]).to(light_pos_d)
193
+ # weights = torch.abs(light_top_d[:,2]/torch.sqrt((light_top_d[:,0] **2 + light_top_d[:,1] **2)))
194
+ # print('weights: ', weights)
195
+ # weights = (weights)/weights.sum()
196
+
197
+ # print(weights.shape, shadow[0].shape)
198
+ # flipped = (weights[...,None, None,None] * (1.0-shadow[0])).sum(dim=0, keepdim=True)
199
+ # shadow = shadow[0].sum(dim=0, keepdim=True)/len(shadow[0])
200
+ # return (1.0-flipped)[0].detach().cpu().numpy().transpose(1,2,0)
201
+
202
+ shadow = shadow[0].sum(dim=0, keepdim=True)/len(shadow[0])
203
+ return shadow[0].detach().cpu().numpy().transpose(1,2,0)
204
+
205
+ def refine_shadow(shadow, intensity=0.6, filter=5):
206
+ shadow[...,0] = uniform_filter(shadow[...,0], size=filter)
207
+ shadow[...,1] = uniform_filter(shadow[...,1], size=filter)
208
+ shadow[...,2] = uniform_filter(shadow[...,2], size=filter)
209
+ return 1.0 - (1.0-shadow) * intensity
210
+
211
+ def render_ao(rgb, mask, hmap):
212
+ rechmap = np.zeros_like(hmap)
213
+ hbb, wbb = np.nonzero(mask[...,0])
214
+ # light_pos = np.array([hbb.min(), (wbb.min() + wbb.max()) * 0.8, -100000])
215
+ light_pos = np.array([-1300.10811363, -46999.86253089, 46486.73121776])
216
+ mouse_pos = light_pos
217
+
218
+ shadow = hshadow_render(rgb, mask, hmap, rechmap, light_pos, mouse_pos)
219
+ softshadow = style_hardshadow(mask, shadow[..., :1], 0.45)[0]
220
+ softshadow = refine_shadow(softshadow)
221
+ return softshadow
222
+
223
+ def ao_composite(rgb, mask, hmap, rechmap, light_pos, mouse_pos):
224
+ # shadow = hshadow_render(rgb, mask, hmap, rechmap, light_pos, mouse_pos)
225
+ # softshadow = style_hardshadow(mask, shadow, 0.45)[0]
226
+ # softshadow = refine_shadow(softshadow)
227
+
228
+ softshadow = render_ao(rgb, mask, hmap)
229
+ mask_ = np.repeat(mask, 3, axis=2)
230
+ return (1.0-mask_) * softshadow * rgb + mask_ * rgb, softshadow.copy()
231
+
232
+
233
+ def render_shadow(rgb, mask, hmap, rechmap, light_pos, mouse_pos, softness, shadow_intensity=0.6):
234
+ shadow = hshadow_render(rgb, mask, hmap, rechmap, light_pos, mouse_pos)
235
+
236
+ if softness is not None:
237
+ shadow, dbgs = style_hardshadow(mask, shadow[..., :1], softness)
238
+ else:
239
+ dbgs = None
240
+
241
+ shadow = refine_shadow(shadow, intensity=shadow_intensity)
242
+ return shadow, dbgs
243
+
244
+
245
+ def hshadow_composite(rgb, mask, hmap, rechmap, light_pos, mouse_pos, softness, shadow_intensity=0.6):
246
+ """ Shadow Rendering and Composition
247
+ rgb: H x W x 3
248
+ mask: H x W x 1
249
+ hmap: H x W x 1
250
+ rechmap: H x W x 1
251
+ light_pos: [x,y,h]
252
+ return:
253
+ Compositied image
254
+ """
255
+ shadow, dbgs = render_shadow(rgb, mask, hmap, rechmap, light_pos, mouse_pos, softness, shadow_intensity)
256
+ mask_ = np.repeat(mask, 3, axis=2)
257
+ return (1.0-mask_) * shadow * rgb + mask_ * rgb, shadow.copy(), dbgs
258
+
259
+ # def vis_horizon(fov, horizon, h, w):
260
+ # # fov, horizon = 120, 400
261
+ # camera = torch.tensor([[fov, horizon]])
262
+ # planes = torch.tensor([[0.0, 0.0, 0.0, 0.0, 1.0, 0.0]])
263
+
264
+ # camera = camera.float().to(device)
265
+ # planes = planes.float().to(device)
266
+
267
+ # ground_vis = plane_visualize.forward(planes, camera, h, w)[0]
268
+ # return 1.0-ground_vis[0].detach().cpu().numpy().transpose(1,2,0)