3ZadeSSG commited on
Commit
99e2b6c
·
1 Parent(s): 908f07a

initial commit

Browse files
.gitattributes CHANGED
@@ -32,4 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # Models and Engines
7
+ *.onnx
8
+ *.onnx.data
9
+ *.pth
10
+ *.engine
11
+
12
+ # Images
13
+ *.png
14
+ *.jpeg
15
+ *.JPG
16
+
17
+ # Videos
18
+ *.mp4
19
+
20
+ # Logs
21
+ logs/
.huggingface.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ sdk: gradio
2
+ python_version: '3.12'
3
+ requirements_file: requirements.txt
README.md CHANGED
@@ -1,14 +1,97 @@
1
- ---
2
- title: PVSDNet Depth Only
3
- emoji: 🔥
4
- colorFrom: indigo
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 6.3.0
8
- app_file: app.py
9
- pinned: false
10
- license: agpl-3.0
11
- short_description: Monocular Depth Estimation Model for Real-Time Inference
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <a href="#"><img src='https://img.shields.io/badge/-Paper-00629B?style=flat&logo=ieee&logoColor=white' alt='arXiv'></a>
3
+ <a href='https://realistic3d-miun.github.io/PVSDNet/'><img src='https://img.shields.io/badge/Project_Page-Website-green?logo=googlechrome&logoColor=white' alt='Project Page'></a>
4
+ <a href='https://huggingface.co/spaces/3ZadeSSG/PVSDNet'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Demo_(Coming_Soon)-blue'></a>
5
+ </div>
6
+
7
+ # PVSDNet: Joint Depth Prediction and View Synthesis via Shared Latent Spaces in Real-Time.
8
+
9
+
10
+ ## Supplementary Video (Head to Project Page for more visual results)
11
+ [![Watch the video](https://img.youtube.com/vi/49s2UPvRA6I/maxresdefault.jpg)](https://youtu.be/49s2UPvRA6I)
12
+
13
+
14
+ # 1. PVSDNet - Joint Depth and View
15
+ **Note:** Will be added soon.
16
+
17
+ ## 1.A. Normal Inference (Recommended for minimal setup)
18
+ **Note:** Will be added soon.
19
+
20
+ ## 2.A. Faster Inference (For best possible FPS)
21
+ **Note:** Will be added soon.
22
+
23
+
24
+
25
+ # 2. PVSDNet Depth-Only Model
26
+ This model is a variant of the original PVSDNet model, where we only predict depth and not the target views. The model core is similar except the rendering network and the positional encoding are removed.
27
+
28
+ * Download the checkpoints from following table and place them in `checkpoint_onnx` directory.
29
+
30
+ | Model | Size | Checkpoint |
31
+ |-----------------|--------|----------------|
32
+ | PVSDNet-Depth-Only | 1.11 GB| [Download](https://huggingface.co/3ZadeSSG/PVSDNet-Depth-Only/resolve/main/depth_only_model.pth) |
33
+ | PVSDNet-Depth-Only-Lite | 279 MB | [Download](https://huggingface.co/3ZadeSSG/PVSDNet-Depth-Only/resolve/main/depth_only_lite_model.pth) |
34
+
35
+ ## 2.A. Normal Inference (Recommended for minimal setup)
36
+
37
+ ## 2.B. Faster Inference (For best possible FPS)
38
+ You need to setup your own TRT Engine for this purpose.
39
+
40
+ * Make sure you modify the `depth_only_parameters` to set resolution you need. By default we have kept it at `384x384`.
41
+
42
+ * Run `export_onnx_depth.py` to conver the normal pytorch models located into into onnx
43
+ ```
44
+ python export_onnx_depth.py
45
+ ```
46
+ * Create TRT Engine directory
47
+ ```
48
+ mkdir TRT_Engine
49
+ ```
50
+ * Build the TRT engine based on created onnx files (which by default will be located in `checkpoint_onnx`)
51
+ ```
52
+ trtexec --onnx=./checkpoint_onnx/depth_only_model.onnx --saveEngine=./TRT_Engine/depth_only_model_fp16.engine --fp16
53
+ ```
54
+ ```
55
+ trtexec --onnx=./checkpoint_onnx/depth_only_lite_model.onnx --saveEngine=./TRT_Engine/depth_only_lite_model_fp16.engine --fp16
56
+ ```
57
+
58
+
59
+ ## 2.C. Predicting on Depth Datasets using Multi-Resolution Fusion
60
+
61
+ We run the scripts inside the `depth_dataset_predictor` directory. There are two sample images for each dataset to test the code.
62
+ * First we build the TRT engine for each dataset as we use multi-resolution fusion.
63
+ ```
64
+ python depth_dataset_predictor/build_trt_<dataset_name>.py
65
+ ```
66
+ * Then we run the prediction script
67
+ ```
68
+ python depth_dataset_predictor/predict_<dataset_name>_TensorRT.py
69
+ ```
70
+
71
+ |Dataset|Setp 1|Step 2|
72
+ |---|---|---|
73
+ |ETH3D| ```python depth_dataset_predictor/build_trt_ETH3D.py``` | ```python depth_dataset_predictor/predict_ETH3D_TensorRT.py```|
74
+ |Sintel| ```python depth_dataset_predictor/build_trt_Sintel.py``` | ```python depth_dataset_predictor/predict_Sintel_TensorRT.py```|
75
+ |KITTI| ```python depth_dataset_predictor/build_trt_KITTI.py``` | ```python depth_dataset_predictor/predict_KITTI_TensorRT.py```|
76
+ |DIODE| ```python depth_dataset_predictor/build_trt_DIODE.py``` | ```python depth_dataset_predictor/predict_DIODE_TensorRT.py```|
77
+ |NYU| ```python depth_dataset_predictor/build_trt_NYU.py``` | ```python depth_dataset_predictor/predict_NYU_TensorRT.py```|
78
+
79
+
80
+ ## 2.D. Predicting on 1080p In-The-Wild Images/Videos using Multi-Resolution Fusion
81
+ Similar to dataset, we can use the mutli-resolution fusion to predict on 1080p In-The-Wild Images/Videos.
82
+
83
+ * First we build the trt engine
84
+ ```
85
+ python depth_in_wild_predictor/build_trt_1080p.py
86
+ ```
87
+ * Then we run the prediction script for images
88
+ ```
89
+ python depth_in_wild_predictor/predict_1080p_TensorRT.py
90
+ ```
91
+ OR, run the prediction script for videos
92
+ ```
93
+ python depth_in_wild_predictor/predict_video_1080p_TensorRT.py
94
+ ```
95
+ #### Note
96
+ * For any other resolutions, you can modify the resolutions in these above scripts to suit your needs. We have kept the default resolution as 1080p for this example.
97
+ * We recommend 3-6 resolutions for best results, but you can use 1-2 smaller resolutions if working with low reoslution images/videos since receptive field of the network can handle that without any issues.
__init__.py ADDED
File without changes
app.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ from PIL import Image
6
+ import torch.nn.functional as F
7
+ import torchvision.transforms as transforms
8
+ import depth_only_parameters as params
9
+
10
+ from models.depth_only_model import PVSDNet
11
+ from models.depth_only_lite_model import PVSDNet_Lite
12
+
13
+ import helperFunctions as helper
14
+ import socket
15
+ from huggingface_hub import hf_hub_download
16
+ import joblib
17
+
18
+ REPO_ID = "3ZadeSSG/PVSDNet-Depth-Only"
19
+ print("Downloading/Loading checkpoints from Hugging Face Hub...")
20
+ params.MODEL_Small_Location = hf_hub_download(
21
+ repo_id=REPO_ID,
22
+ filename="depth_only_lite_model.pth"
23
+ )
24
+
25
+ params.MODEL_Large_Location = hf_hub_download(
26
+ repo_id=REPO_ID,
27
+ filename="depth_only_model.pth"
28
+ )
29
+
30
+ print(f"Large Model loaded at: {params.MODEL_Large_Location}")
31
+ print(f"Lite Model loaded at: {params.MODEL_Small_Location}")
32
+
33
+
34
+ def get_valid_resolutions(width, height):
35
+ """Dynamically determines valid resolutions based on input size.
36
+ - Caps the highest resolution at 1024px to avoid unnecessary high-res computations.
37
+ - Uses 6 resolutions for large images to improve multi-scale fusion quality.
38
+ - Uses 4 resolutions for smaller images (< 512px width or height).
39
+ """
40
+ def make_divisible(n, base=16):
41
+ return max(base, int(round(n / base) * base))
42
+
43
+ max_resolution = 1024
44
+ high_w, high_h = make_divisible(min(width, max_resolution)), make_divisible(min(height, max_resolution))
45
+
46
+ # Calculate more intermediate steps for better fusion
47
+ r80_w, r80_h = make_divisible(int(high_w // 1.25)), make_divisible(int(high_h // 1.25))
48
+ r66_w, r66_h = make_divisible(int(high_w // 1.5)), make_divisible(int(high_h // 1.5))
49
+ r50_w, r50_h = make_divisible(int(high_w // 2)), make_divisible(int(high_h // 2))
50
+ r40_w, r40_h = make_divisible(int(high_w // 2.5)), make_divisible(int(high_h // 2.5))
51
+ r33_w, r33_h = make_divisible(max(256, int(high_w // 3))), make_divisible(max(256, int(high_h // 3)))
52
+
53
+ if width < 512 or height < 512:
54
+ return [(high_w, high_h), (r80_w, r80_h), (r66_w, r66_h), (r50_w, r50_h)]
55
+ else:
56
+ return [
57
+ (high_w, high_h),
58
+ (r80_w, r80_h),
59
+ (r66_w, r66_h),
60
+ (r50_w, r50_h),
61
+ (r40_w, r40_h),
62
+ (r33_w, r33_h)
63
+ ]
64
+
65
+
66
+ def get_transforms(resolutions):
67
+ return [transforms.Compose([transforms.Resize((h, w)), transforms.ToTensor()]) for w, h in resolutions]
68
+
69
+ def get_prediction(image, transform, model):
70
+ img_input = image.convert('RGB')
71
+ img_input = transform(img_input).unsqueeze(0).to(params.DEVICE)
72
+ depth_out = model(img_input).detach().squeeze(0).to("cpu")
73
+ return depth_out
74
+
75
+ def predict_single_image(image, model_type):
76
+ if image is None:
77
+ return None, None
78
+
79
+ # Select model class and checkpoint
80
+ if model_type == "Lite":
81
+ model_class = PVSDNet_Lite
82
+ checkpoint = params.MODEL_Small_Location
83
+ else: # Default to "Large"
84
+ model_class = PVSDNet
85
+ checkpoint = params.MODEL_Large_Location
86
+
87
+ model = model_class(total_image_input=params.params_number_input)
88
+ model = helper.load_Checkpoint(checkpoint, model, load_cpu=True)
89
+ model.to(params.DEVICE)
90
+ model.eval()
91
+
92
+ original_width, original_height = image.size
93
+
94
+ resolutions = get_valid_resolutions(original_width, original_height)
95
+ print(f"Resolutions: {resolutions} for Model Type: {model_type}")
96
+ transforms_list = get_transforms(resolutions)
97
+
98
+ depth_maps = [get_prediction(image, t, model) for t in transforms_list]
99
+
100
+ depth_maps_resized = [
101
+ F.interpolate(depth[None], (original_height, original_width), mode='bilinear', align_corners=False)[0, 0]
102
+ for depth in depth_maps
103
+ ]
104
+
105
+ depth_final = sum(depth_maps_resized) / len(depth_maps_resized)
106
+
107
+ depth_image = (depth_final - depth_final.min()) / (depth_final.max() - depth_final.min())
108
+
109
+ img_out = depth_image.numpy()
110
+ img_out_colored = plt.get_cmap('inferno')(img_out / np.max(img_out))[:, :, :3]
111
+ img_out_colored = (img_out_colored * 255).astype(np.uint8)
112
+
113
+ gray_scale_img_out = (depth_image.numpy() * 255).astype(np.uint8)
114
+
115
+ return Image.fromarray(img_out_colored), Image.fromarray(gray_scale_img_out)
116
+
117
+ with gr.Blocks(title="PVSDNet-Depth-Only Model", theme="default") as demo:
118
+ gr.Markdown(
119
+ """
120
+ ## PVSDNet-Depth-Only ZeroShot Relative Depth Estimation Model
121
+ * Upload an image and get its depth estimation with multi-scale fusion.
122
+ * Images use 2 - 6 resolutions for multi-scale fusion.
123
+
124
+ **Note:** Huggingface demo is running on CPU so inference speeds will be slow.
125
+ ### Head to our [Project Page](https://realistic3d-miun.github.io/PVSDNet/) for more details about the models.
126
+ """)
127
+
128
+ with gr.Row():
129
+ with gr.Column():
130
+ img_input = gr.Image(type="pil", label="RGB Image", height=384)
131
+ with gr.Accordion("Advanced Settings", open=False):
132
+ model_type_dropdown = gr.Dropdown(["Large", "Lite"], label="Model Type", value="Large")
133
+ generate_btn = gr.Button("Estimate Depth", variant="primary")
134
+
135
+ with gr.Column():
136
+ output_color = gr.Image(type="pil", label="Depth Map (Color)", height=384)
137
+ output_gray = gr.Image(type="pil", label="Depth Map (Grayscale)", height=384)
138
+
139
+ generate_btn.click(
140
+ fn=predict_single_image,
141
+ inputs=[img_input, model_type_dropdown],
142
+ outputs=[output_color, output_gray]
143
+ )
144
+
145
+ gr.Markdown("### Example Samples")
146
+ with gr.Column():
147
+ with gr.Row():
148
+ with gr.Column(scale=2): gr.Markdown("**Example Image (Click to load)**")
149
+ with gr.Column(scale=1): gr.Markdown("**Resolution**")
150
+ with gr.Column(scale=2): gr.Markdown("**Fusion Resolutions**")
151
+
152
+ with gr.Row(variant="panel"):
153
+ with gr.Column(scale=2):
154
+ diode_preview = gr.Image("./samples/DIODE/00022_00195_outdoor_010_030.png", label="DIODE", height=120, interactive=False, show_label=True)
155
+ with gr.Column(scale=1):
156
+ gr.Markdown("1024 x 768")
157
+ with gr.Column(scale=2):
158
+ gr.Markdown("1024x768, 816x608, 688x512, 512x384, 416x304, 336x256")
159
+
160
+ with gr.Row(variant="panel"):
161
+ with gr.Column(scale=2):
162
+ eth3d_preview = gr.Image("./samples/ETH3D/DSC_0243.JPG", label="ETH3D", height=120, interactive=False, show_label=True)
163
+ with gr.Column(scale=1):
164
+ gr.Markdown("6048 x 4032")
165
+ with gr.Column(scale=2):
166
+ gr.Markdown("1024x1024, 816x816, 688x688, 512x512, 416x416, 336x336")
167
+
168
+ with gr.Row(variant="panel"):
169
+ with gr.Column(scale=2):
170
+ sintel_preview = gr.Image("./samples/Sintel/frame_0028_temple.png", label="Sintel", height=120, interactive=False, show_label=True)
171
+ with gr.Column(scale=1):
172
+ gr.Markdown("1024 x 436")
173
+ with gr.Column(scale=2):
174
+ gr.Markdown("1024x432, 816x352, 688x288, 512x224")
175
+
176
+ with gr.Row(variant="panel"):
177
+ with gr.Column(scale=2):
178
+ kitti_preview = gr.Image("./samples/KITTI/2011_10_03_drive_0047_sync_image_0000000383_image_02.png", label="KITTI", height=120, interactive=False, show_label=True)
179
+ with gr.Column(scale=1):
180
+ gr.Markdown("1216 x 532")
181
+ with gr.Column(scale=2):
182
+ gr.Markdown("1024x352, 816x288, 688x240, 512x176")
183
+
184
+ with gr.Row(variant="panel"):
185
+ with gr.Column(scale=2):
186
+ wild_1_preview = gr.Image("./samples/Wild/toy.jpeg", label="Wild Image 1", height=120, interactive=False, show_label=True)
187
+ with gr.Column(scale=1):
188
+ gr.Markdown("3019 x 3018")
189
+ with gr.Column(scale=2):
190
+ gr.Markdown("1024x1024, 816x816, 688x688, 512x512, 416x416, 336x336")
191
+
192
+ with gr.Row(variant="panel"):
193
+ with gr.Column(scale=2):
194
+ wild_2_preview = gr.Image("./samples/Wild/hamburg.jpeg", label="Wild Image 2", height=120, interactive=False, show_label=True)
195
+ with gr.Column(scale=1):
196
+ gr.Markdown("1536 x 1920")
197
+ with gr.Column(scale=2):
198
+ gr.Markdown("1024x1024, 816x816, 688x688, 512x512, 416x416, 336x336")
199
+
200
+ with gr.Row(variant="panel"):
201
+ with gr.Column(scale=2):
202
+ wild_3_preview = gr.Image("./samples/Wild/north_hill.jpeg", label="Wild Image 3", height=120, interactive=False, show_label=True)
203
+ with gr.Column(scale=1):
204
+ gr.Markdown("2320 x 2321")
205
+ with gr.Column(scale=2):
206
+ gr.Markdown("1024x1024, 816x816, 688x688, 512x512, 416x416, 336x336")
207
+
208
+ with gr.Row(variant="panel"):
209
+ with gr.Column(scale=2):
210
+ wild_4_preview = gr.Image("./samples/Wild/EH.jpeg", label="Wild Image 4", height=120, interactive=False, show_label=True)
211
+ with gr.Column(scale=1):
212
+ gr.Markdown("1920 x 1080")
213
+ with gr.Column(scale=2):
214
+ gr.Markdown("1024x1024, 816x816, 688x688, 512x512, 416x416, 336x336")
215
+
216
+ with gr.Row(variant="panel"):
217
+ with gr.Column(scale=2):
218
+ wild_5_preview = gr.Image("./samples/Wild/train_station.jpeg", label="Wild Image 5", height=120, interactive=False, show_label=True)
219
+ with gr.Column(scale=1):
220
+ gr.Markdown("1066 x 1060")
221
+ with gr.Column(scale=2):
222
+ gr.Markdown("1024x1024, 816x816, 688x688, 512x512, 416x416, 336x336")
223
+
224
+
225
+ # Define click events to load images
226
+ eth3d_preview.select(fn=lambda: Image.open("./samples/ETH3D/DSC_0243.JPG"), outputs=img_input)
227
+ sintel_preview.select(fn=lambda: Image.open("./samples/Sintel/frame_0028_temple.png"), outputs=img_input)
228
+ kitti_preview.select(fn=lambda: Image.open("./samples/KITTI/2011_10_03_drive_0047_sync_image_0000000383_image_02.png"), outputs=img_input)
229
+ diode_preview.select(fn=lambda: Image.open("./samples/DIODE/00022_00195_outdoor_010_030.png"), outputs=img_input)
230
+
231
+ wild_1_preview.select(fn=lambda: Image.open("./samples/Wild/toy.jpeg"), outputs=img_input)
232
+ wild_2_preview.select(fn=lambda: Image.open("./samples/Wild/hamburg.jpeg"), outputs=img_input)
233
+ wild_3_preview.select(fn=lambda: Image.open("./samples/Wild/north_hill.jpeg"), outputs=img_input)
234
+ wild_4_preview.select(fn=lambda: Image.open("./samples/Wild/EH.jpeg"), outputs=img_input)
235
+ wild_5_preview.select(fn=lambda: Image.open("./samples/Wild/train_station.jpeg"), outputs=img_input)
236
+
237
+
238
+ demo.launch()
depth_only_parameters.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ params_height = 384
4
+ params_width = 384
5
+
6
+ params_number_input = 1
7
+
8
+ LOG_FILE_LOCATION = "./logs/training_log_0.txt"
9
+ CHECKPOINT_LOCATION = "./checkpoint/"
10
+ DEVICE = "cpu"
11
+ ONNX_PATH = "./checkpoint_onnx"
12
+
13
+ MODEL_Large_Location = "./checkpoint/depth_only_model.pth"
14
+ MODEL_Small_Location = "./checkpoint/depth_only_lite_model.pth"
15
+
16
+ os.makedirs(ONNX_PATH,exist_ok=True)
17
+ os.makedirs("./logs",exist_ok=True)
18
+ os.makedirs("./checkpoint",exist_ok=True)
19
+ os.makedirs("./output",exist_ok=True)
20
+
21
+
helperFunctions.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import torch.nn.functional as F
4
+
5
+ def save_checkpoint(model, filelocation, save_parallel = True):
6
+ if save_parallel:
7
+ torch.save(model.module.state_dict(), filelocation)
8
+ else:
9
+ torch.save(model.state_dict(), filelocation)
10
+
11
+ def load_Checkpoint(fileLocation,model, load_cpu=False):
12
+ if load_cpu:
13
+ model.load_state_dict(torch.load(fileLocation,map_location=lambda storage, loc: storage))
14
+ else:
15
+ model.load_state_dict(torch.load(fileLocation))
16
+ return model
17
+
18
+ def writeLog(logList, filename):
19
+ with open(filename, 'w') as outfile:
20
+ outfile.write("\n".join(logList))
21
+
22
+
23
+ def kl_loss(mu, logvar):
24
+ return -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).mean()
25
+
26
+
helper_image_functions.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Author: Manu Gond (manu.gond@miun.se)
3
+ Date: Nov-15-2022
4
+ Objective: Accumulation of some general functions which I
5
+ use daily in my code realted to image relasted task.
6
+ The function names and parameters are self explanetory.
7
+ Requirements: Installed python libraries which have been imported.
8
+ '''
9
+
10
+ import torch
11
+ from torchvision.utils import save_image
12
+ from torchvision.transforms import transforms
13
+ import torchmetrics
14
+ import cv2
15
+ import numpy as np
16
+ from PIL import Image
17
+ import utils
18
+
19
+
20
+ #======================= Read and Write =====================#
21
+ def readImage(location):
22
+ image = Image.open(location).convert("RGB")
23
+ return image
24
+
25
+
26
+ def writeImage(image, location):
27
+ image.save(location)
28
+
29
+
30
+ def writeTensorImage(image, filename):
31
+ save_image(image, filename)
32
+
33
+
34
+ def removeChannel(sourceLocation, targetLocation):
35
+ img = readImage(sourceLocation)
36
+ writeImage(img, targetLocation)
37
+
38
+
39
+ def getImageTransform(width, height):
40
+ transform = transforms.Compose([transforms.Resize((height,width)),
41
+ transforms.ToTensor()])
42
+ return transform
43
+
44
+
45
+ def convertTensor(image):
46
+ transform = getImageTransform(image.size[0], image.size[1])
47
+ image = transform(image)
48
+ return image
49
+
50
+
51
+ #=================== 360 Images =======================#
52
+
53
+ def rotateERP180(image):
54
+ '''
55
+ :param image: PIL Image
56
+ :return: BxHxW Torch Tensor Image
57
+ '''
58
+ W = image.size[0]
59
+ H = image.size[1]
60
+ transform = getImageTransform(W, H)
61
+ image = transform(image)
62
+ image1 = image[:, :, 0:(W//2)]
63
+ image2 = image[:, :, (W//2):W]
64
+ image3 = torch.zeros(image.size())
65
+ image3[:, :, 0:(W//2)] = image2
66
+ image3[:, :, (W//2):W] = image1
67
+ return image3
68
+
69
+
70
+ def convertERP2Cube(e_img, face_w=256, mode='bilinear', cube_format='dice'):
71
+ '''
72
+ e_img: ndarray in shape of [H, W, *]
73
+ face_w: int, the length of each face of the cubemap
74
+ '''
75
+ assert len(e_img.shape) == 3
76
+ h, w = e_img.shape[:2]
77
+ if mode == 'bilinear':
78
+ order = 1
79
+ elif mode == 'nearest':
80
+ order = 0
81
+ else:
82
+ raise NotImplementedError('unknown mode')
83
+
84
+ xyz = utils.xyzcube(face_w)
85
+ uv = utils.xyz2uv(xyz)
86
+ coor_xy = utils.uv2coor(uv, h, w)
87
+
88
+ cubemap = np.stack([
89
+ utils.sample_equirec(e_img[..., i], coor_xy, order=order)
90
+ for i in range(e_img.shape[2])
91
+ ], axis=-1)
92
+
93
+ if cube_format == 'horizon':
94
+ pass
95
+ elif cube_format == 'list':
96
+ cubemap = utils.cube_h2list(cubemap)
97
+ elif cube_format == 'dict':
98
+ cubemap = utils.cube_h2dict(cubemap)
99
+ elif cube_format == 'dice':
100
+ cubemap = utils.cube_h2dice(cubemap)
101
+ else:
102
+ raise NotImplementedError()
103
+ return cubemap
104
+
105
+
106
+ def convertCube2ERP(cubemap, h, w, mode='bilinear', cube_format='dice'):
107
+ if mode == 'bilinear':
108
+ order = 1
109
+ elif mode == 'nearest':
110
+ order = 0
111
+ else:
112
+ raise NotImplementedError('unknown mode')
113
+
114
+ if cube_format == 'horizon':
115
+ pass
116
+ elif cube_format == 'list':
117
+ cubemap = utils.cube_list2h(cubemap)
118
+ elif cube_format == 'dict':
119
+ cubemap = utils.cube_dict2h(cubemap)
120
+ elif cube_format == 'dice':
121
+ cubemap = utils.cube_dice2h(cubemap)
122
+ else:
123
+ raise NotImplementedError('unknown cube_format')
124
+ assert len(cubemap.shape) == 3
125
+ assert cubemap.shape[0] * 6 == cubemap.shape[1]
126
+ assert w % 8 == 0
127
+ face_w = cubemap.shape[0]
128
+
129
+ uv = utils.equirect_uvgrid(h, w)
130
+ u, v = np.split(uv, 2, axis=-1)
131
+ u = u[..., 0]
132
+ v = v[..., 0]
133
+ cube_faces = np.stack(np.split(cubemap, 6, 1), 0)
134
+
135
+ # Get face id to each pixel: 0F 1R 2B 3L 4U 5D
136
+ tp = utils.equirect_facetype(h, w)
137
+ coor_x = np.zeros((h, w))
138
+ coor_y = np.zeros((h, w))
139
+
140
+ for i in range(4):
141
+ mask = (tp == i)
142
+ coor_x[mask] = 0.5 * np.tan(u[mask] - np.pi * i / 2)
143
+ coor_y[mask] = -0.5 * np.tan(v[mask]) / np.cos(u[mask] - np.pi * i / 2)
144
+
145
+ mask = (tp == 4)
146
+ c = 0.5 * np.tan(np.pi / 2 - v[mask])
147
+ coor_x[mask] = c * np.sin(u[mask])
148
+ coor_y[mask] = c * np.cos(u[mask])
149
+
150
+ mask = (tp == 5)
151
+ c = 0.5 * np.tan(np.pi / 2 - np.abs(v[mask]))
152
+ coor_x[mask] = c * np.sin(u[mask])
153
+ coor_y[mask] = -c * np.cos(u[mask])
154
+
155
+ # Final renormalize
156
+ coor_x = (np.clip(coor_x, -0.5, 0.5) + 0.5) * face_w
157
+ coor_y = (np.clip(coor_y, -0.5, 0.5) + 0.5) * face_w
158
+
159
+ equirec = np.stack([
160
+ utils.sample_cubefaces(cube_faces[..., i], tp, coor_y, coor_x, order=order)
161
+ for i in range(cube_faces.shape[3])
162
+ ], axis=-1)
163
+ return equirec
164
+
165
+
166
+
167
+ def convertCube2Slices(image):
168
+ '''
169
+ :param image: Image numpy array
170
+ :return: List of Torch Tensors, CxHxW
171
+ '''
172
+ image = convertTensor(image)
173
+ C, H, W = image.size()
174
+ #print(C,H,W)
175
+ top = torch.zeros((C,W//4,W//4))
176
+ left = torch.zeros(top.size())
177
+ front = torch.zeros(top.size())
178
+ right = torch.zeros(top.size())
179
+ back = torch.zeros(top.size())
180
+ bottom = torch.zeros(top.size())
181
+
182
+ top = image[:, 0:H//3, (W//4):(W//4)*2]
183
+ left = image[:, (H//3):(H//3)*2, 0:W//4]
184
+ front = image[:, (H//3):(H//3)*2, (W//4):(W//4)*2]
185
+ right = image[:, (H//3):(H//3)*2, (W//4)*2:(W//4)*3]
186
+ back = image[:, (H // 3):(H // 3) * 2, (W // 4) * 3:]
187
+ bottom = image[:, (H//3)*2:, (W//4):(W//4)*2]
188
+
189
+ '''
190
+ save_image(top, 'top.png')
191
+ save_image(left, 'left.png')
192
+ save_image(front, 'front.png')
193
+ save_image(right, 'right.png')
194
+ save_image(back, 'back.png')
195
+ save_image(bottom, 'bottom.png')
196
+ '''
197
+ return [top, left, front, right, back, bottom]
198
+
199
+ def convertSlicesToCube(imageList):
200
+ '''
201
+ top = convertTensor(readImage(imageList[0]))
202
+ left = convertTensor(readImage(imageList[1]))
203
+ front = convertTensor(readImage(imageList[2]))
204
+ right = convertTensor(readImage(imageList[3]))
205
+ back = convertTensor(readImage(imageList[4]))
206
+ bottom = convertTensor(readImage(imageList[5]))
207
+ '''
208
+ top = imageList[0]
209
+ left = imageList[1]
210
+ front = imageList[2]
211
+ right = imageList[3]
212
+ back = imageList[4]
213
+ bottom = imageList[5]
214
+
215
+ C, H, W = 3, top.size()[1]*3, top.size()[2]*4
216
+ cube = torch.zeros((C, H, W))
217
+
218
+ cube[:, 0:H//3, (W//4):(W//4)*2] = top
219
+ cube[:, (H // 3):(H // 3) * 2, 0:W // 4] = left
220
+ cube[:, (H // 3):(H // 3) * 2, (W // 4):(W // 4) * 2] = front
221
+ cube[:, (H // 3):(H // 3) * 2, (W // 4) * 2:(W // 4) * 3] = right
222
+ cube[:, (H // 3):(H // 3) * 2, (W // 4) * 3:] = back
223
+ cube[:, (H // 3) * 2:, (W // 4):(W // 4) * 2] = bottom
224
+
225
+ return cube
226
+
227
+
228
+
229
+ #=================== Quality Measures =======================#
230
+ '''
231
+ Predicted Shape : BxCxHxW
232
+ Original Shape : BxCxHxW
233
+ Data Type: Torch Tensor
234
+ '''
235
+ def getSSIM(predicted, original):
236
+ SSIM = torchmetrics.StructuralSimilarityIndexMeasure()
237
+ return SSIM(predicted, original).item()
238
+
239
+
240
+ def getPSNR(predicted, original):
241
+ PSNR = torchmetrics.PeakSignalNoiseRatio()
242
+ return PSNR(predicted, original).item()
243
+
244
+
245
+ def getMSE(predicted, original):
246
+ MSE = torchmetrics.MeanSquaredError()
247
+ return MSE(predicted, original).item()
248
+
249
+
250
+ def getMAE(predicted, original):
251
+ MAE = torchmetrics.MeanAbsoluteError()
252
+ return MAE(predicted, original).item()
253
+
254
+
255
+
256
+ if __name__ == "__main__":
257
+
258
+ '''
259
+ img = readImage("31_image_0_0.png")
260
+ img = convertERP2Cube(e_img=np.asarray(img), face_w=256)
261
+ img = Image.fromarray(img.astype('uint8'),'RGB')
262
+ convertCube2Slices(img)
263
+ '''
264
+ #image = convertSlicesToCube(["top.png", "left.png", "front.png", "right.png", "back.png", "bottom.png"])
265
+ #writeTensorImage(image,'this.png')
266
+
267
+ '''
268
+ writeImage(img, 'cube.png')
269
+
270
+ img = readImage('cube.png')
271
+ img = convertCube2ERP(np.asarray(img),512,1024)
272
+ img = Image.fromarray(img.astype('uint8'),'RGB')
273
+ writeImage(img, 'cubeERP.png')
274
+
275
+
276
+ img1 = readImage("31_image_0_0.png")
277
+ img2 = readImage("cubeERP.png")
278
+ img1 = convertTensor(img1)
279
+ img2 = convertTensor(img2)
280
+ print(getSSIM(img1.unsqueeze(0), img2.unsqueeze(0)))
281
+ '''
282
+
283
+ #img = rotateERP180(img)
284
+ #writeTensorImage(img, 'rotated_image.png')
285
+ #img = convertTensor(img)
286
+ #print(getMAE(img.unsqueeze(0),img.unsqueeze(0)))
287
+
288
+
289
+
290
+
models/__init__.py ADDED
File without changes
models/depth_only_lite_model.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import warnings
5
+ warnings.filterwarnings("ignore")
6
+ import torchvision
7
+ import sys
8
+ import os
9
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
10
+ import depth_only_parameters as params
11
+
12
+ def getConvLayer(in_channel, out_channel, stride=1, padding=1, activation=nn.ReLU()):
13
+ return nn.Sequential(
14
+ nn.Conv2d(in_channel, out_channel,
15
+ kernel_size=3,
16
+ stride=stride,
17
+ padding=padding,
18
+ padding_mode='reflect'),
19
+ activation
20
+ )
21
+
22
+ def getConvTransposeLayer(in_channel, out_channel, kernel=3, stride=1, padding=1, activation=nn.ReLU()):
23
+ return nn.Sequential(
24
+ nn.ConvTranspose2d(in_channel,
25
+ out_channel,
26
+ kernel_size=kernel,
27
+ stride=stride,
28
+ padding=padding),
29
+ activation
30
+ )
31
+
32
+ class Flatten(nn.Module):
33
+ def forward(self, input):
34
+ return input.view(input.size(0), -1)
35
+
36
+ class UnFlatten(nn.Module):
37
+ def forward(self, input, size=1):
38
+ return input.view(input.size(0), 1, params.params_height//8, params.params_width//8)
39
+
40
+ class ResidualBlock(nn.Module):
41
+ def __init__(self, in_channels, out_channels, stride=1):
42
+ super(ResidualBlock, self).__init__()
43
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
44
+ stride=stride, padding=1, bias=False)
45
+ self.relu = nn.ReLU()
46
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
47
+ stride=1, padding=1, bias=False)
48
+ self.stride = stride
49
+
50
+ self.shortcut = nn.Sequential()
51
+ if stride != 1 or in_channels != out_channels:
52
+ self.shortcut = nn.Sequential(
53
+ nn.Conv2d(in_channels, out_channels, kernel_size=1,
54
+ stride=stride, bias=False),
55
+ nn.BatchNorm2d(out_channels)
56
+ )
57
+
58
+ def forward(self, x):
59
+ residual = x
60
+ out = self.conv1(x)
61
+ out = self.relu(out)
62
+ out = self.conv2(out)
63
+ out = out + self.shortcut(residual)
64
+ out = self.relu(out)
65
+ return out
66
+
67
+ class UpperEncoder(nn.Module):
68
+ def __init__(self):
69
+ super().__init__()
70
+ model = torchvision.models.resnet152(pretrained=True)
71
+ layers = list(model.children())
72
+ self.ResNetEncoder = nn.Sequential(*layers[:5].copy())
73
+ del model
74
+
75
+ def forward(self, x):
76
+ x1 = x[:, 0:3, :, :]
77
+ x1 = self.ResNetEncoder(x1)
78
+ return x1
79
+
80
+ def apply_resnet_encoder(self, x):
81
+ x1 = x[:, 0:3, :, :]
82
+ x1 = self.ResNetEncoder(x1)
83
+ return x1
84
+
85
+ class LowerEncoder(nn.Module):
86
+ def __init__(self, total_image_input=1):
87
+ super().__init__()
88
+ # Halved channels compared to the original
89
+ self.encoder_pre = ResidualBlock(total_image_input*3, 10)
90
+ self.encoder_layer1 = ResidualBlock(10, 15)
91
+ self.encoder_layer2 = ResidualBlock(15, 25)
92
+
93
+ self.encoder_layer3 = nn.Sequential(
94
+ ResidualBlock(25, 50),
95
+ nn.MaxPool2d(kernel_size=2, stride=2)
96
+ )
97
+ self.encoder_layer4 = ResidualBlock(50, 100)
98
+ self.encoder_layer5 = nn.Sequential(
99
+ ResidualBlock(100, 200),
100
+ nn.MaxPool2d(kernel_size=2, stride=2)
101
+ )
102
+ self.encoder_layer6 = ResidualBlock(200, 300)
103
+ self.encoder_layer7 = nn.Sequential(
104
+ ResidualBlock(300, 400),
105
+ nn.MaxPool2d(kernel_size=2, stride=2)
106
+ )
107
+ self.encoder_layer8 = ResidualBlock(400, 500)
108
+ self.encoder_layer9 = nn.Sequential(
109
+ ResidualBlock(500, 600),
110
+ nn.MaxPool2d(kernel_size=2, stride=2)
111
+ )
112
+ self.encoder_layer10 = ResidualBlock(600, 700)
113
+ self.encoder_layer11 = ResidualBlock(700, 800)
114
+
115
+ def forward(self, x):
116
+ x = self.encoder_pre(x)
117
+ x = self.encoder_layer1(x)
118
+ x = self.encoder_layer2(x)
119
+ skip1 = self.encoder_layer3(x)
120
+
121
+ x = self.encoder_layer4(skip1)
122
+ skip2 = self.encoder_layer5(x)
123
+
124
+ x = self.encoder_layer6(skip2)
125
+ skip3 = self.encoder_layer7(x)
126
+
127
+ x = self.encoder_layer8(skip3)
128
+ skip4 = self.encoder_layer9(x)
129
+
130
+ x = self.encoder_layer10(skip4)
131
+ x = self.encoder_layer11(x)
132
+ return x, [skip1, skip2, skip3, skip4]
133
+
134
+ class MergeDecoder(nn.Module):
135
+ def __init__(self):
136
+ super().__init__()
137
+ # Halved channels for decoder blocks
138
+ self.decoder_layer1 = ResidualBlock(800, 700)
139
+ self.decoder_layer2 = ResidualBlock(700, 600)
140
+ self.decoder_layer3 = ResidualBlock(600, 500)
141
+
142
+ self.decoder_layer4 = nn.Sequential(
143
+ nn.ConvTranspose2d(500, 400, kernel_size=2, stride=2, padding=0),
144
+ nn.ReLU(True)
145
+ )
146
+ self.decoder_layer5 = ResidualBlock(400, 300)
147
+
148
+ self.decoder_layer6 = nn.Sequential(
149
+ nn.ConvTranspose2d(300, 200, kernel_size=2, stride=2, padding=0),
150
+ nn.ReLU(True)
151
+ )
152
+ self.decoder_layer7 = ResidualBlock(200, 100)
153
+
154
+ self.decoder_layer8 = nn.Sequential(
155
+ nn.ConvTranspose2d(100, 50, kernel_size=2, stride=2, padding=0),
156
+ nn.ReLU(True)
157
+ )
158
+ self.decoder_layer9 = ResidualBlock(50, 50)
159
+
160
+ self.decoder_layer10 = nn.Sequential(
161
+ nn.ConvTranspose2d(50, 50, kernel_size=2, stride=2, padding=0),
162
+ nn.ReLU(True)
163
+ )
164
+ self.decoder_layer11 = ResidualBlock(50, 50)
165
+ self.decoder_layer12 = ResidualBlock(50, 25)
166
+ self.decoder_layer13 = ResidualBlock(25, 20)
167
+ self.decoder_layer14 = ResidualBlock(20, 10)
168
+ self.decoder_layer15 = nn.Sequential(
169
+ nn.Conv2d(10, 4, kernel_size=3, stride=1, padding=1),
170
+ nn.ReLU(True)
171
+ )
172
+ self.decoder_layer16 = nn.Sequential(
173
+ nn.Conv2d(4, 1, kernel_size=3, stride=1, padding=1),
174
+ nn.ReLU(True)
175
+ )
176
+
177
+ def forward(self, x, lower_skip_list, upper_skip_list):
178
+ x = self.decoder_layer1(x)
179
+ x = self.decoder_layer2(x)
180
+ # Expecting lower_skip_list[3] and upper_skip_list[1] to have matching dimensions
181
+ x = x + lower_skip_list[3] + upper_skip_list[1]
182
+
183
+ x = self.decoder_layer3(x)
184
+ x = self.decoder_layer4(x)
185
+ x = x + lower_skip_list[2] + upper_skip_list[0]
186
+
187
+ x = self.decoder_layer5(x)
188
+ x = self.decoder_layer6(x)
189
+ x = x + lower_skip_list[1]
190
+
191
+ x = self.decoder_layer7(x)
192
+ x = self.decoder_layer8(x)
193
+ x = x + lower_skip_list[0]
194
+
195
+ x = self.decoder_layer9(x)
196
+ x = self.decoder_layer10(x)
197
+ x = self.decoder_layer11(x)
198
+ x = self.decoder_layer12(x)
199
+ x = self.decoder_layer13(x)
200
+ x = self.decoder_layer14(x)
201
+ x = self.decoder_layer15(x)
202
+ x = self.decoder_layer16(x)
203
+ return x
204
+
205
+ class PVSDNet_Lite(nn.Module):
206
+ def __init__(self, total_image_input=1):
207
+ super().__init__()
208
+ # Upper encoder remains mostly the same
209
+ self.upper_encoder = UpperEncoder()
210
+ self.lower_encoder = LowerEncoder(total_image_input)
211
+ self.merge_decoder = MergeDecoder()
212
+ # Halved extra layers for upper branch:
213
+ self.upper_encoder_extra_1 = nn.Sequential(
214
+ ResidualBlock(256, 400),
215
+ nn.MaxPool2d(kernel_size=2, stride=2)
216
+ )
217
+ self.upper_encoder_extra_2 = nn.Sequential(
218
+ ResidualBlock(400, 600),
219
+ nn.MaxPool2d(kernel_size=2, stride=2)
220
+ )
221
+
222
+ def forward(self, x):
223
+ # First Encoder Branch (Upper)
224
+ upper_features_1 = self.upper_encoder.apply_resnet_encoder(x)
225
+ upper_features_1 = self.upper_encoder_extra_1(upper_features_1)
226
+ upper_features_2 = self.upper_encoder_extra_2(upper_features_1)
227
+
228
+ # Second Encoder Branch (Lower)
229
+ lower_feature, skip_list = self.lower_encoder(x)
230
+
231
+ # Merge and decode features
232
+ merged_feature = self.merge_decoder(lower_feature, skip_list, [upper_features_1, upper_features_2])
233
+ return merged_feature
234
+
models/depth_only_model.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import warnings
5
+ warnings.filterwarnings("ignore")
6
+ import torchvision
7
+ import sys
8
+ import os
9
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
10
+ import depth_only_parameters as params
11
+
12
+ def getConvLayer(in_channel,out_channel,stride=1,padding=1,activation=nn.ReLU()):
13
+ return nn.Sequential(nn.Conv2d(in_channel,
14
+ out_channel,
15
+ kernel_size=3,
16
+ stride=stride,
17
+ padding=padding,
18
+ padding_mode='reflect'),
19
+ activation)
20
+
21
+ def getConvTransposeLayer(in_channel, out_channel,kernel=3,stride=1,padding=1,activation=nn.ReLU()):
22
+ return nn.Sequential(nn.ConvTranspose2d(in_channel,
23
+ out_channel,
24
+ kernel_size = kernel,
25
+ stride=stride,
26
+ padding=padding),
27
+ activation)
28
+
29
+
30
+
31
+ class Flatten(nn.Module):
32
+ def forward(self, input):
33
+ return input.view(input.size(0), -1)
34
+
35
+ class UnFlatten(nn.Module):
36
+ def forward(self, input, size=1):
37
+ return input.view(input.size(0), 1, params.params_height//8, params.params_width//8)
38
+
39
+ class ResidualBlock(nn.Module):
40
+ def __init__(self, in_channels, out_channels, stride=1):
41
+ super(ResidualBlock, self).__init__()
42
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
43
+ self.relu = nn.ReLU()
44
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
45
+ self.stride = stride
46
+
47
+ self.shortcut = nn.Sequential()
48
+ if stride != 1 or in_channels != out_channels:
49
+ self.shortcut = nn.Sequential(
50
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
51
+ nn.BatchNorm2d(out_channels)
52
+ )
53
+
54
+ def forward(self, x):
55
+ residual = x
56
+
57
+ out = self.conv1(x)
58
+ out = self.relu(out)
59
+
60
+ out = self.conv2(out)
61
+
62
+ out = out + self.shortcut(residual)
63
+ out = self.relu(out)
64
+ return out
65
+
66
+ class UpperEncoder(nn.Module):
67
+ def __init__(self):
68
+ super().__init__()
69
+ model = torchvision.models.resnet152(pretrained=False)
70
+ layers = list(model.children())
71
+ self.ResNetEncoder = torch.nn.Sequential(*layers[:5].copy())
72
+ del model
73
+
74
+ def forward(self, x):
75
+ x1 = x[:, 0:3, :, :]
76
+ x1 = self.ResNetEncoder(x1)
77
+ return x1
78
+
79
+ def apply_resnet_encoder(self, x):
80
+ x1 = x[:, 0:3, :, :]
81
+ x1 = self.ResNetEncoder(x1)
82
+ return x1
83
+
84
+
85
+ class LowerEncoder(nn.Module):
86
+ def __init__(self,total_image_input=1):
87
+ super().__init__()
88
+ self.encoder_pre = ResidualBlock((total_image_input*3), 20)
89
+ self.encoder_layer1 = ResidualBlock(20, 30)
90
+ self.encoder_layer2 = ResidualBlock(30, 50)
91
+
92
+ self.encoder_layer3 = nn.Sequential(
93
+ ResidualBlock(50, 100),
94
+ nn.MaxPool2d(kernel_size=2, stride=2)
95
+ )
96
+
97
+ self.encoder_layer4 = ResidualBlock(100, 200)
98
+ self.encoder_layer5 = nn.Sequential(
99
+ ResidualBlock(200, 400),
100
+ nn.MaxPool2d(kernel_size=2, stride=2)
101
+ )
102
+
103
+ self.encoder_layer6 = ResidualBlock(400, 600)
104
+ self.encoder_layer7 = nn.Sequential(
105
+ ResidualBlock(600, 800),
106
+ nn.MaxPool2d(kernel_size=2, stride=2)
107
+ )
108
+
109
+ self.encoder_layer8 = ResidualBlock(800, 1000)
110
+ self.encoder_layer9 = nn.Sequential(
111
+ ResidualBlock(1000, 1200),
112
+ nn.MaxPool2d(kernel_size=2, stride=2)
113
+ )
114
+
115
+ self.encoder_layer10 = ResidualBlock(1200, 1400)
116
+ self.encoder_layer11 = ResidualBlock(1400, 1600)
117
+
118
+ def forward(self, x):
119
+ x = self.encoder_pre(x)
120
+ x = self.encoder_layer1(x)
121
+ x = self.encoder_layer2(x)
122
+ skip1 = self.encoder_layer3(x)
123
+
124
+ x = self.encoder_layer4(skip1)
125
+ skip2 = self.encoder_layer5(x)
126
+
127
+ x = self.encoder_layer6(skip2)
128
+ skip3 = self.encoder_layer7(x)
129
+
130
+ x = self.encoder_layer8(skip3)
131
+ skip4 = self.encoder_layer9(x)
132
+
133
+ x = self.encoder_layer10(skip4)
134
+ x = self.encoder_layer11(x)
135
+
136
+ return x, [skip1, skip2, skip3, skip4]
137
+
138
+ class MergeDecoder(nn.Module):
139
+ def __init__(self):
140
+ super().__init__()
141
+
142
+ self.decoder_layer1 = ResidualBlock(1600, 1400)
143
+ self.decoder_layer2 = ResidualBlock(1400, 1200)
144
+ self.decoder_layer3 = ResidualBlock(1200, 1000)
145
+
146
+ self.decoder_layer4 = nn.Sequential(
147
+ nn.ConvTranspose2d(1000, 800, 2, stride=2, padding=0),
148
+ nn.ReLU(True)
149
+ )
150
+ self.decoder_layer5 = ResidualBlock(800, 600)
151
+
152
+ self.decoder_layer6 = nn.Sequential(
153
+ nn.ConvTranspose2d(600, 400, 2, stride=2, padding=0),
154
+ nn.ReLU(True)
155
+ )
156
+ self.decoder_layer7 = ResidualBlock(400, 200)
157
+
158
+ self.decoder_layer8 = nn.Sequential(
159
+ nn.ConvTranspose2d(200, 100, 2, stride=2, padding=0),
160
+ nn.ReLU(True)
161
+ )
162
+ self.decoder_layer9 = ResidualBlock(100, 100)
163
+
164
+ self.decoder_layer10 = nn.Sequential(
165
+ nn.ConvTranspose2d(100, 100, 2, stride=2, padding=0),
166
+ nn.ReLU(True)
167
+ )
168
+ self.decoder_layer11 = ResidualBlock(100, 100)
169
+ self.decoder_layer12 = ResidualBlock(100, 50)
170
+ self.decoder_layer13 = ResidualBlock(50, 40)
171
+ self.decoder_layer14 = ResidualBlock(40, 20)
172
+ self.decoder_layer15 = nn.Sequential(
173
+ nn.Conv2d(20, 8, 3, stride=1, padding=1),
174
+ nn.ReLU(True)
175
+ )
176
+ self.decoder_layer16 = nn.Sequential(
177
+ nn.Conv2d(8, 1, 3, stride=1, padding=1),
178
+ nn.ReLU(True)
179
+ )
180
+
181
+ def forward(self, x, lower_skip_list, upper_skip_list):
182
+ x = self.decoder_layer1(x)
183
+ x = self.decoder_layer2(x)
184
+ x = x + lower_skip_list[3] + upper_skip_list[1]
185
+
186
+ x = self.decoder_layer3(x)
187
+ x = self.decoder_layer4(x)
188
+ x = x + lower_skip_list[2] + upper_skip_list[0]
189
+
190
+ x = self.decoder_layer5(x)
191
+ x = self.decoder_layer6(x)
192
+ x = x + lower_skip_list[1]
193
+
194
+ x = self.decoder_layer7(x)
195
+ x = self.decoder_layer8(x)
196
+ x = x + lower_skip_list[0]
197
+
198
+ x = self.decoder_layer9(x)
199
+ x = self.decoder_layer10(x)
200
+ x = self.decoder_layer11(x)
201
+ x = self.decoder_layer12(x)
202
+ x = self.decoder_layer13(x)
203
+ x = self.decoder_layer14(x)
204
+ x = self.decoder_layer15(x)
205
+ x = self.decoder_layer16(x)
206
+ return x
207
+
208
+ class PVSDNet(nn.Module):
209
+ def __init__(self,total_image_input=1):
210
+ super().__init__()
211
+ self.upper_encoder = UpperEncoder()
212
+ self.lower_encoder = LowerEncoder(total_image_input)
213
+ self.merge_decoder = MergeDecoder()
214
+
215
+ self.upper_encoder_extra_1 = nn.Sequential(
216
+ ResidualBlock(256, 800),
217
+ nn.MaxPool2d(kernel_size=2, stride=2)
218
+ )
219
+ self.upper_encoder_extra_2 = nn.Sequential(
220
+ ResidualBlock(800, 1200),
221
+ nn.MaxPool2d(kernel_size=2, stride=2)
222
+ )
223
+
224
+ def forward(self, x):
225
+ upper_features_1 = self.upper_encoder.apply_resnet_encoder(x)
226
+ upper_features_1 = self.upper_encoder_extra_1(upper_features_1)
227
+ upper_features_2 = self.upper_encoder_extra_2(upper_features_1)
228
+
229
+ lower_feature, skip_list = self.lower_encoder(x)
230
+
231
+ merged_feature = self.merge_decoder(lower_feature, skip_list, [upper_features_1, upper_features_2])
232
+ return merged_feature
requirements.txt ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==24.1.0
2
+ annotated-doc==0.0.4
3
+ annotated-types==0.7.0
4
+ anyio==4.12.1
5
+ av==16.0.1
6
+ blinker==1.9.0
7
+ brotli==1.2.0
8
+ certifi==2026.1.4
9
+ charset-normalizer==3.4.4
10
+ click==8.3.1
11
+ colorama==0.4.6
12
+ contourpy==1.3.3
13
+ cuda-toolkit==12.9.1
14
+ cycler==0.12.1
15
+ decorator==4.4.2
16
+ fastapi==0.128.0
17
+ ffmpy==1.0.0
18
+ filelock==3.20.0
19
+ Flask==3.1.2
20
+ fonttools==4.61.1
21
+ fsspec==2025.12.0
22
+ fvcore==0.1.5.post20221221
23
+ gradio==6.2.0
24
+ gradio_client==2.0.2
25
+ groovy==0.1.2
26
+ h11==0.16.0
27
+ hf-xet==1.2.0
28
+ httpcore==1.0.9
29
+ httpx==0.28.1
30
+ huggingface_hub==1.2.4
31
+ idna==3.11
32
+ ImageIO==2.37.2
33
+ imageio-ffmpeg==0.6.0
34
+ iopath==0.1.10
35
+ itsdangerous==2.2.0
36
+ Jinja2==3.1.6
37
+ joblib==1.5.3
38
+ kiwisolver==1.4.9
39
+ lazy_loader==0.4
40
+ Mako==1.3.10
41
+ markdown-it-py==4.0.0
42
+ MarkupSafe==2.1.5
43
+ matplotlib==3.10.8
44
+ matplotlib-inline==0.1.6
45
+ mdurl==0.1.2
46
+ ml_dtypes==0.5.4
47
+ moviepy==1.0.3
48
+ mpmath==1.3.0
49
+ networkx==3.6.1
50
+ numpy==1.26.4
51
+ nvidia-cuda-runtime-cu12==12.9.79
52
+ onnx==1.20.0
53
+ onnx-ir==0.1.14
54
+ onnxscript==0.5.7
55
+ opencv-python==4.6.0.66
56
+ orjson==3.11.5
57
+ packaging==25.0
58
+ pandas==2.3.3
59
+ parameterized==0.9.0
60
+ pillow==10.4.0
61
+ pillow_heif==0.15.0
62
+ platformdirs==4.5.1
63
+ portalocker==3.2.0
64
+ proglog==0.1.12
65
+ protobuf==6.33.2
66
+ pycuda==2025.1.2
67
+ pydantic==2.12.5
68
+ pydantic_core==2.41.5
69
+ pydub==0.25.1
70
+ Pygments==2.19.2
71
+ pyparsing==3.3.1
72
+ python-dateutil==2.9.0.post0
73
+ python-multipart==0.0.21
74
+ pytools==2025.2.5
75
+ pytorch-msssim==1.0.0
76
+ pytorchvideo==0.1.5
77
+ pytz==2025.2
78
+ pywin32==311
79
+ PyYAML==6.0.3
80
+ requests==2.32.5
81
+ rich==14.2.0
82
+ safehttpx==0.1.7
83
+ safetensors==0.7.0
84
+ scikit-image==0.26.0
85
+ scikit-learn==1.8.0
86
+ scipy==1.11.2
87
+ semantic-version==2.10.0
88
+ setuptools==80.9.0
89
+ shellingham==1.5.4
90
+ siphash24==1.8
91
+ six==1.17.0
92
+ starlette==0.50.0
93
+ sympy==1.14.0
94
+ tabulate==0.9.0
95
+ tensorrt_cu12==10.14.1.48.post1
96
+ tensorrt_cu12_bindings==10.14.1.48.post1
97
+ tensorrt_cu12_libs==10.14.1.48.post1
98
+ tensorrt_dispatch_cu12==10.14.1.48.post1
99
+ tensorrt_dispatch_cu12_bindings==10.14.1.48.post1
100
+ tensorrt_dispatch_cu12_libs==10.14.1.48.post1
101
+ tensorrt_lean_cu12==10.14.1.48.post1
102
+ tensorrt_lean_cu12_bindings==10.14.1.48.post1
103
+ tensorrt_lean_cu12_libs==10.14.1.48.post1
104
+ termcolor==3.3.0
105
+ threadpoolctl==3.6.0
106
+ tifffile==2025.12.20
107
+ timm==1.0.24
108
+ tomlkit==0.13.3
109
+ torch==2.9.1+cu130
110
+ torchvision==0.24.1+cu130
111
+ tqdm==4.65.0
112
+ traitlets==5.14.3
113
+ typer==0.21.1
114
+ typer-slim==0.21.1
115
+ typing-inspection==0.4.2
116
+ typing_extensions==4.15.0
117
+ tzdata==2025.3
118
+ urllib3==2.6.3
119
+ uvicorn==0.40.0
120
+ Werkzeug==3.1.5
121
+ wheel==0.45.1
122
+ yacs==0.1.8
rff_torch.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch import Tensor
4
+ import torch.nn as nn
5
+
6
+ @torch.jit.script
7
+ def positional_encoding(
8
+ v: Tensor,
9
+ sigma: float,
10
+ m: int) -> Tensor:
11
+ r"""Computes :math:`\gamma(\mathbf{v}) = (\dots, \cos{2 \pi \sigma^{(j/m)} \mathbf{v}} , \sin{2 \pi \sigma^{(j/m)} \mathbf{v}}, \dots)`
12
+ where :math:`j \in \{0, \dots, m-1\}`
13
+
14
+ Args:
15
+ v (Tensor): input tensor of shape :math:`(N, *, \text{input_size})`
16
+ sigma (float): constant chosen based upon the domain of :attr:`v`
17
+ m (int): [description]
18
+
19
+ Returns:
20
+ Tensor: mapped tensor of shape :math:`(N, *, 2 \cdot m \cdot \text{input_size})`
21
+
22
+ See :class:`~rff.layers.PositionalEncoding` for more details.
23
+ """
24
+ j = torch.arange(m, device=v.device)
25
+ coeffs = 2 * np.pi * sigma ** (j / m)
26
+ vp = coeffs * torch.unsqueeze(v, -1)
27
+ vp_cat = torch.cat((torch.cos(vp), torch.sin(vp)), dim=-1)
28
+ return vp_cat.flatten(-2, -1)
29
+
30
+
31
+ class PositionalEncoding(nn.Module):
32
+ """Layer for mapping coordinates using the positional encoding"""
33
+
34
+ def __init__(self, sigma: float, m: int):
35
+ r"""
36
+ Args:
37
+ sigma (float): frequency constant
38
+ m (int): number of frequencies to map to
39
+ """
40
+ super().__init__()
41
+ self.sigma = sigma
42
+ self.m = m
43
+
44
+ def forward(self, v: Tensor) -> Tensor:
45
+ r"""Computes :math:`\gamma(\mathbf{v}) = (\dots, \cos{2 \pi \sigma^{(j/m)} \mathbf{v}} , \sin{2 \pi \sigma^{(j/m)} \mathbf{v}}, \dots)`
46
+
47
+ Args:
48
+ v (Tensor): input tensor of shape :math:`(N, *, \text{input_size})`
49
+
50
+ Returns:
51
+ Tensor: mapped tensor of shape :math:`(N, *, 2 \cdot m \cdot \text{input_size})`
52
+ """
53
+ return positional_encoding(v, self.sigma, self.m)
utils.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from scipy.ndimage import map_coordinates
3
+
4
+
5
+ def xyzcube(face_w):
6
+ '''
7
+ Return the xyz cordinates of the unit cube in [F R B L U D] format.
8
+ '''
9
+ out = np.zeros((face_w, face_w * 6, 3), np.float32)
10
+ rng = np.linspace(-0.5, 0.5, num=face_w, dtype=np.float32)
11
+ grid = np.stack(np.meshgrid(rng, -rng), -1)
12
+
13
+ # Front face (z = 0.5)
14
+ out[:, 0*face_w:1*face_w, [0, 1]] = grid
15
+ out[:, 0*face_w:1*face_w, 2] = 0.5
16
+
17
+ # Right face (x = 0.5)
18
+ out[:, 1*face_w:2*face_w, [2, 1]] = grid
19
+ out[:, 1*face_w:2*face_w, 0] = 0.5
20
+
21
+ # Back face (z = -0.5)
22
+ out[:, 2*face_w:3*face_w, [0, 1]] = grid
23
+ out[:, 2*face_w:3*face_w, 2] = -0.5
24
+
25
+ # Left face (x = -0.5)
26
+ out[:, 3*face_w:4*face_w, [2, 1]] = grid
27
+ out[:, 3*face_w:4*face_w, 0] = -0.5
28
+
29
+ # Up face (y = 0.5)
30
+ out[:, 4*face_w:5*face_w, [0, 2]] = grid
31
+ out[:, 4*face_w:5*face_w, 1] = 0.5
32
+
33
+ # Down face (y = -0.5)
34
+ out[:, 5*face_w:6*face_w, [0, 2]] = grid
35
+ out[:, 5*face_w:6*face_w, 1] = -0.5
36
+
37
+ return out
38
+
39
+
40
+ def equirect_uvgrid(h, w):
41
+ u = np.linspace(-np.pi, np.pi, num=w, dtype=np.float32)
42
+ v = np.linspace(np.pi, -np.pi, num=h, dtype=np.float32) / 2
43
+
44
+ return np.stack(np.meshgrid(u, v), axis=-1)
45
+
46
+
47
+ def equirect_facetype(h, w):
48
+ '''
49
+ 0F 1R 2B 3L 4U 5D
50
+ '''
51
+ tp = np.roll(np.arange(4).repeat(w // 4)[None, :].repeat(h, 0), 3 * w // 8, 1)
52
+
53
+ # Prepare ceil mask
54
+ mask = np.zeros((h, w // 4), np.bool)
55
+ idx = np.linspace(-np.pi, np.pi, w // 4) / 4
56
+ idx = h // 2 - np.round(np.arctan(np.cos(idx)) * h / np.pi).astype(int)
57
+ for i, j in enumerate(idx):
58
+ mask[:j, i] = 1
59
+ mask = np.roll(np.concatenate([mask] * 4, 1), 3 * w // 8, 1)
60
+
61
+ tp[mask] = 4
62
+ tp[np.flip(mask, 0)] = 5
63
+
64
+ return tp.astype(np.int32)
65
+
66
+
67
+ def xyzpers(h_fov, v_fov, u, v, out_hw, in_rot):
68
+ out = np.ones((*out_hw, 3), np.float32)
69
+
70
+ x_max = np.tan(h_fov / 2)
71
+ y_max = np.tan(v_fov / 2)
72
+ x_rng = np.linspace(-x_max, x_max, num=out_hw[1], dtype=np.float32)
73
+ y_rng = np.linspace(-y_max, y_max, num=out_hw[0], dtype=np.float32)
74
+ out[..., :2] = np.stack(np.meshgrid(x_rng, -y_rng), -1)
75
+ Rx = rotation_matrix(v, [1, 0, 0])
76
+ Ry = rotation_matrix(u, [0, 1, 0])
77
+ Ri = rotation_matrix(in_rot, np.array([0, 0, 1.0]).dot(Rx).dot(Ry))
78
+
79
+ return out.dot(Rx).dot(Ry).dot(Ri)
80
+
81
+
82
+ def xyz2uv(xyz):
83
+ '''
84
+ xyz: ndarray in shape of [..., 3]
85
+ '''
86
+ x, y, z = np.split(xyz, 3, axis=-1)
87
+ u = np.arctan2(x, z)
88
+ c = np.sqrt(x**2 + z**2)
89
+ v = np.arctan2(y, c)
90
+
91
+ return np.concatenate([u, v], axis=-1)
92
+
93
+
94
+ def uv2unitxyz(uv):
95
+ u, v = np.split(uv, 2, axis=-1)
96
+ y = np.sin(v)
97
+ c = np.cos(v)
98
+ x = c * np.sin(u)
99
+ z = c * np.cos(u)
100
+
101
+ return np.concatenate([x, y, z], axis=-1)
102
+
103
+
104
+ def uv2coor(uv, h, w):
105
+ '''
106
+ uv: ndarray in shape of [..., 2]
107
+ h: int, height of the equirectangular image
108
+ w: int, width of the equirectangular image
109
+ '''
110
+ u, v = np.split(uv, 2, axis=-1)
111
+ coor_x = (u / (2 * np.pi) + 0.5) * w - 0.5
112
+ coor_y = (-v / np.pi + 0.5) * h - 0.5
113
+
114
+ return np.concatenate([coor_x, coor_y], axis=-1)
115
+
116
+
117
+ def coor2uv(coorxy, h, w):
118
+ coor_x, coor_y = np.split(coorxy, 2, axis=-1)
119
+ u = ((coor_x + 0.5) / w - 0.5) * 2 * np.pi
120
+ v = -((coor_y + 0.5) / h - 0.5) * np.pi
121
+
122
+ return np.concatenate([u, v], axis=-1)
123
+
124
+
125
+ def sample_equirec(e_img, coor_xy, order):
126
+ w = e_img.shape[1]
127
+ coor_x, coor_y = np.split(coor_xy, 2, axis=-1)
128
+ pad_u = np.roll(e_img[[0]], w // 2, 1)
129
+ pad_d = np.roll(e_img[[-1]], w // 2, 1)
130
+ e_img = np.concatenate([e_img, pad_d, pad_u], 0)
131
+ return map_coordinates(e_img, [coor_y, coor_x],
132
+ order=order, mode='wrap')[..., 0]
133
+
134
+
135
+ def sample_cubefaces(cube_faces, tp, coor_y, coor_x, order):
136
+ cube_faces = cube_faces.copy()
137
+ cube_faces[1] = np.flip(cube_faces[1], 1)
138
+ cube_faces[2] = np.flip(cube_faces[2], 1)
139
+ cube_faces[4] = np.flip(cube_faces[4], 0)
140
+
141
+ # Pad up down
142
+ pad_ud = np.zeros((6, 2, cube_faces.shape[2]))
143
+ pad_ud[0, 0] = cube_faces[5, 0, :]
144
+ pad_ud[0, 1] = cube_faces[4, -1, :]
145
+ pad_ud[1, 0] = cube_faces[5, :, -1]
146
+ pad_ud[1, 1] = cube_faces[4, ::-1, -1]
147
+ pad_ud[2, 0] = cube_faces[5, -1, ::-1]
148
+ pad_ud[2, 1] = cube_faces[4, 0, ::-1]
149
+ pad_ud[3, 0] = cube_faces[5, ::-1, 0]
150
+ pad_ud[3, 1] = cube_faces[4, :, 0]
151
+ pad_ud[4, 0] = cube_faces[0, 0, :]
152
+ pad_ud[4, 1] = cube_faces[2, 0, ::-1]
153
+ pad_ud[5, 0] = cube_faces[2, -1, ::-1]
154
+ pad_ud[5, 1] = cube_faces[0, -1, :]
155
+ cube_faces = np.concatenate([cube_faces, pad_ud], 1)
156
+
157
+ # Pad left right
158
+ pad_lr = np.zeros((6, cube_faces.shape[1], 2))
159
+ pad_lr[0, :, 0] = cube_faces[1, :, 0]
160
+ pad_lr[0, :, 1] = cube_faces[3, :, -1]
161
+ pad_lr[1, :, 0] = cube_faces[2, :, 0]
162
+ pad_lr[1, :, 1] = cube_faces[0, :, -1]
163
+ pad_lr[2, :, 0] = cube_faces[3, :, 0]
164
+ pad_lr[2, :, 1] = cube_faces[1, :, -1]
165
+ pad_lr[3, :, 0] = cube_faces[0, :, 0]
166
+ pad_lr[3, :, 1] = cube_faces[2, :, -1]
167
+ pad_lr[4, 1:-1, 0] = cube_faces[1, 0, ::-1]
168
+ pad_lr[4, 1:-1, 1] = cube_faces[3, 0, :]
169
+ pad_lr[5, 1:-1, 0] = cube_faces[1, -2, :]
170
+ pad_lr[5, 1:-1, 1] = cube_faces[3, -2, ::-1]
171
+ cube_faces = np.concatenate([cube_faces, pad_lr], 2)
172
+
173
+ return map_coordinates(cube_faces, [tp, coor_y, coor_x], order=order, mode='wrap')
174
+
175
+
176
+ def cube_h2list(cube_h):
177
+ assert cube_h.shape[0] * 6 == cube_h.shape[1]
178
+ return np.split(cube_h, 6, axis=1)
179
+
180
+
181
+ def cube_list2h(cube_list):
182
+ assert len(cube_list) == 6
183
+ assert sum(face.shape == cube_list[0].shape for face in cube_list) == 6
184
+ return np.concatenate(cube_list, axis=1)
185
+
186
+
187
+ def cube_h2dict(cube_h):
188
+ cube_list = cube_h2list(cube_h)
189
+ return dict([(k, cube_list[i])
190
+ for i, k in enumerate(['F', 'R', 'B', 'L', 'U', 'D'])])
191
+
192
+
193
+ def cube_dict2h(cube_dict, face_k=['F', 'R', 'B', 'L', 'U', 'D']):
194
+ assert len(face_k) == 6
195
+ return cube_list2h([cube_dict[k] for k in face_k])
196
+
197
+
198
+ def cube_h2dice(cube_h):
199
+ assert cube_h.shape[0] * 6 == cube_h.shape[1]
200
+ w = cube_h.shape[0]
201
+ cube_dice = np.zeros((w * 3, w * 4, cube_h.shape[2]), dtype=cube_h.dtype)
202
+ cube_list = cube_h2list(cube_h)
203
+ # Order: F R B L U D
204
+ sxy = [(1, 1), (2, 1), (3, 1), (0, 1), (1, 0), (1, 2)]
205
+ for i, (sx, sy) in enumerate(sxy):
206
+ face = cube_list[i]
207
+ if i in [1, 2]:
208
+ face = np.flip(face, axis=1)
209
+ if i == 4:
210
+ face = np.flip(face, axis=0)
211
+ cube_dice[sy*w:(sy+1)*w, sx*w:(sx+1)*w] = face
212
+ return cube_dice
213
+
214
+
215
+ def cube_dice2h(cube_dice):
216
+ w = cube_dice.shape[0] // 3
217
+ assert cube_dice.shape[0] == w * 3 and cube_dice.shape[1] == w * 4
218
+ cube_h = np.zeros((w, w * 6, cube_dice.shape[2]), dtype=cube_dice.dtype)
219
+ # Order: F R B L U D
220
+ sxy = [(1, 1), (2, 1), (3, 1), (0, 1), (1, 0), (1, 2)]
221
+ for i, (sx, sy) in enumerate(sxy):
222
+ face = cube_dice[sy*w:(sy+1)*w, sx*w:(sx+1)*w]
223
+ if i in [1, 2]:
224
+ face = np.flip(face, axis=1)
225
+ if i == 4:
226
+ face = np.flip(face, axis=0)
227
+ cube_h[:, i*w:(i+1)*w] = face
228
+ return cube_h
229
+
230
+
231
+ def rotation_matrix(rad, ax):
232
+ ax = np.array(ax)
233
+ assert len(ax.shape) == 1 and ax.shape[0] == 3
234
+ ax = ax / np.sqrt((ax**2).sum())
235
+ R = np.diag([np.cos(rad)] * 3)
236
+ R = R + np.outer(ax, ax) * (1.0 - np.cos(rad))
237
+
238
+ ax = ax * np.sin(rad)
239
+ R = R + np.array([[0, -ax[2], ax[1]],
240
+ [ax[2], 0, -ax[0]],
241
+ [-ax[1], ax[0], 0]])
242
+
243
+ return R