3ZadeSSG commited on
Commit
ff00a24
·
1 Parent(s): 0032477

initial commit

Browse files
README.md CHANGED
@@ -1,14 +1,30 @@
1
- ---
2
- title: RT MPINet
3
- emoji: 👀
4
- colorFrom: red
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 5.41.1
8
- app_file: app.py
9
- pinned: false
10
- license: gpl-2.0
11
- short_description: Multiplane Image Network for Real-Time View Synthesis
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <a href="#"><img src='https://img.shields.io/badge/-Paper-00629B?style=flat&logo=ieee&logoColor=white' alt='arXiv'></a>
3
+ <a href='https://realistic3d-miun.github.io/Research/RT_MPINet/index.html'><img src='https://img.shields.io/badge/Project_Page-Website-green?logo=googlechrome&logoColor=white' alt='Project Page'></a>
4
+ <a href='https://huggingface.co/spaces/3ZadeSSG/RT-MPINet'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Demo_(RT_MPINet)-blue'></a>
5
+ </div>
6
+
7
+ # RT-MPINet
8
+ #### Real-Time View Synthesis with Multiplane Image Network using Multimodal Supervision (RT-MPINet)
9
+
10
+ We present a real-time multiplane image (MPI) network. Unlike existing MPI based approaches that often rely on a separate depth estimation network to guide the network for estimating MPI parameters, our method directly predicts these parameters from a single RGB image. To guide the network we present a multimodal training strategy utilizing joint supervision from view synthesis and depth estimation losses. More details can be found in the paper.
11
+
12
+ **Please head to the [Project Page](https://realistic3d-miun.github.io/Research/RT_MPINet/index.html) to see supplementary materials and Full Code**
13
+
14
+ ## Acknowledgements
15
+ - We thank the authors of [AdaMPI](https://github.com/yxuhan/AdaMPI) for their implementation of the homography renderer which has been used in this codebase under `./utils` directory
16
+ - We tank the author of [Deepview renderer](https://github.com/Findeton/deepview) template, which was used in our project page.
17
+
18
+ ## Citation
19
+ If you use our work please use following citation:
20
+ ```
21
+ @inproceedings{gond2025rtmpi,
22
+ title={Real-Time View Synthesis with Multiplane Image Network using Multimodal Supervision},
23
+ author={Gond, Manu and Shamshirgarha, Mohammadreza and Zerman, Emin and Knorr, Sebastian and Sj{\"o}str{\"o}m, M{\aa}rten},
24
+ booktitle={2025 IEEE 27th International Workshop on Multimedia Signal Processing (MMSP)},
25
+ pages={},
26
+ year={2025},
27
+ organization={IEEE}
28
+ }
29
+ ```
30
+
app.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import cv2
5
+ import tempfile
6
+ from PIL import Image
7
+ import torchvision.transforms as transforms
8
+ import matplotlib.pyplot as plt
9
+ from model_Small import MMPI as MMPI_S
10
+ from model_Medium import MMPI as MMPI_M
11
+ from model_Large import MMPI as MMPI_L
12
+ import helperFunctions as helper
13
+ import socket
14
+ import parameters as params
15
+ from utils.mpi.homography_sampler import HomographySample
16
+ from utils.utils import (
17
+ render_novel_view,
18
+ )
19
+
20
+ # Checkpoint locations for all models
21
+ MODEL_S_LOCATION = "./checkpoint/checkpoint_RT_MPI_Small.pth"
22
+ MODEL_M_LOCATION = "./checkpoint/checkpoint_RT_MPI_Medium.pth"
23
+ MODEL_L_LOCATION = "./checkpoint/checkpoint_RT_MPI_Large.pth"
24
+
25
+ DEVICE = "cpu"
26
+
27
+ def getPositionVector(x, y, z, pose):
28
+ pose[0,0,3] = x
29
+ pose[0,1,3] = y
30
+ pose[0,2,3] = z
31
+ return pose
32
+
33
+ def generateCircularTrajectory(radius, num_frames):
34
+ angles = np.linspace(0, 2 * np.pi, num_frames, endpoint=False)
35
+ return [[radius * np.cos(angle), radius * np.sin(angle), 0] for angle in angles]
36
+
37
+ def generateWiggleTrajectory(radius, num_frames):
38
+ angles = np.linspace(0, 2 * np.pi, num_frames, endpoint=False)
39
+ return [[radius * np.cos(angle), 0, radius * np.sin(angle)] for angle in angles]
40
+
41
+ def create_video_from_memory(frames, fps=60):
42
+ if not frames:
43
+ return None
44
+ height, width, _ = frames[0].shape
45
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
46
+ temp_video = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
47
+ out = cv2.VideoWriter(temp_video.name, fourcc, fps, (width, height))
48
+ for frame in frames:
49
+ out.write(frame)
50
+ out.release()
51
+ return temp_video.name
52
+
53
+ def process_image(img, video_type, radius, num_frames, num_loops, model_type, resolution):
54
+ # Parse resolution string
55
+ height, width = map(int, resolution.lower().split("x"))
56
+
57
+ # Select model class and checkpoint
58
+ if model_type == "Small":
59
+ model_class = MMPI_S
60
+ checkpoint = MODEL_S_LOCATION
61
+ elif model_type == "Medium":
62
+ model_class = MMPI_M
63
+ checkpoint = MODEL_M_LOCATION
64
+ else:
65
+ model_class = MMPI_L
66
+ checkpoint = MODEL_L_LOCATION
67
+
68
+ # Load model
69
+ model = model_class(total_image_input=params.params_number_input, height=height, width=width)
70
+ model = helper.load_Checkpoint(checkpoint, model, load_cpu=True)
71
+ model.to(DEVICE)
72
+ model.eval()
73
+
74
+ min_side = min(img.width, img.height)
75
+ left = (img.width - min_side) // 2
76
+ top = (img.height - min_side) // 2
77
+ right = left + min_side
78
+ bottom = top + min_side
79
+ img = img.crop((left, top, right, bottom))
80
+
81
+ if video_type == "Circle":
82
+ trajectory = generateCircularTrajectory(radius, num_frames)
83
+ elif video_type == "Swing":
84
+ trajectory = generateWiggleTrajectory(radius, num_frames)
85
+ else:
86
+ trajectory = generateCircularTrajectory(radius, num_frames)
87
+
88
+ transform = transforms.Compose([
89
+ transforms.Resize((height, width)),
90
+ transforms.ToTensor()
91
+ ])
92
+ img_input = transform(img).to(DEVICE).unsqueeze(0)
93
+
94
+ grid = params.get_disparity_all_src().unsqueeze(0).to(DEVICE)
95
+ k_tgt = torch.tensor([
96
+ [0.58, 0, 0.5],
97
+ [0, 0.58, 0.5],
98
+ [0, 0, 1]]).to(DEVICE)
99
+ k_tgt[0, :] *= height
100
+ k_tgt[1, :] *= width
101
+ k_tgt = k_tgt.unsqueeze(0)
102
+ k_src_inv = torch.inverse(k_tgt)
103
+ pose = torch.eye(4).to(DEVICE).unsqueeze(0)
104
+
105
+ homography_sampler = HomographySample(height, width, DEVICE)
106
+
107
+ with torch.no_grad():
108
+ rgb_layers, sigma_layers = model.get_layers(img_input, height=height, width=width)
109
+
110
+ predicted_depth = model.get_depth(img_input)
111
+ predicted_depth = (predicted_depth-predicted_depth.min())/(predicted_depth.max()-predicted_depth.min())
112
+ img_predicted_depth = predicted_depth.squeeze().cpu().detach().numpy()
113
+ img_predicted_depth_colored = plt.get_cmap('inferno')(img_predicted_depth / np.max(img_predicted_depth))[:, :, :3]
114
+ img_predicted_depth_colored = (img_predicted_depth_colored * 255).astype(np.uint8)
115
+ img_predicted_depth_colored = Image.fromarray(img_predicted_depth_colored)
116
+
117
+ layer_depth = model.get_layer_depth(img_input, grid)
118
+ img_layer_depth = layer_depth.squeeze().cpu().detach().numpy()
119
+ img_layer_depth_colored = plt.get_cmap('inferno')(img_layer_depth / np.max(img_layer_depth))[:, :, :3]
120
+ img_layer_depth_colored = (img_layer_depth_colored * 255).astype(np.uint8)
121
+ img_layer_depth_colored = Image.fromarray(img_layer_depth_colored)
122
+
123
+ single_loop_frames = []
124
+ for idx, pose_coords in enumerate(trajectory):
125
+ #print(f" - Rendering frame {idx + 1}/{len(trajectory)}", end="\r")
126
+ with torch.no_grad():
127
+ target_pose = getPositionVector(pose_coords[0], pose_coords[1], pose_coords[2], pose)
128
+ output_img = render_novel_view(rgb_layers,
129
+ sigma_layers,
130
+ grid,
131
+ target_pose,
132
+ k_src_inv,
133
+ k_tgt,
134
+ homography_sampler)
135
+
136
+ img_np = output_img.detach().cpu().squeeze(0).permute(1, 2, 0).numpy()
137
+ img_np = (img_np * 255).astype(np.uint8)
138
+ img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
139
+ single_loop_frames.append(img_bgr)
140
+
141
+ final_frames = single_loop_frames * int(num_loops)
142
+
143
+ video_path = create_video_from_memory(final_frames)
144
+ #print("Video generation complete!")
145
+
146
+ return video_path, img_predicted_depth_colored, img_layer_depth_colored
147
+
148
+ with gr.Blocks(title="RT-MPINet", theme="default") as demo:
149
+ gr.Markdown(
150
+ """
151
+ ## Parallax Video Generator via Real-Time Multiplane Image Network (RT-MPINet)
152
+ We use a smaller 256x256 model for faster inference on CPU instances.
153
+
154
+ #### Notes:
155
+ 1. Use a higher number of frames (>80) and loops (>4) to get a smoother video.
156
+ 2. The default uses 60 frames and 4 camera loops for fast video generation.
157
+ 3. We have 3 models available (larger the model, slower the inference):
158
+ * **Small:** 6.6 Million parameters
159
+ * **Medium:** 69 Million parameters
160
+ * **Large:** 288 Million parameters (Not available in this demo due to storage limits, you need to download this model and run locally)
161
+ """)
162
+ with gr.Row():
163
+ img_input = gr.Image(type="pil", label="Upload Image")
164
+ video_type = gr.Dropdown(["Circle", "Swing"], label="Video Type", value="Swing")
165
+ with gr.Column():
166
+ with gr.Accordion("Advanced Settings", open=False):
167
+ radius = gr.Slider(0.001, 0.1, value=0.05, label="Radius (for Circle/Swing)")
168
+ num_frames = gr.Slider(10, 180, value=60, step=1, label="Frames per Loop")
169
+ num_loops = gr.Slider(1, 10, value=4, step=1, label="Number of Loops")
170
+ with gr.Column():
171
+ model_type_dropdown = gr.Dropdown(["Small", "Medium"], label="Model Type", value="Medium")
172
+ resolution_dropdown = gr.Dropdown(["256x256", "384x384", "512x512"], label="Input Resolution", value="384x384")
173
+ generate_btn = gr.Button("Generate Video", variant="primary")
174
+
175
+ with gr.Row():
176
+ video_output = gr.Video(label="Generated Video")
177
+ depth_output = gr.Image(label="Depth Map - From Depth Decoder")
178
+ layer_depth_output = gr.Image(label="Layer Depth Map - From MPI Layers")
179
+
180
+ def toggle_custom_path(video_type_selection):
181
+ is_custom = (video_type_selection == "Custom")
182
+ return gr.update(visible=is_custom)
183
+
184
+ generate_btn.click(fn=process_image,
185
+ inputs=[img_input, video_type, radius, num_frames, num_loops, model_type_dropdown, resolution_dropdown],
186
+ outputs=[video_output, depth_output, layer_depth_output])
187
+
188
+ demo.launch()
helperFunctions.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import torch.nn.functional as F
4
+
5
+ def save_checkpoint(model, filelocation, save_parallel = True):
6
+ if save_parallel:
7
+ torch.save(model.module.state_dict(), filelocation)
8
+ else:
9
+ torch.save(model.state_dict(), filelocation)
10
+
11
+ def load_Checkpoint(fileLocation,model, load_cpu=False):
12
+ if load_cpu:
13
+ model.load_state_dict(torch.load(fileLocation,map_location=lambda storage, loc: storage))
14
+ else:
15
+ model.load_state_dict(torch.load(fileLocation))
16
+ return model
17
+
18
+ def writeLog(logList, filename):
19
+ with open(filename, 'w') as outfile:
20
+ outfile.write("\n".join(logList))
21
+
22
+
23
+ def kl_loss(mu, logvar):
24
+ return -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).mean()
25
+
26
+
helper_image_functions.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Author: Manu Gond (manu.gond@miun.se)
3
+ Date: Nov-15-2022
4
+ Objective: Accumulation of some general functions which I
5
+ use daily in my code realted to image relasted task.
6
+ The function names and parameters are self explanetory.
7
+ Requirements: Installed python libraries which have been imported.
8
+ '''
9
+
10
+ import torch
11
+ from torchvision.utils import save_image
12
+ from torchvision.transforms import transforms
13
+ import torchmetrics
14
+ import cv2
15
+ import numpy as np
16
+ from PIL import Image
17
+ import utils
18
+
19
+
20
+ #======================= Read and Write =====================#
21
+ def readImage(location):
22
+ image = Image.open(location).convert("RGB")
23
+ return image
24
+
25
+
26
+ def writeImage(image, location):
27
+ image.save(location)
28
+
29
+
30
+ def writeTensorImage(image, filename):
31
+ save_image(image, filename)
32
+
33
+
34
+ def removeChannel(sourceLocation, targetLocation):
35
+ img = readImage(sourceLocation)
36
+ writeImage(img, targetLocation)
37
+
38
+
39
+ def getImageTransform(width, height):
40
+ transform = transforms.Compose([transforms.Resize((height,width)),
41
+ transforms.ToTensor()])
42
+ return transform
43
+
44
+
45
+ def convertTensor(image):
46
+ transform = getImageTransform(image.size[0], image.size[1])
47
+ image = transform(image)
48
+ return image
49
+
50
+
51
+ #=================== 360 Images =======================#
52
+
53
+ def rotateERP180(image):
54
+ '''
55
+ :param image: PIL Image
56
+ :return: BxHxW Torch Tensor Image
57
+ '''
58
+ W = image.size[0]
59
+ H = image.size[1]
60
+ transform = getImageTransform(W, H)
61
+ image = transform(image)
62
+ image1 = image[:, :, 0:(W//2)]
63
+ image2 = image[:, :, (W//2):W]
64
+ image3 = torch.zeros(image.size())
65
+ image3[:, :, 0:(W//2)] = image2
66
+ image3[:, :, (W//2):W] = image1
67
+ return image3
68
+
69
+
70
+ def convertERP2Cube(e_img, face_w=256, mode='bilinear', cube_format='dice'):
71
+ '''
72
+ e_img: ndarray in shape of [H, W, *]
73
+ face_w: int, the length of each face of the cubemap
74
+ '''
75
+ assert len(e_img.shape) == 3
76
+ h, w = e_img.shape[:2]
77
+ if mode == 'bilinear':
78
+ order = 1
79
+ elif mode == 'nearest':
80
+ order = 0
81
+ else:
82
+ raise NotImplementedError('unknown mode')
83
+
84
+ xyz = utils.xyzcube(face_w)
85
+ uv = utils.xyz2uv(xyz)
86
+ coor_xy = utils.uv2coor(uv, h, w)
87
+
88
+ cubemap = np.stack([
89
+ utils.sample_equirec(e_img[..., i], coor_xy, order=order)
90
+ for i in range(e_img.shape[2])
91
+ ], axis=-1)
92
+
93
+ if cube_format == 'horizon':
94
+ pass
95
+ elif cube_format == 'list':
96
+ cubemap = utils.cube_h2list(cubemap)
97
+ elif cube_format == 'dict':
98
+ cubemap = utils.cube_h2dict(cubemap)
99
+ elif cube_format == 'dice':
100
+ cubemap = utils.cube_h2dice(cubemap)
101
+ else:
102
+ raise NotImplementedError()
103
+ return cubemap
104
+
105
+
106
+ def convertCube2ERP(cubemap, h, w, mode='bilinear', cube_format='dice'):
107
+ if mode == 'bilinear':
108
+ order = 1
109
+ elif mode == 'nearest':
110
+ order = 0
111
+ else:
112
+ raise NotImplementedError('unknown mode')
113
+
114
+ if cube_format == 'horizon':
115
+ pass
116
+ elif cube_format == 'list':
117
+ cubemap = utils.cube_list2h(cubemap)
118
+ elif cube_format == 'dict':
119
+ cubemap = utils.cube_dict2h(cubemap)
120
+ elif cube_format == 'dice':
121
+ cubemap = utils.cube_dice2h(cubemap)
122
+ else:
123
+ raise NotImplementedError('unknown cube_format')
124
+ assert len(cubemap.shape) == 3
125
+ assert cubemap.shape[0] * 6 == cubemap.shape[1]
126
+ assert w % 8 == 0
127
+ face_w = cubemap.shape[0]
128
+
129
+ uv = utils.equirect_uvgrid(h, w)
130
+ u, v = np.split(uv, 2, axis=-1)
131
+ u = u[..., 0]
132
+ v = v[..., 0]
133
+ cube_faces = np.stack(np.split(cubemap, 6, 1), 0)
134
+
135
+ # Get face id to each pixel: 0F 1R 2B 3L 4U 5D
136
+ tp = utils.equirect_facetype(h, w)
137
+ coor_x = np.zeros((h, w))
138
+ coor_y = np.zeros((h, w))
139
+
140
+ for i in range(4):
141
+ mask = (tp == i)
142
+ coor_x[mask] = 0.5 * np.tan(u[mask] - np.pi * i / 2)
143
+ coor_y[mask] = -0.5 * np.tan(v[mask]) / np.cos(u[mask] - np.pi * i / 2)
144
+
145
+ mask = (tp == 4)
146
+ c = 0.5 * np.tan(np.pi / 2 - v[mask])
147
+ coor_x[mask] = c * np.sin(u[mask])
148
+ coor_y[mask] = c * np.cos(u[mask])
149
+
150
+ mask = (tp == 5)
151
+ c = 0.5 * np.tan(np.pi / 2 - np.abs(v[mask]))
152
+ coor_x[mask] = c * np.sin(u[mask])
153
+ coor_y[mask] = -c * np.cos(u[mask])
154
+
155
+ # Final renormalize
156
+ coor_x = (np.clip(coor_x, -0.5, 0.5) + 0.5) * face_w
157
+ coor_y = (np.clip(coor_y, -0.5, 0.5) + 0.5) * face_w
158
+
159
+ equirec = np.stack([
160
+ utils.sample_cubefaces(cube_faces[..., i], tp, coor_y, coor_x, order=order)
161
+ for i in range(cube_faces.shape[3])
162
+ ], axis=-1)
163
+ return equirec
164
+
165
+
166
+
167
+ def convertCube2Slices(image):
168
+ '''
169
+ :param image: Image numpy array
170
+ :return: List of Torch Tensors, CxHxW
171
+ '''
172
+ image = convertTensor(image)
173
+ C, H, W = image.size()
174
+ #print(C,H,W)
175
+ top = torch.zeros((C,W//4,W//4))
176
+ left = torch.zeros(top.size())
177
+ front = torch.zeros(top.size())
178
+ right = torch.zeros(top.size())
179
+ back = torch.zeros(top.size())
180
+ bottom = torch.zeros(top.size())
181
+
182
+ top = image[:, 0:H//3, (W//4):(W//4)*2]
183
+ left = image[:, (H//3):(H//3)*2, 0:W//4]
184
+ front = image[:, (H//3):(H//3)*2, (W//4):(W//4)*2]
185
+ right = image[:, (H//3):(H//3)*2, (W//4)*2:(W//4)*3]
186
+ back = image[:, (H // 3):(H // 3) * 2, (W // 4) * 3:]
187
+ bottom = image[:, (H//3)*2:, (W//4):(W//4)*2]
188
+
189
+ '''
190
+ save_image(top, 'top.png')
191
+ save_image(left, 'left.png')
192
+ save_image(front, 'front.png')
193
+ save_image(right, 'right.png')
194
+ save_image(back, 'back.png')
195
+ save_image(bottom, 'bottom.png')
196
+ '''
197
+ return [top, left, front, right, back, bottom]
198
+
199
+ def convertSlicesToCube(imageList):
200
+ '''
201
+ top = convertTensor(readImage(imageList[0]))
202
+ left = convertTensor(readImage(imageList[1]))
203
+ front = convertTensor(readImage(imageList[2]))
204
+ right = convertTensor(readImage(imageList[3]))
205
+ back = convertTensor(readImage(imageList[4]))
206
+ bottom = convertTensor(readImage(imageList[5]))
207
+ '''
208
+ top = imageList[0]
209
+ left = imageList[1]
210
+ front = imageList[2]
211
+ right = imageList[3]
212
+ back = imageList[4]
213
+ bottom = imageList[5]
214
+
215
+ C, H, W = 3, top.size()[1]*3, top.size()[2]*4
216
+ cube = torch.zeros((C, H, W))
217
+
218
+ cube[:, 0:H//3, (W//4):(W//4)*2] = top
219
+ cube[:, (H // 3):(H // 3) * 2, 0:W // 4] = left
220
+ cube[:, (H // 3):(H // 3) * 2, (W // 4):(W // 4) * 2] = front
221
+ cube[:, (H // 3):(H // 3) * 2, (W // 4) * 2:(W // 4) * 3] = right
222
+ cube[:, (H // 3):(H // 3) * 2, (W // 4) * 3:] = back
223
+ cube[:, (H // 3) * 2:, (W // 4):(W // 4) * 2] = bottom
224
+
225
+ return cube
226
+
227
+
228
+
229
+ #=================== Quality Measures =======================#
230
+ '''
231
+ Predicted Shape : BxCxHxW
232
+ Original Shape : BxCxHxW
233
+ Data Type: Torch Tensor
234
+ '''
235
+ def getSSIM(predicted, original):
236
+ SSIM = torchmetrics.StructuralSimilarityIndexMeasure()
237
+ return SSIM(predicted, original).item()
238
+
239
+
240
+ def getPSNR(predicted, original):
241
+ PSNR = torchmetrics.PeakSignalNoiseRatio()
242
+ return PSNR(predicted, original).item()
243
+
244
+
245
+ def getMSE(predicted, original):
246
+ MSE = torchmetrics.MeanSquaredError()
247
+ return MSE(predicted, original).item()
248
+
249
+
250
+ def getMAE(predicted, original):
251
+ MAE = torchmetrics.MeanAbsoluteError()
252
+ return MAE(predicted, original).item()
253
+
254
+
255
+
256
+ if __name__ == "__main__":
257
+
258
+ '''
259
+ img = readImage("31_image_0_0.png")
260
+ img = convertERP2Cube(e_img=np.asarray(img), face_w=256)
261
+ img = Image.fromarray(img.astype('uint8'),'RGB')
262
+ convertCube2Slices(img)
263
+ '''
264
+ #image = convertSlicesToCube(["top.png", "left.png", "front.png", "right.png", "back.png", "bottom.png"])
265
+ #writeTensorImage(image,'this.png')
266
+
267
+ '''
268
+ writeImage(img, 'cube.png')
269
+
270
+ img = readImage('cube.png')
271
+ img = convertCube2ERP(np.asarray(img),512,1024)
272
+ img = Image.fromarray(img.astype('uint8'),'RGB')
273
+ writeImage(img, 'cubeERP.png')
274
+
275
+
276
+ img1 = readImage("31_image_0_0.png")
277
+ img2 = readImage("cubeERP.png")
278
+ img1 = convertTensor(img1)
279
+ img2 = convertTensor(img2)
280
+ print(getSSIM(img1.unsqueeze(0), img2.unsqueeze(0)))
281
+ '''
282
+
283
+ #img = rotateERP180(img)
284
+ #writeTensorImage(img, 'rotated_image.png')
285
+ #img = convertTensor(img)
286
+ #print(getMAE(img.unsqueeze(0),img.unsqueeze(0)))
287
+
288
+
289
+
290
+
model_Large.py ADDED
@@ -0,0 +1,535 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import warnings
5
+ warnings.filterwarnings("ignore")
6
+ import torchvision
7
+ import parameters as params
8
+ import timm
9
+
10
+ class DinoV2FeatureExtractor(nn.Module):
11
+ def __init__(self, out_channels=256, out_size=(64, 64)):
12
+ super().__init__()
13
+ self.dino = timm.create_model('vit_base_patch14_dinov2.lvd142m', pretrained=False)
14
+ self.dino.eval()
15
+ for p in self.dino.parameters():
16
+ p.requires_grad = False
17
+
18
+ self.out_size = out_size
19
+ self.feat_proj = nn.Sequential(
20
+ nn.Conv2d(self.dino.embed_dim, out_channels, kernel_size=1),
21
+ nn.ReLU(),
22
+ )
23
+
24
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
25
+ x = F.interpolate(x, size=(518, 518), mode='bilinear', align_corners=False)
26
+ patch_tokens = self.dino.forward_features(x)
27
+ patch_tokens = patch_tokens[:, 1:]
28
+ B, N, C = patch_tokens.shape
29
+ h = w = int(N ** 0.5)
30
+ feat_map = patch_tokens.transpose(1, 2).reshape(B, C, h, w) # [B, C, H', W']
31
+ feat_map = F.interpolate(feat_map, size=self.out_size, mode='bilinear', align_corners=False)
32
+ return self.feat_proj(feat_map)
33
+
34
+ def getLinearLayer(in_feat, out_feat, activation=nn.ReLU(True)):
35
+ return nn.Sequential(
36
+ nn.Linear(in_features=in_feat, out_features=out_feat, bias=True),
37
+ activation
38
+ )
39
+
40
+ def getConvLayer(in_channel,out_channel,stride=1,padding=1,activation=nn.ReLU()):
41
+ return nn.Sequential(nn.Conv2d(in_channel,
42
+ out_channel,
43
+ kernel_size=3,
44
+ stride=stride,
45
+ padding=padding,
46
+ padding_mode='reflect'),
47
+ activation)
48
+
49
+ def getConvTransposeLayer(in_channel, out_channel,kernel=3,stride=1,padding=1,activation=nn.ReLU()):
50
+ return nn.Sequential(nn.ConvTranspose2d(in_channel,
51
+ out_channel,
52
+ kernel_size = kernel,
53
+ stride=stride,
54
+ padding=padding),
55
+ activation)
56
+
57
+
58
+ class ResidualBlock(nn.Module):
59
+ def __init__(self, in_channels, out_channels, stride=1):
60
+ super(ResidualBlock, self).__init__()
61
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
62
+ self.relu = nn.ReLU()
63
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
64
+ self.stride = stride
65
+
66
+ self.shortcut = nn.Sequential()
67
+ if stride != 1 or in_channels != out_channels:
68
+ self.shortcut = nn.Sequential(
69
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
70
+ nn.BatchNorm2d(out_channels)
71
+ )
72
+
73
+ def forward(self, x):
74
+ residual = x
75
+
76
+ out = self.conv1(x)
77
+ out = self.relu(out)
78
+
79
+ out = self.conv2(out)
80
+
81
+ out = out + self.shortcut(residual)
82
+ out = self.relu(out)
83
+ return out
84
+
85
+
86
+ # class ResidualBlock(nn.Module):
87
+ # def __init__(self, in_channels, out_channels, stride=1, expansion=4):
88
+ # super().__init__()
89
+ # mid_channels = out_channels // expansion
90
+ # self.pw_reduce = nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False)
91
+ # self.bn1 = nn.BatchNorm2d(mid_channels)
92
+ # self.dw = nn.Conv2d(mid_channels, mid_channels, kernel_size=3,
93
+ # stride=stride, padding=1, groups=mid_channels, bias=False)
94
+ # self.bn2 = nn.BatchNorm2d(mid_channels)
95
+ # self.pw_expand = nn.Conv2d(mid_channels, out_channels, kernel_size=1, bias=False)
96
+ # self.bn3 = nn.BatchNorm2d(out_channels)
97
+ # self.relu = nn.ReLU(inplace=True)
98
+ # self.stride = stride
99
+ # if stride != 1 or in_channels != out_channels:
100
+ # self.shortcut = nn.Sequential(
101
+ # nn.Conv2d(in_channels, out_channels, kernel_size=1,
102
+ # stride=stride, bias=False),
103
+ # nn.BatchNorm2d(out_channels),
104
+ # )
105
+ # else:
106
+ # self.shortcut = nn.Identity()
107
+
108
+ # def forward(self, x):
109
+ # identity = x
110
+
111
+ # out = self.pw_reduce(x)
112
+ # out = self.bn1(out)
113
+ # out = self.relu(out)
114
+
115
+ # out = self.dw(out)
116
+ # out = self.bn2(out)
117
+ # out = self.relu(out)
118
+
119
+ # out = self.pw_expand(out)
120
+ # out = self.bn3(out)
121
+
122
+ # out += self.shortcut(identity)
123
+ # out = self.relu(out)
124
+ # return out
125
+
126
+ class FeatureNet(nn.Module):
127
+ def __init__(self,height,width):
128
+ super().__init__()
129
+ model = torchvision.models.resnet152(pretrained=False)
130
+ layers = list(model.children())
131
+ self.FeatureEncoder = torch.nn.Sequential(*layers[:5].copy())
132
+ self.expand_layer = ResidualBlock(256, 500)
133
+
134
+ def forward(self, x):
135
+ x = self.FeatureEncoder(x)
136
+ x = self.expand_layer(x)
137
+ return x
138
+
139
+ def apply_feature_encoder(self, x):
140
+ x = self.FeatureEncoder(x)
141
+ x = self.expand_layer(x)
142
+ return x
143
+
144
+ class Encoder(nn.Module):
145
+ def __init__(self, height, width, total_image_input=1):
146
+ super().__init__()
147
+ self.height = height
148
+ self.width = width
149
+ self.encoder_pre = ResidualBlock((total_image_input*3), 20)
150
+ self.encoder_layer1 = ResidualBlock(20, 30)
151
+ self.encoder_layer2 = ResidualBlock(30, 50)
152
+
153
+ self.encoder_layer3 = nn.Sequential(
154
+ ResidualBlock(50, 100),
155
+ nn.MaxPool2d(kernel_size=2, stride=2)
156
+ )
157
+
158
+ self.encoder_layer4 = ResidualBlock(100, 500)
159
+ self.encoder_layer5 = nn.Sequential(
160
+ ResidualBlock(500, 500),
161
+ nn.MaxPool2d(kernel_size=2, stride=2)
162
+ )
163
+
164
+ self.encoder_layer6 = ResidualBlock(500, 500)
165
+ self.encoder_layer7 = nn.Sequential(
166
+ ResidualBlock(500, 500),
167
+ nn.MaxPool2d(kernel_size=2, stride=2)
168
+ )
169
+
170
+ self.encoder_layer8 = ResidualBlock(500, 1000)
171
+ self.encoder_layer9 = nn.Sequential(
172
+ ResidualBlock(1000, 1000),
173
+ nn.MaxPool2d(kernel_size=2, stride=2)
174
+ )
175
+
176
+ self.encoder_layer10 = ResidualBlock(1000, 1000)
177
+ self.encoder_layer11 = ResidualBlock(1000, 1000)
178
+
179
+ def forward(self, x, height=None, width=None):
180
+ if height == None and width == None:
181
+ height = self.height
182
+ width = self.width
183
+
184
+ x = self.encoder_pre(x)
185
+ x = self.encoder_layer1(x)
186
+ x = self.encoder_layer2(x)
187
+ skip1 = self.encoder_layer3(x)
188
+
189
+ x = self.encoder_layer4(skip1)
190
+ skip2 = self.encoder_layer5(x)
191
+
192
+ x = self.encoder_layer6(skip2)
193
+ skip3 = self.encoder_layer7(x)
194
+
195
+ x = self.encoder_layer8(skip3)
196
+ skip4 = self.encoder_layer9(x)
197
+
198
+ x = self.encoder_layer10(skip4)
199
+ x = self.encoder_layer11(x)
200
+
201
+ return x, [skip1, skip2, skip3, skip4]
202
+
203
+ class DecoderRGB(nn.Module):
204
+ def __init__(self,height,width):
205
+ super().__init__()
206
+ self.height = height
207
+ self.width = width
208
+ self.decoder_layer1 = ResidualBlock(1000, 1000)
209
+ self.decoder_layer2 = ResidualBlock(1000, 1000)
210
+ self.decoder_layer3 = ResidualBlock(1000, 1000)
211
+
212
+ self.decoder_layer4 = nn.Sequential(
213
+ nn.ConvTranspose2d(1000, 500, 2, stride=2, padding=0),
214
+ nn.ReLU(True)
215
+ )
216
+ self.decoder_layer5 = ResidualBlock(500, 500)
217
+
218
+ self.decoder_layer6 = nn.Sequential(
219
+ nn.ConvTranspose2d(500, 500, 2, stride=2, padding=0),
220
+ nn.ReLU(True)
221
+ )
222
+ self.decoder_layer7 = ResidualBlock(500, 500)
223
+
224
+ self.decoder_layer8 = nn.Sequential(
225
+ nn.ConvTranspose2d(500, 100, 2, stride=2, padding=0),
226
+ nn.ReLU(True)
227
+ )
228
+ self.decoder_layer9 = ResidualBlock(100, 100)
229
+
230
+ self.decoder_layer10 = nn.Sequential(
231
+ nn.ConvTranspose2d(100, 100, 2, stride=2, padding=0),
232
+ nn.ReLU(True)
233
+ )
234
+ self.decoder_layer11 = ResidualBlock(100, 100)
235
+ self.decoder_layer12 = ResidualBlock(100, 96)
236
+ self.decoder_layer13 = ResidualBlock(96, 96)
237
+ self.decoder_layer14 = ResidualBlock(96, 96)
238
+ self.decoder_layer15 = nn.Sequential(
239
+ nn.Conv2d(96, 96, 3, stride=1, padding=1),
240
+ nn.Sigmoid()
241
+ )
242
+ self.decoder_layer16 = nn.Sequential(
243
+ nn.Conv2d(96, 96, 3, stride=1, padding=1),
244
+ nn.Sigmoid()
245
+ )
246
+
247
+ def forward(self, x, lower_skip_list, imagenet_features, height=None, width=None):
248
+ if height == None and width == None:
249
+ height = self.height
250
+ width = self.width
251
+
252
+ x = self.decoder_layer1(x)
253
+ x = self.decoder_layer2(x)
254
+ x = x + lower_skip_list[3]
255
+
256
+ x = self.decoder_layer3(x)
257
+ x = self.decoder_layer4(x)
258
+ x = x + lower_skip_list[2]
259
+
260
+ x = self.decoder_layer5(x)
261
+ x = self.decoder_layer6(x)
262
+ x = x + lower_skip_list[1] + imagenet_features
263
+
264
+ x = self.decoder_layer7(x)
265
+ x = self.decoder_layer8(x)
266
+ x = x + lower_skip_list[0]
267
+
268
+ x = self.decoder_layer9(x)
269
+ x = self.decoder_layer10(x)
270
+ x = self.decoder_layer11(x)
271
+ x = self.decoder_layer12(x)
272
+ x = self.decoder_layer13(x)
273
+ x = self.decoder_layer14(x)
274
+ x = self.decoder_layer15(x)
275
+ x = self.decoder_layer16(x)
276
+ x = x.view(x.size()[0], 32, 3, height, width)
277
+ return x
278
+
279
+ class DecoderSigma(nn.Module):
280
+ def __init__(self,height,width):
281
+ super().__init__()
282
+ self.height = height
283
+ self.width = width
284
+ self.decoder_layer1 = ResidualBlock(1000, 1000)
285
+ self.decoder_layer2 = ResidualBlock(1000, 1000)
286
+ self.decoder_layer3 = ResidualBlock(1000, 1000)
287
+
288
+ self.decoder_layer4 = nn.Sequential(
289
+ nn.ConvTranspose2d(1000, 500, 2, stride=2, padding=0),
290
+ nn.ReLU(True)
291
+ )
292
+ self.decoder_layer5 = ResidualBlock(500, 500)
293
+
294
+ self.decoder_layer6 = nn.Sequential(
295
+ nn.ConvTranspose2d(500, 500, 2, stride=2, padding=0),
296
+ nn.ReLU(True)
297
+ )
298
+ self.decoder_layer7 = ResidualBlock(500, 500)
299
+
300
+ self.decoder_layer8 = nn.Sequential(
301
+ nn.ConvTranspose2d(500, 100, 2, stride=2, padding=0),
302
+ nn.ReLU(True)
303
+ )
304
+ self.decoder_layer9 = ResidualBlock(100, 100)
305
+
306
+ self.decoder_layer10 = nn.Sequential(
307
+ nn.ConvTranspose2d(100, 100, 2, stride=2, padding=0),
308
+ nn.ReLU(True)
309
+ )
310
+ self.decoder_layer11 = ResidualBlock(100, 100)
311
+ self.decoder_layer12 = ResidualBlock(100, 50)
312
+ self.decoder_layer13 = ResidualBlock(50, 40)
313
+ self.decoder_layer14 = ResidualBlock(40, 32)
314
+ self.decoder_layer15 = nn.Sequential(
315
+ nn.Conv2d(32, 32, 3, stride=1, padding=1),
316
+ nn.ReLU(True)
317
+ )
318
+ self.decoder_layer16 = nn.Sequential(
319
+ nn.Conv2d(32, 32, 3, stride=1, padding=1),
320
+ nn.ReLU(True)
321
+ )
322
+
323
+ def forward(self, x, lower_skip_list, imagenet_features, height=None, width=None):
324
+ if height == None and width == None:
325
+ height = self.height
326
+ width = self.width
327
+
328
+ x = self.decoder_layer1(x)
329
+ x = self.decoder_layer2(x)
330
+ x = x + lower_skip_list[3]
331
+
332
+ x = self.decoder_layer3(x)
333
+ x = self.decoder_layer4(x)
334
+ x = x + lower_skip_list[2]
335
+
336
+ x = self.decoder_layer5(x)
337
+ x = self.decoder_layer6(x)
338
+ x = x + lower_skip_list[1] + imagenet_features
339
+
340
+ x = self.decoder_layer7(x)
341
+ x = self.decoder_layer8(x)
342
+ x = x + lower_skip_list[0]
343
+
344
+ x = self.decoder_layer9(x)
345
+ x = self.decoder_layer10(x)
346
+ x = self.decoder_layer11(x)
347
+ x = self.decoder_layer12(x)
348
+ x = self.decoder_layer13(x)
349
+ x = self.decoder_layer14(x)
350
+ x = self.decoder_layer15(x)
351
+ x = self.decoder_layer16(x)
352
+ x = x.view(x.size()[0], 32, 1, height, width)
353
+ return x
354
+
355
+
356
+ class DecoderDepth(nn.Module):
357
+ def __init__(self,height,width):
358
+ super().__init__()
359
+ self.height = height
360
+ self.width = width
361
+ self.decoder_layer1 = ResidualBlock(1000, 1000)
362
+ self.decoder_layer2 = ResidualBlock(1000, 1000)
363
+ self.decoder_layer3 = ResidualBlock(1000, 1000)
364
+
365
+ self.decoder_layer4 = nn.Sequential(
366
+ nn.ConvTranspose2d(1000, 500, 2, stride=2, padding=0),
367
+ nn.ReLU(True)
368
+ )
369
+ self.decoder_layer5 = ResidualBlock(500, 500)
370
+
371
+ self.decoder_layer6 = nn.Sequential(
372
+ nn.ConvTranspose2d(500, 500, 2, stride=2, padding=0),
373
+ nn.ReLU(True)
374
+ )
375
+ self.decoder_layer7 = ResidualBlock(500, 500)
376
+
377
+ self.decoder_layer8 = nn.Sequential(
378
+ nn.ConvTranspose2d(500, 100, 2, stride=2, padding=0),
379
+ nn.ReLU(True)
380
+ )
381
+ self.decoder_layer9 = ResidualBlock(100, 100)
382
+
383
+ self.decoder_layer10 = nn.Sequential(
384
+ nn.ConvTranspose2d(100, 100, 2, stride=2, padding=0),
385
+ nn.ReLU(True)
386
+ )
387
+ self.decoder_layer11 = ResidualBlock(100, 100)
388
+ self.decoder_layer12 = ResidualBlock(100, 50)
389
+ self.decoder_layer13 = ResidualBlock(50, 40)
390
+ self.decoder_layer14 = ResidualBlock(40, 16)
391
+ self.decoder_layer15 = nn.Sequential(
392
+ nn.Conv2d(16, 8, 3, stride=1, padding=1),
393
+ nn.ReLU(True)
394
+ )
395
+ self.decoder_layer16 = nn.Sequential(
396
+ nn.Conv2d(8, 1, 3, stride=1, padding=1),
397
+ nn.ReLU(True)
398
+ )
399
+
400
+ def forward(self, x, lower_skip_list, imagenet_features, height=None, width=None):
401
+ if height == None and width == None:
402
+ height = self.height
403
+ width = self.width
404
+
405
+ x = self.decoder_layer1(x)
406
+ x = self.decoder_layer2(x)
407
+ x = x + lower_skip_list[3]
408
+
409
+ x = self.decoder_layer3(x)
410
+ x = self.decoder_layer4(x)
411
+ x = x + lower_skip_list[2]
412
+
413
+ x = self.decoder_layer5(x)
414
+ x = self.decoder_layer6(x)
415
+ x = x + lower_skip_list[1] + imagenet_features
416
+
417
+ x = self.decoder_layer7(x)
418
+ x = self.decoder_layer8(x)
419
+ x = x + lower_skip_list[0]
420
+
421
+ x = self.decoder_layer9(x)
422
+ x = self.decoder_layer10(x)
423
+ x = self.decoder_layer11(x)
424
+ x = self.decoder_layer12(x)
425
+ x = self.decoder_layer13(x)
426
+ x = self.decoder_layer14(x)
427
+ x = self.decoder_layer15(x)
428
+ x = self.decoder_layer16(x)
429
+ return x
430
+
431
+ class MMPI(nn.Module):
432
+ def __init__(self,total_image_input=1, height=384,width=384):
433
+ super().__init__()
434
+ self.height = height
435
+ self.width = width
436
+ self.feature_encoder = FeatureNet(height,width)
437
+ self.lower_encoder = Encoder(height, width, total_image_input)
438
+ self.merge_decoder_rgb = DecoderRGB(height, width)
439
+ self.merge_decoder_sigma = DecoderSigma(height, width)
440
+ self.depth_decoder = DecoderDepth(height, width)
441
+
442
+ def forward(self, x, height=None, width=None):
443
+ if height == None and width == None:
444
+ height = self.height
445
+ width = self.width
446
+
447
+ imagenet_fatures = self.feature_encoder.apply_feature_encoder(x)
448
+ lower_feature, skip_list = self.lower_encoder(x, height, width)
449
+
450
+ merged_feature_rgb = self.merge_decoder_rgb(lower_feature, skip_list, imagenet_fatures, height, width)
451
+ merged_feature_sigma = self.merge_decoder_sigma(lower_feature, skip_list, imagenet_fatures, height, width)
452
+
453
+ merged_feature_depth = self.depth_decoder(lower_feature, skip_list, imagenet_fatures)
454
+
455
+ return merged_feature_rgb, merged_feature_sigma, merged_feature_depth
456
+
457
+ def get_rgb_sigma(self, x, height=None, width=None):
458
+ if height == None and width == None:
459
+ height = self.height
460
+ width = self.width
461
+
462
+ imagenet_fatures = self.feature_encoder.apply_feature_encoder(x)
463
+ lower_feature, skip_list = self.lower_encoder(x, height, width)
464
+ merged_feature_rgb = self.merge_decoder_rgb(lower_feature, skip_list, imagenet_fatures, height, width)
465
+ merged_feature_sigma = self.merge_decoder_sigma(lower_feature, skip_list, imagenet_fatures, height, width)
466
+ return merged_feature_rgb, merged_feature_sigma
467
+
468
+ def get_depth(self, x, height=None, width=None):
469
+ if height == None and width == None:
470
+ height = self.height
471
+ width = self.width
472
+
473
+ imagenet_fatures = self.feature_encoder.apply_feature_encoder(x)
474
+ lower_feature, skip_list = self.lower_encoder(x, height, width)
475
+ merged_feature_depth = self.depth_decoder(lower_feature, skip_list, imagenet_fatures)
476
+ return merged_feature_depth
477
+
478
+ def get_layer_depth(self, x, grid, height=None, width=None):
479
+ if height == None and width == None:
480
+ height = self.height
481
+ width = self.width
482
+
483
+ imagenet_fatures = self.feature_encoder.apply_feature_encoder(x)
484
+ lower_feature, skip_list = self.lower_encoder(x, height, width)
485
+
486
+ rgb_layers = self.merge_decoder_rgb(lower_feature, skip_list, imagenet_fatures, height, width)
487
+ sigma_layers = self.merge_decoder_sigma(lower_feature, skip_list, imagenet_fatures, height, width)
488
+
489
+ pred_mpi_planes = torch.randn((1, 4, height, width)).to(params.DEVICE)
490
+ for i in range(params.params_num_planes):
491
+ RGBA = torch.cat((rgb_layers[0,i,:,:,:],sigma_layers[0,i,:,:,:]),dim=0).unsqueeze(0)
492
+ pred_mpi_planes = torch.cat((pred_mpi_planes,RGBA),dim=0)
493
+
494
+ pred_mpi_planes = pred_mpi_planes[1:,:,:,:].unsqueeze(0)
495
+
496
+ sigma = pred_mpi_planes[:, :, 3, :, :]
497
+ B, D, H, W = sigma.shape
498
+
499
+ pred_mpi_disp = grid
500
+ disp_sorted, _ = pred_mpi_disp.sort(dim=1)
501
+ delta = disp_sorted[:, 1:] - disp_sorted[:, :-1]
502
+ delta_last = delta[:, -1:]
503
+ delta = torch.cat([delta, delta_last], dim=1)
504
+
505
+ delta = delta.unsqueeze(-1).unsqueeze(-1).expand_as(sigma)
506
+
507
+ alpha = 1.0 - torch.exp(-delta * sigma)
508
+
509
+ transmittance = torch.cumprod(1 - alpha + 1e-7, dim=1)
510
+ shifted_transmittance = torch.ones_like(transmittance)
511
+ shifted_transmittance[:, 1:, :, :] = transmittance[:, :-1, :, :]
512
+
513
+ disparity = pred_mpi_disp.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, H, W)
514
+
515
+ disparity_map = (disparity * alpha * shifted_transmittance).sum(dim=1, keepdim=True)
516
+
517
+ return disparity_map
518
+
519
+ def get_layers(self, x, height=None, width=None):
520
+ if height == None and width == None:
521
+ height = self.height
522
+ width = self.width
523
+
524
+ imagenet_fatures = self.feature_encoder.apply_feature_encoder(x)
525
+ lower_feature, skip_list = self.lower_encoder(x, height, width)
526
+ merged_feature_rgb = self.merge_decoder_rgb(lower_feature, skip_list, imagenet_fatures, height, width)
527
+ merged_feature_sigma = self.merge_decoder_sigma(lower_feature, skip_list, imagenet_fatures, height, width)
528
+ return merged_feature_rgb, merged_feature_sigma
529
+
530
+
531
+
532
+
533
+
534
+
535
+
model_Medium.py ADDED
@@ -0,0 +1,535 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import warnings
5
+ warnings.filterwarnings("ignore")
6
+ import torchvision
7
+ import parameters as params
8
+ import timm
9
+
10
+ class DinoV2FeatureExtractor(nn.Module):
11
+ def __init__(self, out_channels=256, out_size=(64, 64)):
12
+ super().__init__()
13
+ self.dino = timm.create_model('vit_base_patch14_dinov2.lvd142m', pretrained=False)
14
+ self.dino.eval()
15
+ for p in self.dino.parameters():
16
+ p.requires_grad = False
17
+
18
+ self.out_size = out_size
19
+ self.feat_proj = nn.Sequential(
20
+ nn.Conv2d(self.dino.embed_dim, out_channels, kernel_size=1),
21
+ nn.ReLU(),
22
+ )
23
+
24
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
25
+ x = F.interpolate(x, size=(518, 518), mode='bilinear', align_corners=False)
26
+ patch_tokens = self.dino.forward_features(x)
27
+ patch_tokens = patch_tokens[:, 1:]
28
+ B, N, C = patch_tokens.shape
29
+ h = w = int(N ** 0.5)
30
+ feat_map = patch_tokens.transpose(1, 2).reshape(B, C, h, w) # [B, C, H', W']
31
+ feat_map = F.interpolate(feat_map, size=self.out_size, mode='bilinear', align_corners=False)
32
+ return self.feat_proj(feat_map)
33
+
34
+ def getLinearLayer(in_feat, out_feat, activation=nn.ReLU(True)):
35
+ return nn.Sequential(
36
+ nn.Linear(in_features=in_feat, out_features=out_feat, bias=True),
37
+ activation
38
+ )
39
+
40
+ def getConvLayer(in_channel,out_channel,stride=1,padding=1,activation=nn.ReLU()):
41
+ return nn.Sequential(nn.Conv2d(in_channel,
42
+ out_channel,
43
+ kernel_size=3,
44
+ stride=stride,
45
+ padding=padding,
46
+ padding_mode='reflect'),
47
+ activation)
48
+
49
+ def getConvTransposeLayer(in_channel, out_channel,kernel=3,stride=1,padding=1,activation=nn.ReLU()):
50
+ return nn.Sequential(nn.ConvTranspose2d(in_channel,
51
+ out_channel,
52
+ kernel_size = kernel,
53
+ stride=stride,
54
+ padding=padding),
55
+ activation)
56
+
57
+
58
+ class ResidualBlock(nn.Module):
59
+ def __init__(self, in_channels, out_channels, stride=1):
60
+ super(ResidualBlock, self).__init__()
61
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
62
+ self.relu = nn.ReLU()
63
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
64
+ self.stride = stride
65
+
66
+ self.shortcut = nn.Sequential()
67
+ if stride != 1 or in_channels != out_channels:
68
+ self.shortcut = nn.Sequential(
69
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
70
+ nn.BatchNorm2d(out_channels)
71
+ )
72
+
73
+ def forward(self, x):
74
+ residual = x
75
+
76
+ out = self.conv1(x)
77
+ out = self.relu(out)
78
+
79
+ out = self.conv2(out)
80
+
81
+ out = out + self.shortcut(residual)
82
+ out = self.relu(out)
83
+ return out
84
+
85
+
86
+ # class ResidualBlock(nn.Module):
87
+ # def __init__(self, in_channels, out_channels, stride=1, expansion=4):
88
+ # super().__init__()
89
+ # mid_channels = out_channels // expansion
90
+ # self.pw_reduce = nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False)
91
+ # self.bn1 = nn.BatchNorm2d(mid_channels)
92
+ # self.dw = nn.Conv2d(mid_channels, mid_channels, kernel_size=3,
93
+ # stride=stride, padding=1, groups=mid_channels, bias=False)
94
+ # self.bn2 = nn.BatchNorm2d(mid_channels)
95
+ # self.pw_expand = nn.Conv2d(mid_channels, out_channels, kernel_size=1, bias=False)
96
+ # self.bn3 = nn.BatchNorm2d(out_channels)
97
+ # self.relu = nn.ReLU(inplace=True)
98
+ # self.stride = stride
99
+ # if stride != 1 or in_channels != out_channels:
100
+ # self.shortcut = nn.Sequential(
101
+ # nn.Conv2d(in_channels, out_channels, kernel_size=1,
102
+ # stride=stride, bias=False),
103
+ # nn.BatchNorm2d(out_channels),
104
+ # )
105
+ # else:
106
+ # self.shortcut = nn.Identity()
107
+
108
+ # def forward(self, x):
109
+ # identity = x
110
+
111
+ # out = self.pw_reduce(x)
112
+ # out = self.bn1(out)
113
+ # out = self.relu(out)
114
+
115
+ # out = self.dw(out)
116
+ # out = self.bn2(out)
117
+ # out = self.relu(out)
118
+
119
+ # out = self.pw_expand(out)
120
+ # out = self.bn3(out)
121
+
122
+ # out += self.shortcut(identity)
123
+ # out = self.relu(out)
124
+ # return out
125
+
126
+ class FeatureNet(nn.Module):
127
+ def __init__(self,height,width):
128
+ super().__init__()
129
+ model = torchvision.models.resnet152(pretrained=False)
130
+ layers = list(model.children())
131
+ self.FeatureEncoder = torch.nn.Sequential(*layers[:5].copy())
132
+ self.expand_layer = ResidualBlock(256, 200)
133
+
134
+ def forward(self, x):
135
+ x = self.FeatureEncoder(x)
136
+ x = self.expand_layer(x)
137
+ return x
138
+
139
+ def apply_feature_encoder(self, x):
140
+ x = self.FeatureEncoder(x)
141
+ x = self.expand_layer(x)
142
+ return x
143
+
144
+ class Encoder(nn.Module):
145
+ def __init__(self, height, width, total_image_input=1):
146
+ super().__init__()
147
+ self.height = height
148
+ self.width = width
149
+ self.encoder_pre = ResidualBlock((total_image_input*3), 20)
150
+ self.encoder_layer1 = ResidualBlock(20, 30)
151
+ self.encoder_layer2 = ResidualBlock(30, 50)
152
+
153
+ self.encoder_layer3 = nn.Sequential(
154
+ ResidualBlock(50, 100),
155
+ nn.MaxPool2d(kernel_size=2, stride=2)
156
+ )
157
+
158
+ self.encoder_layer4 = ResidualBlock(100, 200)
159
+ self.encoder_layer5 = nn.Sequential(
160
+ ResidualBlock(200, 200),
161
+ nn.MaxPool2d(kernel_size=2, stride=2)
162
+ )
163
+
164
+ self.encoder_layer6 = ResidualBlock(200, 200)
165
+ self.encoder_layer7 = nn.Sequential(
166
+ ResidualBlock(200, 200),
167
+ nn.MaxPool2d(kernel_size=2, stride=2)
168
+ )
169
+
170
+ self.encoder_layer8 = ResidualBlock(200, 500)
171
+ self.encoder_layer9 = nn.Sequential(
172
+ ResidualBlock(500, 500),
173
+ nn.MaxPool2d(kernel_size=2, stride=2)
174
+ )
175
+
176
+ self.encoder_layer10 = ResidualBlock(500, 500)
177
+ self.encoder_layer11 = ResidualBlock(500, 500)
178
+
179
+ def forward(self, x, height=None, width=None):
180
+ if height == None and width == None:
181
+ height = self.height
182
+ width = self.width
183
+
184
+ x = self.encoder_pre(x)
185
+ x = self.encoder_layer1(x)
186
+ x = self.encoder_layer2(x)
187
+ skip1 = self.encoder_layer3(x)
188
+
189
+ x = self.encoder_layer4(skip1)
190
+ skip2 = self.encoder_layer5(x)
191
+
192
+ x = self.encoder_layer6(skip2)
193
+ skip3 = self.encoder_layer7(x)
194
+
195
+ x = self.encoder_layer8(skip3)
196
+ skip4 = self.encoder_layer9(x)
197
+
198
+ x = self.encoder_layer10(skip4)
199
+ x = self.encoder_layer11(x)
200
+
201
+ return x, [skip1, skip2, skip3, skip4]
202
+
203
+ class DecoderRGB(nn.Module):
204
+ def __init__(self,height,width):
205
+ super().__init__()
206
+ self.height = height
207
+ self.width = width
208
+ self.decoder_layer1 = ResidualBlock(500, 500)
209
+ self.decoder_layer2 = ResidualBlock(500, 500)
210
+ self.decoder_layer3 = ResidualBlock(500, 500)
211
+
212
+ self.decoder_layer4 = nn.Sequential(
213
+ nn.ConvTranspose2d(500, 200, 2, stride=2, padding=0),
214
+ nn.ReLU(True)
215
+ )
216
+ self.decoder_layer5 = ResidualBlock(200, 200)
217
+
218
+ self.decoder_layer6 = nn.Sequential(
219
+ nn.ConvTranspose2d(200, 200, 2, stride=2, padding=0),
220
+ nn.ReLU(True)
221
+ )
222
+ self.decoder_layer7 = ResidualBlock(200, 200)
223
+
224
+ self.decoder_layer8 = nn.Sequential(
225
+ nn.ConvTranspose2d(200, 100, 2, stride=2, padding=0),
226
+ nn.ReLU(True)
227
+ )
228
+ self.decoder_layer9 = ResidualBlock(100, 100)
229
+
230
+ self.decoder_layer10 = nn.Sequential(
231
+ nn.ConvTranspose2d(100, 100, 2, stride=2, padding=0),
232
+ nn.ReLU(True)
233
+ )
234
+ self.decoder_layer11 = ResidualBlock(100, 100)
235
+ self.decoder_layer12 = ResidualBlock(100, 96)
236
+ self.decoder_layer13 = ResidualBlock(96, 96)
237
+ self.decoder_layer14 = ResidualBlock(96, 96)
238
+ self.decoder_layer15 = nn.Sequential(
239
+ nn.Conv2d(96, 96, 3, stride=1, padding=1),
240
+ nn.Sigmoid()
241
+ )
242
+ self.decoder_layer16 = nn.Sequential(
243
+ nn.Conv2d(96, 96, 3, stride=1, padding=1),
244
+ nn.Sigmoid()
245
+ )
246
+
247
+ def forward(self, x, lower_skip_list, imagenet_features, height=None, width=None):
248
+ if height == None and width == None:
249
+ height = self.height
250
+ width = self.width
251
+
252
+ x = self.decoder_layer1(x)
253
+ x = self.decoder_layer2(x)
254
+ x = x + lower_skip_list[3]
255
+
256
+ x = self.decoder_layer3(x)
257
+ x = self.decoder_layer4(x)
258
+ x = x + lower_skip_list[2]
259
+
260
+ x = self.decoder_layer5(x)
261
+ x = self.decoder_layer6(x)
262
+ x = x + lower_skip_list[1] + imagenet_features
263
+
264
+ x = self.decoder_layer7(x)
265
+ x = self.decoder_layer8(x)
266
+ x = x + lower_skip_list[0]
267
+
268
+ x = self.decoder_layer9(x)
269
+ x = self.decoder_layer10(x)
270
+ x = self.decoder_layer11(x)
271
+ x = self.decoder_layer12(x)
272
+ x = self.decoder_layer13(x)
273
+ x = self.decoder_layer14(x)
274
+ x = self.decoder_layer15(x)
275
+ x = self.decoder_layer16(x)
276
+ x = x.view(x.size()[0], 32, 3, height, width)
277
+ return x
278
+
279
+ class DecoderSigma(nn.Module):
280
+ def __init__(self,height,width):
281
+ super().__init__()
282
+ self.height = height
283
+ self.width = width
284
+ self.decoder_layer1 = ResidualBlock(500, 500)
285
+ self.decoder_layer2 = ResidualBlock(500, 500)
286
+ self.decoder_layer3 = ResidualBlock(500, 500)
287
+
288
+ self.decoder_layer4 = nn.Sequential(
289
+ nn.ConvTranspose2d(500, 200, 2, stride=2, padding=0),
290
+ nn.ReLU(True)
291
+ )
292
+ self.decoder_layer5 = ResidualBlock(200, 200)
293
+
294
+ self.decoder_layer6 = nn.Sequential(
295
+ nn.ConvTranspose2d(200, 200, 2, stride=2, padding=0),
296
+ nn.ReLU(True)
297
+ )
298
+ self.decoder_layer7 = ResidualBlock(200, 200)
299
+
300
+ self.decoder_layer8 = nn.Sequential(
301
+ nn.ConvTranspose2d(200, 100, 2, stride=2, padding=0),
302
+ nn.ReLU(True)
303
+ )
304
+ self.decoder_layer9 = ResidualBlock(100, 100)
305
+
306
+ self.decoder_layer10 = nn.Sequential(
307
+ nn.ConvTranspose2d(100, 100, 2, stride=2, padding=0),
308
+ nn.ReLU(True)
309
+ )
310
+ self.decoder_layer11 = ResidualBlock(100, 100)
311
+ self.decoder_layer12 = ResidualBlock(100, 50)
312
+ self.decoder_layer13 = ResidualBlock(50, 40)
313
+ self.decoder_layer14 = ResidualBlock(40, 32)
314
+ self.decoder_layer15 = nn.Sequential(
315
+ nn.Conv2d(32, 32, 3, stride=1, padding=1),
316
+ nn.ReLU(True)
317
+ )
318
+ self.decoder_layer16 = nn.Sequential(
319
+ nn.Conv2d(32, 32, 3, stride=1, padding=1),
320
+ nn.ReLU(True)
321
+ )
322
+
323
+ def forward(self, x, lower_skip_list, imagenet_features, height=None, width=None):
324
+ if height == None and width == None:
325
+ height = self.height
326
+ width = self.width
327
+
328
+ x = self.decoder_layer1(x)
329
+ x = self.decoder_layer2(x)
330
+ x = x + lower_skip_list[3]
331
+
332
+ x = self.decoder_layer3(x)
333
+ x = self.decoder_layer4(x)
334
+ x = x + lower_skip_list[2]
335
+
336
+ x = self.decoder_layer5(x)
337
+ x = self.decoder_layer6(x)
338
+ x = x + lower_skip_list[1] + imagenet_features
339
+
340
+ x = self.decoder_layer7(x)
341
+ x = self.decoder_layer8(x)
342
+ x = x + lower_skip_list[0]
343
+
344
+ x = self.decoder_layer9(x)
345
+ x = self.decoder_layer10(x)
346
+ x = self.decoder_layer11(x)
347
+ x = self.decoder_layer12(x)
348
+ x = self.decoder_layer13(x)
349
+ x = self.decoder_layer14(x)
350
+ x = self.decoder_layer15(x)
351
+ x = self.decoder_layer16(x)
352
+ x = x.view(x.size()[0], 32, 1, height, width)
353
+ return x
354
+
355
+
356
+ class DecoderDepth(nn.Module):
357
+ def __init__(self,height,width):
358
+ super().__init__()
359
+ self.height = height
360
+ self.width = width
361
+ self.decoder_layer1 = ResidualBlock(500, 500)
362
+ self.decoder_layer2 = ResidualBlock(500, 500)
363
+ self.decoder_layer3 = ResidualBlock(500, 500)
364
+
365
+ self.decoder_layer4 = nn.Sequential(
366
+ nn.ConvTranspose2d(500, 200, 2, stride=2, padding=0),
367
+ nn.ReLU(True)
368
+ )
369
+ self.decoder_layer5 = ResidualBlock(200, 200)
370
+
371
+ self.decoder_layer6 = nn.Sequential(
372
+ nn.ConvTranspose2d(200, 200, 2, stride=2, padding=0),
373
+ nn.ReLU(True)
374
+ )
375
+ self.decoder_layer7 = ResidualBlock(200, 200)
376
+
377
+ self.decoder_layer8 = nn.Sequential(
378
+ nn.ConvTranspose2d(200, 100, 2, stride=2, padding=0),
379
+ nn.ReLU(True)
380
+ )
381
+ self.decoder_layer9 = ResidualBlock(100, 100)
382
+
383
+ self.decoder_layer10 = nn.Sequential(
384
+ nn.ConvTranspose2d(100, 100, 2, stride=2, padding=0),
385
+ nn.ReLU(True)
386
+ )
387
+ self.decoder_layer11 = ResidualBlock(100, 100)
388
+ self.decoder_layer12 = ResidualBlock(100, 50)
389
+ self.decoder_layer13 = ResidualBlock(50, 40)
390
+ self.decoder_layer14 = ResidualBlock(40, 16)
391
+ self.decoder_layer15 = nn.Sequential(
392
+ nn.Conv2d(16, 8, 3, stride=1, padding=1),
393
+ nn.ReLU(True)
394
+ )
395
+ self.decoder_layer16 = nn.Sequential(
396
+ nn.Conv2d(8, 1, 3, stride=1, padding=1),
397
+ nn.ReLU(True)
398
+ )
399
+
400
+ def forward(self, x, lower_skip_list, imagenet_features, height=None, width=None):
401
+ if height == None and width == None:
402
+ height = self.height
403
+ width = self.width
404
+
405
+ x = self.decoder_layer1(x)
406
+ x = self.decoder_layer2(x)
407
+ x = x + lower_skip_list[3]
408
+
409
+ x = self.decoder_layer3(x)
410
+ x = self.decoder_layer4(x)
411
+ x = x + lower_skip_list[2]
412
+
413
+ x = self.decoder_layer5(x)
414
+ x = self.decoder_layer6(x)
415
+ x = x + lower_skip_list[1] + imagenet_features
416
+
417
+ x = self.decoder_layer7(x)
418
+ x = self.decoder_layer8(x)
419
+ x = x + lower_skip_list[0]
420
+
421
+ x = self.decoder_layer9(x)
422
+ x = self.decoder_layer10(x)
423
+ x = self.decoder_layer11(x)
424
+ x = self.decoder_layer12(x)
425
+ x = self.decoder_layer13(x)
426
+ x = self.decoder_layer14(x)
427
+ x = self.decoder_layer15(x)
428
+ x = self.decoder_layer16(x)
429
+ return x
430
+
431
+ class MMPI(nn.Module):
432
+ def __init__(self,total_image_input=1, height=384,width=384):
433
+ super().__init__()
434
+ self.height = height
435
+ self.width = width
436
+ self.feature_encoder = FeatureNet(height,width)
437
+ self.lower_encoder = Encoder(height, width, total_image_input)
438
+ self.merge_decoder_rgb = DecoderRGB(height, width)
439
+ self.merge_decoder_sigma = DecoderSigma(height, width)
440
+ self.depth_decoder = DecoderDepth(height, width)
441
+
442
+ def forward(self, x, height=None, width=None):
443
+ if height == None and width == None:
444
+ height = self.height
445
+ width = self.width
446
+
447
+ imagenet_fatures = self.feature_encoder.apply_feature_encoder(x)
448
+ lower_feature, skip_list = self.lower_encoder(x, height, width)
449
+
450
+ merged_feature_rgb = self.merge_decoder_rgb(lower_feature, skip_list, imagenet_fatures, height, width)
451
+ merged_feature_sigma = self.merge_decoder_sigma(lower_feature, skip_list, imagenet_fatures, height, width)
452
+
453
+ merged_feature_depth = self.depth_decoder(lower_feature, skip_list, imagenet_fatures)
454
+
455
+ return merged_feature_rgb, merged_feature_sigma, merged_feature_depth
456
+
457
+ def get_rgb_sigma(self, x, height=None, width=None):
458
+ if height == None and width == None:
459
+ height = self.height
460
+ width = self.width
461
+
462
+ imagenet_fatures = self.feature_encoder.apply_feature_encoder(x)
463
+ lower_feature, skip_list = self.lower_encoder(x, height, width)
464
+ merged_feature_rgb = self.merge_decoder_rgb(lower_feature, skip_list, imagenet_fatures, height, width)
465
+ merged_feature_sigma = self.merge_decoder_sigma(lower_feature, skip_list, imagenet_fatures, height, width)
466
+ return merged_feature_rgb, merged_feature_sigma
467
+
468
+ def get_depth(self, x, height=None, width=None):
469
+ if height == None and width == None:
470
+ height = self.height
471
+ width = self.width
472
+
473
+ imagenet_fatures = self.feature_encoder.apply_feature_encoder(x)
474
+ lower_feature, skip_list = self.lower_encoder(x, height, width)
475
+ merged_feature_depth = self.depth_decoder(lower_feature, skip_list, imagenet_fatures)
476
+ return merged_feature_depth
477
+
478
+ def get_layer_depth(self, x, grid, height=None, width=None):
479
+ if height == None and width == None:
480
+ height = self.height
481
+ width = self.width
482
+
483
+ imagenet_fatures = self.feature_encoder.apply_feature_encoder(x)
484
+ lower_feature, skip_list = self.lower_encoder(x, height, width)
485
+
486
+ rgb_layers = self.merge_decoder_rgb(lower_feature, skip_list, imagenet_fatures, height, width)
487
+ sigma_layers = self.merge_decoder_sigma(lower_feature, skip_list, imagenet_fatures, height, width)
488
+
489
+ pred_mpi_planes = torch.randn((1, 4, height, width)).to(params.DEVICE)
490
+ for i in range(params.params_num_planes):
491
+ RGBA = torch.cat((rgb_layers[0,i,:,:,:],sigma_layers[0,i,:,:,:]),dim=0).unsqueeze(0)
492
+ pred_mpi_planes = torch.cat((pred_mpi_planes,RGBA),dim=0)
493
+
494
+ pred_mpi_planes = pred_mpi_planes[1:,:,:,:].unsqueeze(0)
495
+
496
+ sigma = pred_mpi_planes[:, :, 3, :, :]
497
+ B, D, H, W = sigma.shape
498
+
499
+ pred_mpi_disp = grid
500
+ disp_sorted, _ = pred_mpi_disp.sort(dim=1)
501
+ delta = disp_sorted[:, 1:] - disp_sorted[:, :-1]
502
+ delta_last = delta[:, -1:]
503
+ delta = torch.cat([delta, delta_last], dim=1)
504
+
505
+ delta = delta.unsqueeze(-1).unsqueeze(-1).expand_as(sigma)
506
+
507
+ alpha = 1.0 - torch.exp(-delta * sigma)
508
+
509
+ transmittance = torch.cumprod(1 - alpha + 1e-7, dim=1)
510
+ shifted_transmittance = torch.ones_like(transmittance)
511
+ shifted_transmittance[:, 1:, :, :] = transmittance[:, :-1, :, :]
512
+
513
+ disparity = pred_mpi_disp.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, H, W)
514
+
515
+ disparity_map = (disparity * alpha * shifted_transmittance).sum(dim=1, keepdim=True)
516
+
517
+ return disparity_map
518
+
519
+ def get_layers(self, x, height=None, width=None):
520
+ if height == None and width == None:
521
+ height = self.height
522
+ width = self.width
523
+
524
+ imagenet_fatures = self.feature_encoder.apply_feature_encoder(x)
525
+ lower_feature, skip_list = self.lower_encoder(x, height, width)
526
+ merged_feature_rgb = self.merge_decoder_rgb(lower_feature, skip_list, imagenet_fatures, height, width)
527
+ merged_feature_sigma = self.merge_decoder_sigma(lower_feature, skip_list, imagenet_fatures, height, width)
528
+ return merged_feature_rgb, merged_feature_sigma
529
+
530
+
531
+
532
+
533
+
534
+
535
+
model_Small.py ADDED
@@ -0,0 +1,544 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import warnings
5
+ warnings.filterwarnings("ignore")
6
+ import torchvision
7
+ import parameters as params
8
+ import timm
9
+
10
+ class DinoV2FeatureExtractor(nn.Module):
11
+ def __init__(self, out_channels=256, out_size=(64, 64)):
12
+ super().__init__()
13
+ self.dino = timm.create_model('vit_base_patch14_dinov2.lvd142m', pretrained=False)
14
+ self.dino.eval()
15
+ for p in self.dino.parameters():
16
+ p.requires_grad = False
17
+
18
+ self.out_size = out_size
19
+ self.feat_proj = nn.Sequential(
20
+ nn.Conv2d(self.dino.embed_dim, out_channels, kernel_size=1),
21
+ nn.ReLU(),
22
+ )
23
+
24
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
25
+ x = F.interpolate(x, size=(518, 518), mode='bilinear', align_corners=False)
26
+ patch_tokens = self.dino.forward_features(x)
27
+ patch_tokens = patch_tokens[:, 1:]
28
+ B, N, C = patch_tokens.shape
29
+ h = w = int(N ** 0.5)
30
+ feat_map = patch_tokens.transpose(1, 2).reshape(B, C, h, w) # [B, C, H', W']
31
+ feat_map = F.interpolate(feat_map, size=self.out_size, mode='bilinear', align_corners=False)
32
+ return self.feat_proj(feat_map)
33
+
34
+ def getLinearLayer(in_feat, out_feat, activation=nn.ReLU(True)):
35
+ return nn.Sequential(
36
+ nn.Linear(in_features=in_feat, out_features=out_feat, bias=True),
37
+ activation
38
+ )
39
+
40
+ def getConvLayer(in_channel,out_channel,stride=1,padding=1,activation=nn.ReLU()):
41
+ return nn.Sequential(nn.Conv2d(in_channel,
42
+ out_channel,
43
+ kernel_size=3,
44
+ stride=stride,
45
+ padding=padding,
46
+ padding_mode='reflect'),
47
+ activation)
48
+
49
+ def getConvTransposeLayer(in_channel, out_channel,kernel=3,stride=1,padding=1,activation=nn.ReLU()):
50
+ return nn.Sequential(nn.ConvTranspose2d(in_channel,
51
+ out_channel,
52
+ kernel_size = kernel,
53
+ stride=stride,
54
+ padding=padding),
55
+ activation)
56
+
57
+
58
+ class ResidualBlock(nn.Module):
59
+ def __init__(self, in_channels, out_channels, stride=1):
60
+ super(ResidualBlock, self).__init__()
61
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
62
+ self.relu = nn.ReLU()
63
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
64
+ self.stride = stride
65
+
66
+ self.shortcut = nn.Sequential()
67
+ if stride != 1 or in_channels != out_channels:
68
+ self.shortcut = nn.Sequential(
69
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
70
+ nn.BatchNorm2d(out_channels)
71
+ )
72
+
73
+ def forward(self, x):
74
+ residual = x
75
+
76
+ out = self.conv1(x)
77
+ out = self.relu(out)
78
+
79
+ out = self.conv2(out)
80
+
81
+ out = out + self.shortcut(residual)
82
+ out = self.relu(out)
83
+ return out
84
+
85
+
86
+ # class ResidualBlock(nn.Module):
87
+ # def __init__(self, in_channels, out_channels, stride=1, expansion=4):
88
+ # super().__init__()
89
+ # mid_channels = out_channels // expansion
90
+ # self.pw_reduce = nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False)
91
+ # self.bn1 = nn.BatchNorm2d(mid_channels)
92
+ # self.dw = nn.Conv2d(mid_channels, mid_channels, kernel_size=3,
93
+ # stride=stride, padding=1, groups=mid_channels, bias=False)
94
+ # self.bn2 = nn.BatchNorm2d(mid_channels)
95
+ # self.pw_expand = nn.Conv2d(mid_channels, out_channels, kernel_size=1, bias=False)
96
+ # self.bn3 = nn.BatchNorm2d(out_channels)
97
+ # self.relu = nn.ReLU(inplace=True)
98
+ # self.stride = stride
99
+ # if stride != 1 or in_channels != out_channels:
100
+ # self.shortcut = nn.Sequential(
101
+ # nn.Conv2d(in_channels, out_channels, kernel_size=1,
102
+ # stride=stride, bias=False),
103
+ # nn.BatchNorm2d(out_channels),
104
+ # )
105
+ # else:
106
+ # self.shortcut = nn.Identity()
107
+
108
+ # def forward(self, x):
109
+ # identity = x
110
+
111
+ # out = self.pw_reduce(x)
112
+ # out = self.bn1(out)
113
+ # out = self.relu(out)
114
+
115
+ # out = self.dw(out)
116
+ # out = self.bn2(out)
117
+ # out = self.relu(out)
118
+
119
+ # out = self.pw_expand(out)
120
+ # out = self.bn3(out)
121
+
122
+ # out += self.shortcut(identity)
123
+ # out = self.relu(out)
124
+ # return out
125
+
126
+ class FeatureNet(nn.Module):
127
+ def __init__(self,height,width):
128
+ super().__init__()
129
+ model = torchvision.models.resnet152(pretrained=False)
130
+ layers = list(model.children())
131
+ self.FeatureEncoder = torch.nn.Sequential(*layers[:5].copy())
132
+ del model
133
+
134
+ def forward(self, x):
135
+ x = self.FeatureEncoder(x)
136
+ return x
137
+
138
+ def apply_feature_encoder(self, x):
139
+ x = self.FeatureEncoder(x)
140
+ return x
141
+
142
+ class Encoder(nn.Module):
143
+ def __init__(self, height, width, total_image_input=1):
144
+ super().__init__()
145
+ self.height = height
146
+ self.width = width
147
+ self.encoder_pre = ResidualBlock((total_image_input*3), 20)
148
+ self.encoder_layer1 = ResidualBlock(20, 30)
149
+ self.encoder_layer2 = ResidualBlock(30, 50)
150
+
151
+ self.encoder_layer3 = nn.Sequential(
152
+ ResidualBlock(50, 100),
153
+ nn.MaxPool2d(kernel_size=2, stride=2)
154
+ )
155
+
156
+ self.encoder_layer4 = ResidualBlock(100, 100)
157
+ self.encoder_layer5 = nn.Sequential(
158
+ ResidualBlock(100, 100),
159
+ nn.MaxPool2d(kernel_size=2, stride=2)
160
+ )
161
+
162
+ self.encoder_layer6 = ResidualBlock(100, 100)
163
+ self.encoder_layer7 = nn.Sequential(
164
+ ResidualBlock(100, 100),
165
+ nn.MaxPool2d(kernel_size=2, stride=2)
166
+ )
167
+
168
+ self.encoder_layer8 = ResidualBlock(100, 100)
169
+ self.encoder_layer9 = nn.Sequential(
170
+ ResidualBlock(100, 100),
171
+ nn.MaxPool2d(kernel_size=2, stride=2)
172
+ )
173
+
174
+ self.encoder_layer10 = ResidualBlock(100, 100)
175
+ self.encoder_layer11 = ResidualBlock(100, 100)
176
+
177
+ def forward(self, x, height=None, width=None):
178
+ if height == None and width == None:
179
+ height = self.height
180
+ width = self.width
181
+
182
+ x = self.encoder_pre(x)
183
+ x = self.encoder_layer1(x)
184
+ x = self.encoder_layer2(x)
185
+ skip1 = self.encoder_layer3(x)
186
+
187
+ x = self.encoder_layer4(skip1)
188
+ skip2 = self.encoder_layer5(x)
189
+
190
+ x = self.encoder_layer6(skip2)
191
+ skip3 = self.encoder_layer7(x)
192
+
193
+ x = self.encoder_layer8(skip3)
194
+ skip4 = self.encoder_layer9(x)
195
+
196
+ x = self.encoder_layer10(skip4)
197
+ x = self.encoder_layer11(x)
198
+
199
+ return x, [skip1, skip2, skip3, skip4]
200
+
201
+ class DecoderRGB(nn.Module):
202
+ def __init__(self,height,width):
203
+ super().__init__()
204
+ self.height = height
205
+ self.width = width
206
+ self.decoder_layer1 = ResidualBlock(100, 100)
207
+ self.decoder_layer2 = ResidualBlock(100, 100)
208
+ self.decoder_layer3 = ResidualBlock(100, 100)
209
+
210
+ self.decoder_layer4 = nn.Sequential(
211
+ nn.ConvTranspose2d(100, 100, 2, stride=2, padding=0),
212
+ nn.ReLU(True)
213
+ )
214
+ self.decoder_layer5 = ResidualBlock(100, 100)
215
+
216
+ self.decoder_layer6 = nn.Sequential(
217
+ nn.ConvTranspose2d(100, 100, 2, stride=2, padding=0),
218
+ nn.ReLU(True)
219
+ )
220
+ self.decoder_layer7 = ResidualBlock(100, 100)
221
+
222
+ self.decoder_layer8 = nn.Sequential(
223
+ nn.ConvTranspose2d(100, 100, 2, stride=2, padding=0),
224
+ nn.ReLU(True)
225
+ )
226
+ self.decoder_layer9 = ResidualBlock(100, 100)
227
+
228
+ self.decoder_layer10 = nn.Sequential(
229
+ nn.ConvTranspose2d(100, 100, 2, stride=2, padding=0),
230
+ nn.ReLU(True)
231
+ )
232
+ self.decoder_layer11 = ResidualBlock(100, 100)
233
+ self.decoder_layer12 = ResidualBlock(100, 96)
234
+ self.decoder_layer13 = ResidualBlock(96, 96)
235
+ self.decoder_layer14 = ResidualBlock(96, 96)
236
+ self.decoder_layer15 = nn.Sequential(
237
+ nn.Conv2d(96, 96, 3, stride=1, padding=1),
238
+ nn.Sigmoid()
239
+ )
240
+ self.decoder_layer16 = nn.Sequential(
241
+ nn.Conv2d(96, 96, 3, stride=1, padding=1),
242
+ nn.Sigmoid()
243
+ )
244
+
245
+ def forward(self, x, lower_skip_list, upper_skip_list, height=None, width=None):
246
+ if height == None and width == None:
247
+ height = self.height
248
+ width = self.width
249
+
250
+ x = self.decoder_layer1(x)
251
+ x = self.decoder_layer2(x)
252
+ x = x + lower_skip_list[3] + upper_skip_list[1]
253
+
254
+ x = self.decoder_layer3(x)
255
+ x = self.decoder_layer4(x)
256
+ x = x + lower_skip_list[2] + upper_skip_list[0]
257
+
258
+ x = self.decoder_layer5(x)
259
+ x = self.decoder_layer6(x)
260
+ x = x + lower_skip_list[1]
261
+
262
+ x = self.decoder_layer7(x)
263
+ x = self.decoder_layer8(x)
264
+ x = x + lower_skip_list[0]
265
+
266
+ x = self.decoder_layer9(x)
267
+ x = self.decoder_layer10(x)
268
+ x = self.decoder_layer11(x)
269
+ x = self.decoder_layer12(x)
270
+ x = self.decoder_layer13(x)
271
+ x = self.decoder_layer14(x)
272
+ x = self.decoder_layer15(x)
273
+ x = self.decoder_layer16(x)
274
+ x = x.view(x.size()[0], 32, 3, height, width)
275
+ return x
276
+
277
+ class DecoderSigma(nn.Module):
278
+ def __init__(self,height,width):
279
+ super().__init__()
280
+ self.height = height
281
+ self.width = width
282
+ self.decoder_layer1 = ResidualBlock(100, 100)
283
+ self.decoder_layer2 = ResidualBlock(100, 100)
284
+ self.decoder_layer3 = ResidualBlock(100, 100)
285
+
286
+ self.decoder_layer4 = nn.Sequential(
287
+ nn.ConvTranspose2d(100, 100, 2, stride=2, padding=0),
288
+ nn.ReLU(True)
289
+ )
290
+ self.decoder_layer5 = ResidualBlock(100, 100)
291
+
292
+ self.decoder_layer6 = nn.Sequential(
293
+ nn.ConvTranspose2d(100, 100, 2, stride=2, padding=0),
294
+ nn.ReLU(True)
295
+ )
296
+ self.decoder_layer7 = ResidualBlock(100, 100)
297
+
298
+ self.decoder_layer8 = nn.Sequential(
299
+ nn.ConvTranspose2d(100, 100, 2, stride=2, padding=0),
300
+ nn.ReLU(True)
301
+ )
302
+ self.decoder_layer9 = ResidualBlock(100, 100)
303
+
304
+ self.decoder_layer10 = nn.Sequential(
305
+ nn.ConvTranspose2d(100, 50, 2, stride=2, padding=0),
306
+ nn.ReLU(True)
307
+ )
308
+ self.decoder_layer11 = nn.Sequential(
309
+ nn.Conv2d(50, 32, 3, stride=1, padding=1),
310
+ nn.ReLU(True)
311
+ )
312
+ self.decoder_layer12 = nn.Sequential(
313
+ nn.Conv2d(32, 32, 3, stride=1, padding=1),
314
+ nn.ReLU(True)
315
+ )
316
+
317
+ def forward(self, x, lower_skip_list, upper_skip_list, height=None, width=None):
318
+ if height == None and width == None:
319
+ height = self.height
320
+ width = self.width
321
+
322
+ x = self.decoder_layer1(x)
323
+ x = self.decoder_layer2(x)
324
+ x = x + lower_skip_list[3] + upper_skip_list[1]
325
+
326
+ x = self.decoder_layer3(x)
327
+ x = self.decoder_layer4(x)
328
+ x = x + lower_skip_list[2] + upper_skip_list[0]
329
+
330
+ x = self.decoder_layer5(x)
331
+ x = self.decoder_layer6(x)
332
+ x = x + lower_skip_list[1]
333
+
334
+ x = self.decoder_layer7(x)
335
+ x = self.decoder_layer8(x)
336
+ x = x + lower_skip_list[0]
337
+
338
+ x = self.decoder_layer9(x)
339
+ x = self.decoder_layer10(x)
340
+ x = self.decoder_layer11(x)
341
+ x = self.decoder_layer12(x)
342
+ x = x.view(x.size()[0], 32, 1, height, width)
343
+ return x
344
+
345
+
346
+ class DecoderDepth(nn.Module):
347
+ def __init__(self,height,width):
348
+ super().__init__()
349
+ self.height = height
350
+ self.width = width
351
+ self.decoder_layer1 = ResidualBlock(100, 100)
352
+ self.decoder_layer2 = ResidualBlock(100, 100)
353
+ self.decoder_layer3 = ResidualBlock(100, 100)
354
+
355
+ self.decoder_layer4 = nn.Sequential(
356
+ nn.ConvTranspose2d(100, 100, 2, stride=2, padding=0),
357
+ nn.ReLU(True)
358
+ )
359
+ self.decoder_layer5 = ResidualBlock(100, 100)
360
+
361
+ self.decoder_layer6 = nn.Sequential(
362
+ nn.ConvTranspose2d(100, 100, 2, stride=2, padding=0),
363
+ nn.ReLU(True)
364
+ )
365
+ self.decoder_layer7 = ResidualBlock(100, 100)
366
+
367
+ self.decoder_layer8 = nn.Sequential(
368
+ nn.ConvTranspose2d(100, 100, 2, stride=2, padding=0),
369
+ nn.ReLU(True)
370
+ )
371
+ self.decoder_layer9 = ResidualBlock(100, 50)
372
+
373
+ self.decoder_layer10 = nn.Sequential(
374
+ nn.ConvTranspose2d(50, 20, 2, stride=2, padding=0),
375
+ nn.ReLU(True)
376
+ )
377
+ self.decoder_layer11 = nn.Sequential(
378
+ nn.Conv2d(20, 5, 3, stride=1, padding=1),
379
+ nn.ReLU(True)
380
+ )
381
+ self.decoder_layer12 = nn.Sequential(
382
+ nn.Conv2d(5, 1, 3, stride=1, padding=1),
383
+ nn.ReLU(True)
384
+ )
385
+ def forward(self, x, lower_skip_list, upper_skip_list, height=None, width=None):
386
+ if height == None and width == None:
387
+ height = self.height
388
+ width = self.width
389
+
390
+ x = self.decoder_layer1(x)
391
+ x = self.decoder_layer2(x)
392
+ x = x + lower_skip_list[3] + upper_skip_list[1]
393
+
394
+ x = self.decoder_layer3(x)
395
+ x = self.decoder_layer4(x)
396
+ x = x + lower_skip_list[2] + upper_skip_list[0]
397
+
398
+ x = self.decoder_layer5(x)
399
+ x = self.decoder_layer6(x)
400
+ x = x + lower_skip_list[1]
401
+
402
+ x = self.decoder_layer7(x)
403
+ x = self.decoder_layer8(x)
404
+ x = x + lower_skip_list[0]
405
+
406
+ x = self.decoder_layer9(x)
407
+ x = self.decoder_layer10(x)
408
+ x = self.decoder_layer11(x)
409
+ x = self.decoder_layer12(x)
410
+ return x
411
+
412
+ class MMPI(nn.Module):
413
+ def __init__(self,total_image_input=1, height=384,width=384):
414
+ super().__init__()
415
+ self.height = height
416
+ self.width = width
417
+ self.feature_encoder = FeatureNet(height,width)
418
+ self.lower_encoder = Encoder(height, width, total_image_input)
419
+ self.merge_decoder_rgb = DecoderRGB(height, width)
420
+ self.merge_decoder_sigma = DecoderSigma(height, width)
421
+ self.depth_decoder = DecoderDepth(height, width)
422
+ self.upper_encoder_extra_1 = nn.Sequential(
423
+ ResidualBlock(256, 100),
424
+ nn.MaxPool2d(kernel_size=2, stride=2)
425
+ )
426
+ self.upper_encoder_extra_2 = nn.Sequential(
427
+ ResidualBlock(100, 100),
428
+ nn.MaxPool2d(kernel_size=2, stride=2)
429
+ )
430
+
431
+ def forward(self, x, height=None, width=None):
432
+ if height == None and width == None:
433
+ height = self.height
434
+ width = self.width
435
+
436
+ upper_features_1 = self.feature_encoder.apply_feature_encoder(x)
437
+ upper_features_1 = self.upper_encoder_extra_1(upper_features_1)
438
+ upper_features_2 = self.upper_encoder_extra_2(upper_features_1)
439
+
440
+ lower_feature, skip_list = self.lower_encoder(x, height, width)
441
+
442
+ merged_feature_rgb = self.merge_decoder_rgb(lower_feature, skip_list, [upper_features_1, upper_features_2], height, width)
443
+ merged_feature_sigma = self.merge_decoder_sigma(lower_feature, skip_list, [upper_features_1, upper_features_2], height, width)
444
+
445
+ merged_feature_depth = self.depth_decoder(lower_feature, skip_list, [upper_features_1, upper_features_2])
446
+
447
+ return merged_feature_rgb, merged_feature_sigma, merged_feature_depth
448
+
449
+ def get_rgb_sigma(self, x, height=None, width=None):
450
+ if height == None and width == None:
451
+ height = self.height
452
+ width = self.width
453
+
454
+ upper_features_1 = self.feature_encoder.apply_feature_encoder(x)
455
+ upper_features_1 = self.upper_encoder_extra_1(upper_features_1)
456
+ upper_features_2 = self.upper_encoder_extra_2(upper_features_1)
457
+
458
+ lower_feature, skip_list = self.lower_encoder(x, height, width)
459
+
460
+ merged_feature_rgb = self.merge_decoder_rgb(lower_feature, skip_list, [upper_features_1, upper_features_2], height, width)
461
+ merged_feature_sigma = self.merge_decoder_sigma(lower_feature, skip_list, [upper_features_1, upper_features_2], height, width)
462
+
463
+ return merged_feature_rgb, merged_feature_sigma
464
+
465
+ def get_depth(self, x, height=None, width=None):
466
+ if height == None and width == None:
467
+ height = self.height
468
+ width = self.width
469
+
470
+ upper_features_1 = self.feature_encoder.apply_feature_encoder(x)
471
+ upper_features_1 = self.upper_encoder_extra_1(upper_features_1)
472
+ upper_features_2 = self.upper_encoder_extra_2(upper_features_1)
473
+
474
+ lower_feature, skip_list = self.lower_encoder(x, height, width)
475
+
476
+ merged_feature_depth = self.depth_decoder(lower_feature, skip_list, [upper_features_1, upper_features_2])
477
+ return merged_feature_depth
478
+
479
+ def get_layer_depth(self, x, grid, height=None, width=None):
480
+ if height == None and width == None:
481
+ height = self.height
482
+ width = self.width
483
+
484
+ upper_features_1 = self.feature_encoder.apply_feature_encoder(x)
485
+ upper_features_1 = self.upper_encoder_extra_1(upper_features_1)
486
+ upper_features_2 = self.upper_encoder_extra_2(upper_features_1)
487
+
488
+ lower_feature, skip_list = self.lower_encoder(x, height, width)
489
+
490
+ rgb_layers = self.merge_decoder_rgb(lower_feature, skip_list, [upper_features_1, upper_features_2], height, width)
491
+ sigma_layers = self.merge_decoder_sigma(lower_feature, skip_list, [upper_features_1, upper_features_2], height, width)
492
+
493
+ pred_mpi_planes = torch.randn((1, 4, height, width)).to(params.DEVICE)
494
+ for i in range(params.params_num_planes):
495
+ RGBA = torch.cat((rgb_layers[0,i,:,:,:],sigma_layers[0,i,:,:,:]),dim=0).unsqueeze(0)
496
+ pred_mpi_planes = torch.cat((pred_mpi_planes,RGBA),dim=0)
497
+
498
+ pred_mpi_planes = pred_mpi_planes[1:,:,:,:].unsqueeze(0)
499
+
500
+ sigma = pred_mpi_planes[:, :, 3, :, :]
501
+ B, D, H, W = sigma.shape
502
+
503
+ pred_mpi_disp = grid
504
+ disp_sorted, _ = pred_mpi_disp.sort(dim=1)
505
+ delta = disp_sorted[:, 1:] - disp_sorted[:, :-1]
506
+ delta_last = delta[:, -1:]
507
+ delta = torch.cat([delta, delta_last], dim=1)
508
+
509
+ delta = delta.unsqueeze(-1).unsqueeze(-1).expand_as(sigma)
510
+
511
+ alpha = 1.0 - torch.exp(-delta * sigma)
512
+
513
+ transmittance = torch.cumprod(1 - alpha + 1e-7, dim=1)
514
+ shifted_transmittance = torch.ones_like(transmittance)
515
+ shifted_transmittance[:, 1:, :, :] = transmittance[:, :-1, :, :]
516
+
517
+ disparity = pred_mpi_disp.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, H, W)
518
+
519
+ disparity_map = (disparity * alpha * shifted_transmittance).sum(dim=1, keepdim=True)
520
+
521
+ return disparity_map
522
+
523
+ def get_layers(self, x, height=None, width=None):
524
+ if height == None and width == None:
525
+ height = self.height
526
+ width = self.width
527
+
528
+ upper_features_1 = self.feature_encoder.apply_feature_encoder(x)
529
+ upper_features_1 = self.upper_encoder_extra_1(upper_features_1)
530
+ upper_features_2 = self.upper_encoder_extra_2(upper_features_1)
531
+
532
+ lower_feature, skip_list = self.lower_encoder(x, height, width)
533
+
534
+ merged_feature_rgb = self.merge_decoder_rgb(lower_feature, skip_list, [upper_features_1, upper_features_2], height, width)
535
+ merged_feature_sigma = self.merge_decoder_sigma(lower_feature, skip_list, [upper_features_1, upper_features_2], height, width)
536
+
537
+ return merged_feature_rgb, merged_feature_sigma
538
+
539
+
540
+
541
+
542
+
543
+
544
+
parameters.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ params_height = 256
5
+ params_width = 256
6
+ params_m = 32
7
+ params_number_input = 1
8
+ params_step_size = 2
9
+ params_gamma = 0.2
10
+ params_num_planes = 32
11
+
12
+ TRAIN_LOCATION = "./lf_train.txt"
13
+ VALIDATION_LOCATION = "./lf_validate.txt"
14
+ TEST_LOCATION = "./lf_test.txt"
15
+ LOG_FILE_LOCATION = "./logs/training_log_0.txt"
16
+ CHECKPOINT_LOCATION = "./checkpoint/"
17
+ RESUME_CHECKPOINT_LOCATION = "./checkpoint/checkpoint_best.pth"
18
+ START_CHECKPOINT_LOCATION = "./checkpoint/checkpoint_init.pth"
19
+ DEVICE = "cpu"
20
+
21
+ BATCH_SIZE = 32
22
+ LEARNING_RATE = 0.0001
23
+ NUM_EPOCHS = 150
24
+ START_EPOCH = 0
25
+ PRINT_INTERVAL = 20
26
+ T_max = 150
27
+
28
+ os.makedirs("./logs",exist_ok=True)
29
+ os.makedirs("./checkpoint",exist_ok=True)
30
+ os.makedirs("./output",exist_ok=True)
31
+
32
+ def uniform_planes(a: float, b: float, n: int) -> torch.Tensor:
33
+ """
34
+ Return n values uniformly spaced *within* (a, b),
35
+ i.e. excluding the exact endpoints a and b.
36
+ """
37
+ step = (b - a) / (n + 1)
38
+ # torch.arange(1, n+1) gives [1,2,...,n]
39
+ return a + step * torch.arange(1, n + 1, dtype=torch.float32)
40
+
41
+ def get_disparity_all_src():
42
+ d1 = uniform_planes(0.0, 0.4, 20)
43
+ d2 = uniform_planes(0.4, 1.0, 12)
44
+ disparities = torch.cat([d1, d2], dim=0)
45
+ return disparities
46
+
47
+
48
+
post-install.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #!/bin/bash
2
+ pip install "pytorch3d @ git+https://github.com/facebookresearch/pytorch3d.git@89653419d0973396f3eff1a381ba09a07fffc2ed"
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy==1.26.4
2
+ torch==2.1.0
3
+ torchvision==0.16.0
4
+ pytorch-lightning==2.1.3
5
+ pytorch-msssim==1.0.0
6
+ pytorchvideo==0.1.5
7
+ grpcio==1.57.0
8
+ opencv-contrib-python==4.10.0.84
9
+ opencv-python==4.6.0.66
10
+ pillow==10.4.0
11
+ pillow_heif==0.15.0
12
+ matplotlib==3.7.2
13
+ matplotlib-inline==0.1.6
14
+ transformers==4.43.3
15
+ tqdm==4.65.0
16
+ moviepy==1.0.3
17
+ scikit-image==0.21.0
18
+ scikit-learn==1.3.0
19
+ scipy==1.11.2
utils.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from scipy.ndimage import map_coordinates
3
+
4
+
5
+ def xyzcube(face_w):
6
+ '''
7
+ Return the xyz cordinates of the unit cube in [F R B L U D] format.
8
+ '''
9
+ out = np.zeros((face_w, face_w * 6, 3), np.float32)
10
+ rng = np.linspace(-0.5, 0.5, num=face_w, dtype=np.float32)
11
+ grid = np.stack(np.meshgrid(rng, -rng), -1)
12
+
13
+ # Front face (z = 0.5)
14
+ out[:, 0*face_w:1*face_w, [0, 1]] = grid
15
+ out[:, 0*face_w:1*face_w, 2] = 0.5
16
+
17
+ # Right face (x = 0.5)
18
+ out[:, 1*face_w:2*face_w, [2, 1]] = grid
19
+ out[:, 1*face_w:2*face_w, 0] = 0.5
20
+
21
+ # Back face (z = -0.5)
22
+ out[:, 2*face_w:3*face_w, [0, 1]] = grid
23
+ out[:, 2*face_w:3*face_w, 2] = -0.5
24
+
25
+ # Left face (x = -0.5)
26
+ out[:, 3*face_w:4*face_w, [2, 1]] = grid
27
+ out[:, 3*face_w:4*face_w, 0] = -0.5
28
+
29
+ # Up face (y = 0.5)
30
+ out[:, 4*face_w:5*face_w, [0, 2]] = grid
31
+ out[:, 4*face_w:5*face_w, 1] = 0.5
32
+
33
+ # Down face (y = -0.5)
34
+ out[:, 5*face_w:6*face_w, [0, 2]] = grid
35
+ out[:, 5*face_w:6*face_w, 1] = -0.5
36
+
37
+ return out
38
+
39
+
40
+ def equirect_uvgrid(h, w):
41
+ u = np.linspace(-np.pi, np.pi, num=w, dtype=np.float32)
42
+ v = np.linspace(np.pi, -np.pi, num=h, dtype=np.float32) / 2
43
+
44
+ return np.stack(np.meshgrid(u, v), axis=-1)
45
+
46
+
47
+ def equirect_facetype(h, w):
48
+ '''
49
+ 0F 1R 2B 3L 4U 5D
50
+ '''
51
+ tp = np.roll(np.arange(4).repeat(w // 4)[None, :].repeat(h, 0), 3 * w // 8, 1)
52
+
53
+ # Prepare ceil mask
54
+ mask = np.zeros((h, w // 4), np.bool)
55
+ idx = np.linspace(-np.pi, np.pi, w // 4) / 4
56
+ idx = h // 2 - np.round(np.arctan(np.cos(idx)) * h / np.pi).astype(int)
57
+ for i, j in enumerate(idx):
58
+ mask[:j, i] = 1
59
+ mask = np.roll(np.concatenate([mask] * 4, 1), 3 * w // 8, 1)
60
+
61
+ tp[mask] = 4
62
+ tp[np.flip(mask, 0)] = 5
63
+
64
+ return tp.astype(np.int32)
65
+
66
+
67
+ def xyzpers(h_fov, v_fov, u, v, out_hw, in_rot):
68
+ out = np.ones((*out_hw, 3), np.float32)
69
+
70
+ x_max = np.tan(h_fov / 2)
71
+ y_max = np.tan(v_fov / 2)
72
+ x_rng = np.linspace(-x_max, x_max, num=out_hw[1], dtype=np.float32)
73
+ y_rng = np.linspace(-y_max, y_max, num=out_hw[0], dtype=np.float32)
74
+ out[..., :2] = np.stack(np.meshgrid(x_rng, -y_rng), -1)
75
+ Rx = rotation_matrix(v, [1, 0, 0])
76
+ Ry = rotation_matrix(u, [0, 1, 0])
77
+ Ri = rotation_matrix(in_rot, np.array([0, 0, 1.0]).dot(Rx).dot(Ry))
78
+
79
+ return out.dot(Rx).dot(Ry).dot(Ri)
80
+
81
+
82
+ def xyz2uv(xyz):
83
+ '''
84
+ xyz: ndarray in shape of [..., 3]
85
+ '''
86
+ x, y, z = np.split(xyz, 3, axis=-1)
87
+ u = np.arctan2(x, z)
88
+ c = np.sqrt(x**2 + z**2)
89
+ v = np.arctan2(y, c)
90
+
91
+ return np.concatenate([u, v], axis=-1)
92
+
93
+
94
+ def uv2unitxyz(uv):
95
+ u, v = np.split(uv, 2, axis=-1)
96
+ y = np.sin(v)
97
+ c = np.cos(v)
98
+ x = c * np.sin(u)
99
+ z = c * np.cos(u)
100
+
101
+ return np.concatenate([x, y, z], axis=-1)
102
+
103
+
104
+ def uv2coor(uv, h, w):
105
+ '''
106
+ uv: ndarray in shape of [..., 2]
107
+ h: int, height of the equirectangular image
108
+ w: int, width of the equirectangular image
109
+ '''
110
+ u, v = np.split(uv, 2, axis=-1)
111
+ coor_x = (u / (2 * np.pi) + 0.5) * w - 0.5
112
+ coor_y = (-v / np.pi + 0.5) * h - 0.5
113
+
114
+ return np.concatenate([coor_x, coor_y], axis=-1)
115
+
116
+
117
+ def coor2uv(coorxy, h, w):
118
+ coor_x, coor_y = np.split(coorxy, 2, axis=-1)
119
+ u = ((coor_x + 0.5) / w - 0.5) * 2 * np.pi
120
+ v = -((coor_y + 0.5) / h - 0.5) * np.pi
121
+
122
+ return np.concatenate([u, v], axis=-1)
123
+
124
+
125
+ def sample_equirec(e_img, coor_xy, order):
126
+ w = e_img.shape[1]
127
+ coor_x, coor_y = np.split(coor_xy, 2, axis=-1)
128
+ pad_u = np.roll(e_img[[0]], w // 2, 1)
129
+ pad_d = np.roll(e_img[[-1]], w // 2, 1)
130
+ e_img = np.concatenate([e_img, pad_d, pad_u], 0)
131
+ return map_coordinates(e_img, [coor_y, coor_x],
132
+ order=order, mode='wrap')[..., 0]
133
+
134
+
135
+ def sample_cubefaces(cube_faces, tp, coor_y, coor_x, order):
136
+ cube_faces = cube_faces.copy()
137
+ cube_faces[1] = np.flip(cube_faces[1], 1)
138
+ cube_faces[2] = np.flip(cube_faces[2], 1)
139
+ cube_faces[4] = np.flip(cube_faces[4], 0)
140
+
141
+ # Pad up down
142
+ pad_ud = np.zeros((6, 2, cube_faces.shape[2]))
143
+ pad_ud[0, 0] = cube_faces[5, 0, :]
144
+ pad_ud[0, 1] = cube_faces[4, -1, :]
145
+ pad_ud[1, 0] = cube_faces[5, :, -1]
146
+ pad_ud[1, 1] = cube_faces[4, ::-1, -1]
147
+ pad_ud[2, 0] = cube_faces[5, -1, ::-1]
148
+ pad_ud[2, 1] = cube_faces[4, 0, ::-1]
149
+ pad_ud[3, 0] = cube_faces[5, ::-1, 0]
150
+ pad_ud[3, 1] = cube_faces[4, :, 0]
151
+ pad_ud[4, 0] = cube_faces[0, 0, :]
152
+ pad_ud[4, 1] = cube_faces[2, 0, ::-1]
153
+ pad_ud[5, 0] = cube_faces[2, -1, ::-1]
154
+ pad_ud[5, 1] = cube_faces[0, -1, :]
155
+ cube_faces = np.concatenate([cube_faces, pad_ud], 1)
156
+
157
+ # Pad left right
158
+ pad_lr = np.zeros((6, cube_faces.shape[1], 2))
159
+ pad_lr[0, :, 0] = cube_faces[1, :, 0]
160
+ pad_lr[0, :, 1] = cube_faces[3, :, -1]
161
+ pad_lr[1, :, 0] = cube_faces[2, :, 0]
162
+ pad_lr[1, :, 1] = cube_faces[0, :, -1]
163
+ pad_lr[2, :, 0] = cube_faces[3, :, 0]
164
+ pad_lr[2, :, 1] = cube_faces[1, :, -1]
165
+ pad_lr[3, :, 0] = cube_faces[0, :, 0]
166
+ pad_lr[3, :, 1] = cube_faces[2, :, -1]
167
+ pad_lr[4, 1:-1, 0] = cube_faces[1, 0, ::-1]
168
+ pad_lr[4, 1:-1, 1] = cube_faces[3, 0, :]
169
+ pad_lr[5, 1:-1, 0] = cube_faces[1, -2, :]
170
+ pad_lr[5, 1:-1, 1] = cube_faces[3, -2, ::-1]
171
+ cube_faces = np.concatenate([cube_faces, pad_lr], 2)
172
+
173
+ return map_coordinates(cube_faces, [tp, coor_y, coor_x], order=order, mode='wrap')
174
+
175
+
176
+ def cube_h2list(cube_h):
177
+ assert cube_h.shape[0] * 6 == cube_h.shape[1]
178
+ return np.split(cube_h, 6, axis=1)
179
+
180
+
181
+ def cube_list2h(cube_list):
182
+ assert len(cube_list) == 6
183
+ assert sum(face.shape == cube_list[0].shape for face in cube_list) == 6
184
+ return np.concatenate(cube_list, axis=1)
185
+
186
+
187
+ def cube_h2dict(cube_h):
188
+ cube_list = cube_h2list(cube_h)
189
+ return dict([(k, cube_list[i])
190
+ for i, k in enumerate(['F', 'R', 'B', 'L', 'U', 'D'])])
191
+
192
+
193
+ def cube_dict2h(cube_dict, face_k=['F', 'R', 'B', 'L', 'U', 'D']):
194
+ assert len(face_k) == 6
195
+ return cube_list2h([cube_dict[k] for k in face_k])
196
+
197
+
198
+ def cube_h2dice(cube_h):
199
+ assert cube_h.shape[0] * 6 == cube_h.shape[1]
200
+ w = cube_h.shape[0]
201
+ cube_dice = np.zeros((w * 3, w * 4, cube_h.shape[2]), dtype=cube_h.dtype)
202
+ cube_list = cube_h2list(cube_h)
203
+ # Order: F R B L U D
204
+ sxy = [(1, 1), (2, 1), (3, 1), (0, 1), (1, 0), (1, 2)]
205
+ for i, (sx, sy) in enumerate(sxy):
206
+ face = cube_list[i]
207
+ if i in [1, 2]:
208
+ face = np.flip(face, axis=1)
209
+ if i == 4:
210
+ face = np.flip(face, axis=0)
211
+ cube_dice[sy*w:(sy+1)*w, sx*w:(sx+1)*w] = face
212
+ return cube_dice
213
+
214
+
215
+ def cube_dice2h(cube_dice):
216
+ w = cube_dice.shape[0] // 3
217
+ assert cube_dice.shape[0] == w * 3 and cube_dice.shape[1] == w * 4
218
+ cube_h = np.zeros((w, w * 6, cube_dice.shape[2]), dtype=cube_dice.dtype)
219
+ # Order: F R B L U D
220
+ sxy = [(1, 1), (2, 1), (3, 1), (0, 1), (1, 0), (1, 2)]
221
+ for i, (sx, sy) in enumerate(sxy):
222
+ face = cube_dice[sy*w:(sy+1)*w, sx*w:(sx+1)*w]
223
+ if i in [1, 2]:
224
+ face = np.flip(face, axis=1)
225
+ if i == 4:
226
+ face = np.flip(face, axis=0)
227
+ cube_h[:, i*w:(i+1)*w] = face
228
+ return cube_h
229
+
230
+
231
+ def rotation_matrix(rad, ax):
232
+ ax = np.array(ax)
233
+ assert len(ax.shape) == 1 and ax.shape[0] == 3
234
+ ax = ax / np.sqrt((ax**2).sum())
235
+ R = np.diag([np.cos(rad)] * 3)
236
+ R = R + np.outer(ax, ax) * (1.0 - np.cos(rad))
237
+
238
+ ax = ax * np.sin(rad)
239
+ R = R + np.array([[0, -ax[2], ax[1]],
240
+ [ax[2], 0, -ax[0]],
241
+ [-ax[1], ax[0], 0]])
242
+
243
+ return R
utils/.DS_Store ADDED
Binary file (6.15 kB). View file
 
utils/__init__.py ADDED
File without changes
utils/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (164 Bytes). View file
 
utils/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (156 Bytes). View file
 
utils/__pycache__/rendererBackbone.cpython-39.pyc ADDED
Binary file (4.02 kB). View file
 
utils/__pycache__/utils.cpython-38.pyc ADDED
Binary file (4.03 kB). View file
 
utils/__pycache__/utils.cpython-39.pyc ADDED
Binary file (4.04 kB). View file
 
utils/mpi/__init__.py ADDED
File without changes
utils/mpi/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (168 Bytes). View file
 
utils/mpi/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (160 Bytes). View file
 
utils/mpi/__pycache__/homography_sampler.cpython-38.pyc ADDED
Binary file (4.62 kB). View file
 
utils/mpi/__pycache__/homography_sampler.cpython-39.pyc ADDED
Binary file (4.64 kB). View file
 
utils/mpi/__pycache__/mpi_rendering.cpython-38.pyc ADDED
Binary file (7.43 kB). View file
 
utils/mpi/__pycache__/mpi_rendering.cpython-39.pyc ADDED
Binary file (7.45 kB). View file
 
utils/mpi/__pycache__/rendering_utils.cpython-38.pyc ADDED
Binary file (4.09 kB). View file
 
utils/mpi/__pycache__/rendering_utils.cpython-39.pyc ADDED
Binary file (4.07 kB). View file
 
utils/mpi/homography_sampler.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from scipy.spatial.transform import Rotation
4
+
5
+
6
+ def inverse(matrices):
7
+ """
8
+ torch.inverse() sometimes produces outputs with nan the when batch size is 2.
9
+ Ref https://github.com/pytorch/pytorch/issues/47272
10
+ this function keeps inversing the matrix until successful or maximum tries is reached
11
+ :param matrices Bx3x3
12
+ """
13
+ inverse = None
14
+ max_tries = 5
15
+ while (inverse is None) or (torch.isnan(inverse)).any():
16
+ #torch.cuda.synchronize()
17
+ inverse = torch.inverse(matrices)
18
+
19
+ # Break out of the loop when the inverse is successful or there"re no more tries
20
+ max_tries -= 1
21
+ if max_tries == 0:
22
+ break
23
+
24
+ # Raise an Exception if the inverse contains nan
25
+ if (torch.isnan(inverse)).any():
26
+ raise Exception("Matrix inverse contains nan!")
27
+ return inverse
28
+
29
+
30
+ class HomographySample:
31
+ def __init__(self, H_tgt, W_tgt, device=None):
32
+ if device is None:
33
+ self.device = torch.device("cpu")
34
+ else:
35
+ self.device = device
36
+
37
+ self.Height_tgt = H_tgt
38
+ self.Width_tgt = W_tgt
39
+ self.meshgrid = self.grid_generation(self.Height_tgt, self.Width_tgt, self.device)
40
+ self.meshgrid = self.meshgrid.permute(2, 0, 1).contiguous() # 3xHxW
41
+
42
+ self.n = self.plane_normal_generation(self.device)
43
+
44
+ @staticmethod
45
+ def grid_generation(H, W, device):
46
+ x = np.linspace(0, W-1, W)
47
+ y = np.linspace(0, H-1, H)
48
+ xv, yv = np.meshgrid(x, y) # HxW
49
+ xv = torch.from_numpy(xv.astype(np.float32)).to(dtype=torch.float32, device=device)
50
+ yv = torch.from_numpy(yv.astype(np.float32)).to(dtype=torch.float32, device=device)
51
+ ones = torch.ones_like(xv)
52
+ meshgrid = torch.stack((xv, yv, ones), dim=2) # HxWx3
53
+ return meshgrid
54
+
55
+ @staticmethod
56
+ def plane_normal_generation(device):
57
+ n = torch.tensor([0, 0, 1], dtype=torch.float32, device=device)
58
+ return n
59
+
60
+ @staticmethod
61
+ def euler_to_rotation_matrix(x_angle, y_angle, z_angle, seq='xyz', degrees=False):
62
+ """
63
+ Note that here we want to return a rotation matrix rot_mtx, which transform the tgt points into src frame,
64
+ i.e, rot_mtx * p_tgt = p_src
65
+ Therefore we need to add negative to x/y/z_angle
66
+ :param roll:
67
+ :param pitch:
68
+ :param yaw:
69
+ :return:
70
+ """
71
+ r = Rotation.from_euler(seq,
72
+ [-x_angle, -y_angle, -z_angle],
73
+ degrees=degrees)
74
+ rot_mtx = r.as_matrix().astype(np.float32)
75
+ return rot_mtx
76
+
77
+
78
+ def sample(self, src_BCHW, d_src_B,
79
+ G_tgt_src,
80
+ K_src_inv, K_tgt):
81
+ """
82
+ Coordinate system: x, y are the image directions, z is pointing to depth direction
83
+ :param src_BCHW: torch tensor float, 0-1, rgb/rgba. BxCxHxW
84
+ Assume to be at position P=[I|0]
85
+ :param d_src_B: distance of image plane to src camera origin
86
+ :param G_tgt_src: Bx4x4
87
+ :param K_src_inv: Bx3x3
88
+ :param K_tgt: Bx3x3
89
+ :return: tgt_BCHW
90
+ """
91
+ # parameter processing ------ begin ------
92
+ B, channels, Height_src, Width_src = src_BCHW.size(0), src_BCHW.size(1), src_BCHW.size(2), src_BCHW.size(3)
93
+ R_tgt_src = G_tgt_src[:, 0:3, 0:3]
94
+ t_tgt_src = G_tgt_src[:, 0:3, 3]
95
+
96
+ Height_tgt = self.Height_tgt
97
+ Width_tgt = self.Width_tgt
98
+ # if R_src_tgt is None:
99
+ # R_src_tgt = torch.eye(3, dtype=torch.float32, device=src_BCHW.device)
100
+ # R_src_tgt = R_src_tgt.unsqueeze(0).expand(B, 3, 3)
101
+ # if t_src_tgt is None:
102
+ # t_src_tgt = torch.tensor([0, 0, 0],
103
+ # dtype=torch.float32,
104
+ # device=src_BCHW.device)
105
+ # t_src_tgt = t_src_tgt.unsqueeze(0).expand(B, 3)
106
+
107
+ # relationship between FoV and focal length:
108
+ # assume W > H
109
+ # W / 2 = f*tan(\theta / 2)
110
+ # here we default the horizontal FoV as 53.13 degree
111
+ # the vertical FoV can be computed as H/2 = W*tan(\theta/2)
112
+
113
+ R_tgt_src = R_tgt_src.to(device=src_BCHW.device)
114
+ t_tgt_src = t_tgt_src.to(device=src_BCHW.device)
115
+ K_src_inv = K_src_inv.to(device=src_BCHW.device)
116
+ K_tgt = K_tgt.to(device=src_BCHW.device)
117
+ # parameter processing ------ end ------
118
+
119
+ # the goal is compute H_src_tgt, that maps a tgt pixel to src pixel
120
+ # so we compute H_tgt_src first, and then inverse
121
+ n = self.n.to(device=src_BCHW.device)
122
+ n = n.unsqueeze(0).repeat(B, 1) # Bx3
123
+ # Bx3x3 - (Bx3x1 * Bx1x3)
124
+ # note here we use -d_src, because the plane function is n^T * X - d_src = 0
125
+ d_src_B33 = d_src_B.reshape(B, 1, 1).repeat(1, 3, 3) # B -> Bx3x3
126
+ R_tnd = R_tgt_src - torch.matmul(t_tgt_src.unsqueeze(2), n.unsqueeze(1)) / -d_src_B33
127
+ H_tgt_src = torch.matmul(K_tgt,
128
+ torch.matmul(R_tnd, K_src_inv))
129
+
130
+ # TODO: fix cuda inverse
131
+ with torch.no_grad():
132
+ H_src_tgt = inverse(H_tgt_src)
133
+
134
+ # create tgt image grid, and map to src
135
+ meshgrid_tgt_homo = self.meshgrid.to(src_BCHW.device)
136
+ # 3xHxW -> Bx3xHxW
137
+ meshgrid_tgt_homo = meshgrid_tgt_homo.unsqueeze(0).expand(B, 3, Height_tgt, Width_tgt)
138
+
139
+ # wrap meshgrid_tgt_homo to meshgrid_src
140
+ meshgrid_tgt_homo_B3N = meshgrid_tgt_homo.view(B, 3, -1) # Bx3xHW
141
+ meshgrid_src_homo_B3N = torch.matmul(H_src_tgt, meshgrid_tgt_homo_B3N) # Bx3x3 * Bx3xHW -> Bx3xHW
142
+ # Bx3xHW -> Bx3xHxW -> BxHxWx3
143
+ meshgrid_src_homo = meshgrid_src_homo_B3N.view(B, 3, Height_tgt, Width_tgt).permute(0, 2, 3, 1)
144
+ meshgrid_src = meshgrid_src_homo[:, :, :, 0:2] / meshgrid_src_homo[:, :, :, 2:] # BxHxWx2
145
+
146
+ valid_mask_x = torch.logical_and(meshgrid_src[:, :, :, 0] < Width_src,
147
+ meshgrid_src[:, :, :, 0] > -1)
148
+ valid_mask_y = torch.logical_and(meshgrid_src[:, :, :, 1] < Height_src,
149
+ meshgrid_src[:, :, :, 1] > -1)
150
+ valid_mask = torch.logical_and(valid_mask_x, valid_mask_y) # BxHxW
151
+
152
+ # sample from src_BCHW
153
+ # normalize meshgrid_src to [-1,1]
154
+ meshgrid_src[:, :, :, 0] = (meshgrid_src[:, :, :, 0]+0.5) / (Width_src * 0.5) - 1
155
+ meshgrid_src[:, :, :, 1] = (meshgrid_src[:, :, :, 1]+0.5) / (Height_src * 0.5) - 1
156
+ tgt_BCHW = torch.nn.functional.grid_sample(src_BCHW, grid=meshgrid_src, padding_mode='border',
157
+ align_corners=False)
158
+ # BxCxHxW, BxHxW
159
+ return tgt_BCHW, valid_mask
utils/mpi/mpi_rendering.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from utils.mpi.homography_sampler import HomographySample
4
+ from utils.mpi.rendering_utils import transform_G_xyz, sample_pdf, gather_pixel_by_pxpy
5
+
6
+
7
+ def render(rgb_BS3HW, sigma_BS1HW, xyz_BS3HW, use_alpha=False, is_bg_depth_inf=False):
8
+ if not use_alpha:
9
+ imgs_syn, depth_syn, blend_weights, weights = plane_volume_rendering(
10
+ rgb_BS3HW,
11
+ sigma_BS1HW,
12
+ xyz_BS3HW,
13
+ is_bg_depth_inf
14
+ )
15
+ else:
16
+ imgs_syn, weights = alpha_composition(sigma_BS1HW, rgb_BS3HW)
17
+ depth_syn, _ = alpha_composition(sigma_BS1HW, xyz_BS3HW[:, :, 2:])
18
+ # No rgb blending with alpha composition
19
+ blend_weights = torch.cumprod(1 - sigma_BS1HW + 1e-6, dim=1)
20
+ # blend_weights = torch.zeros_like(rgb_BS3HW).cuda()
21
+ return imgs_syn, depth_syn, blend_weights, weights
22
+
23
+
24
+ def alpha_composition(alpha_BK1HW, value_BKCHW):
25
+ """
26
+ composition equation from 'Single-View View Synthesis with Multiplane Images'
27
+ K is the number of planes, k=0 means the nearest plane, k=K-1 means the farthest plane
28
+ :param alpha_BK1HW: alpha at each of the K planes
29
+ :param value_BKCHW: rgb/disparity at each of the K planes
30
+ :return:
31
+ """
32
+ B, K, _, H, W = alpha_BK1HW.size()
33
+ alpha_comp_cumprod = torch.cumprod(1 - alpha_BK1HW, dim=1) # BxKx1xHxW
34
+
35
+ preserve_ratio = torch.cat((torch.ones((B, 1, 1, H, W), dtype=alpha_BK1HW.dtype, device=alpha_BK1HW.device),
36
+ alpha_comp_cumprod[:, 0:K-1, :, :, :]), dim=1) # BxKx1xHxW
37
+ weights = alpha_BK1HW * preserve_ratio # BxKx1xHxW
38
+ value_composed = torch.sum(value_BKCHW * weights, dim=1, keepdim=False) # Bx3xHxW
39
+
40
+ return value_composed, weights
41
+
42
+
43
+ def plane_volume_rendering(rgb_BS3HW, sigma_BS1HW, xyz_BS3HW, is_bg_depth_inf):
44
+ B, S, _, H, W = sigma_BS1HW.size()
45
+
46
+ xyz_diff_BS3HW = xyz_BS3HW[:, 1:, :, :, :] - xyz_BS3HW[:, 0:-1, :, :, :] # Bx(S-1)x3xHxW
47
+ xyz_dist_BS1HW = torch.norm(xyz_diff_BS3HW, dim=2, keepdim=True) # Bx(S-1)x1xHxW
48
+
49
+ xyz_dist_BS1HW = torch.cat((xyz_dist_BS1HW,
50
+ torch.full((B, 1, 1, H, W),
51
+ fill_value=1e3,
52
+ dtype=xyz_BS3HW.dtype,
53
+ device=xyz_BS3HW.device)),
54
+ dim=1) # BxSx3xHxW
55
+ transparency = torch.exp(-sigma_BS1HW * xyz_dist_BS1HW) # BxSx1xHxW
56
+ alpha = 1 - transparency # BxSx1xHxW
57
+
58
+ # add small eps to avoid zero transparency_acc
59
+ # pytorch.cumprod is like: [a, b, c] -> [a, a*b, a*b*c], we need to modify it to [1, a, a*b]
60
+ transparency_acc = torch.cumprod(transparency + 1e-6, dim=1) # BxSx1xHxW
61
+ transparency_acc = torch.cat((torch.ones((B, 1, 1, H, W), dtype=transparency.dtype, device=transparency.device),
62
+ transparency_acc[:, 0:-1, :, :, :]),
63
+ dim=1) # BxSx1xHxW
64
+
65
+ weights = transparency_acc * alpha # BxSx1xHxW
66
+ rgb_out, depth_out = weighted_sum_mpi(rgb_BS3HW, xyz_BS3HW, weights, is_bg_depth_inf)
67
+
68
+ return rgb_out, depth_out, transparency_acc, weights
69
+
70
+
71
+ def weighted_sum_mpi(rgb_BS3HW, xyz_BS3HW, weights, is_bg_depth_inf):
72
+ weights_sum = torch.sum(weights, dim=1, keepdim=False) # Bx1xHxW
73
+ rgb_out = torch.sum(weights * rgb_BS3HW, dim=1, keepdim=False) # Bx3xHxW
74
+
75
+ if is_bg_depth_inf:
76
+ # for dtu dataset, set large depth if weight_sum is small
77
+ depth_out = torch.sum(weights * xyz_BS3HW[:, :, 2:, :, :], dim=1, keepdim=False) \
78
+ + (1 - weights_sum) * 1000
79
+ else:
80
+ depth_out = torch.sum(weights * xyz_BS3HW[:, :, 2:, :, :], dim=1, keepdim=False) \
81
+ / (weights_sum + 1e-5) # Bx1xHxW
82
+
83
+ return rgb_out, depth_out
84
+
85
+
86
+ def get_xyz_from_depth(meshgrid_homo,
87
+ depth,
88
+ K_inv):
89
+ """
90
+
91
+ :param meshgrid_homo: 3xHxW
92
+ :param depth: Bx1xHxW
93
+ :param K_inv: Bx3x3
94
+ :return:
95
+ """
96
+ H, W = meshgrid_homo.size(1), meshgrid_homo.size(2)
97
+ B, _, H_d, W_d = depth.size()
98
+ assert H==H_d, W==W_d
99
+
100
+ # 3xHxW -> Bx3xHxW
101
+ meshgrid_src_homo = meshgrid_homo.unsqueeze(0).repeat(B, 1, 1, 1)
102
+ meshgrid_src_homo_B3N = meshgrid_src_homo.reshape(B, 3, -1)
103
+ xyz_src = torch.matmul(K_inv, meshgrid_src_homo_B3N) # Bx3xHW
104
+ xyz_src = xyz_src.reshape(B, 3, H, W) * depth # Bx3xHxW
105
+
106
+ return xyz_src
107
+
108
+
109
+ def disparity_consistency_src_to_tgt(meshgrid_homo, K_src_inv, disparity_src,
110
+ G_tgt_src, K_tgt, disparity_tgt):
111
+ """
112
+
113
+ :param xyz_src_B3N: Bx3xN
114
+ :param G_tgt_src: Bx4x4
115
+ :param K_tgt: Bx3x3
116
+ :param disparity_tgt: Bx1xHxW
117
+ :return:
118
+ """
119
+ B, _, H, W = disparity_src.size()
120
+ depth_src = torch.reciprocal(disparity_src)
121
+ xyz_src_B3N = get_xyz_from_depth(meshgrid_homo, depth_src, K_src_inv).view(B, 3, H*W)
122
+
123
+ xyz_tgt_B3N = transform_G_xyz(G_tgt_src, xyz_src_B3N, is_return_homo=False)
124
+ K_xyz_tgt_B3N = torch.matmul(K_tgt, xyz_tgt_B3N)
125
+ pxpy_tgt_B2N = K_xyz_tgt_B3N[:, 0:2, :] / K_xyz_tgt_B3N[:, 2:, :] # Bx2xN
126
+
127
+ pxpy_tgt_mask = torch.logical_and(
128
+ torch.logical_and(pxpy_tgt_B2N[:, 0:1, :] >= 0,
129
+ pxpy_tgt_B2N[:, 0:1, :] <= W - 1),
130
+ torch.logical_and(pxpy_tgt_B2N[:, 1:2, :] >= 0,
131
+ pxpy_tgt_B2N[:, 1:2, :] <= H - 1)
132
+ ) # B1N
133
+
134
+ disparity_src = torch.reciprocal(xyz_tgt_B3N[:, 2:, :]) # Bx1xN
135
+ disparity_tgt = gather_pixel_by_pxpy(disparity_tgt, pxpy_tgt_B2N) # Bx1xN
136
+
137
+ depth_diff = torch.abs(disparity_src - disparity_tgt)
138
+ return torch.mean(depth_diff[pxpy_tgt_mask])
139
+
140
+
141
+ def get_src_xyz_from_plane_disparity(meshgrid_src_homo,
142
+ mpi_disparity_src,
143
+ K_src_inv):
144
+ """
145
+
146
+ :param meshgrid_src_homo: 3xHxW
147
+ :param mpi_disparity_src: BxS
148
+ :param K_src_inv: Bx3x3
149
+ :return:
150
+ """
151
+ B, S = mpi_disparity_src.size()
152
+ H, W = meshgrid_src_homo.size(1), meshgrid_src_homo.size(2)
153
+ mpi_depth_src = torch.reciprocal(mpi_disparity_src) # BxS
154
+
155
+ K_src_inv_Bs33 = K_src_inv.unsqueeze(1).repeat(1, S, 1, 1).reshape(B * S, 3, 3)
156
+
157
+ # 3xHxW -> BxSx3xHxW
158
+ meshgrid_src_homo = meshgrid_src_homo.unsqueeze(0).unsqueeze(1).repeat(B, S, 1, 1, 1)
159
+ meshgrid_src_homo_Bs3N = meshgrid_src_homo.reshape(B * S, 3, -1)
160
+ xyz_src = torch.matmul(K_src_inv_Bs33, meshgrid_src_homo_Bs3N) # BSx3xHW
161
+ xyz_src = xyz_src.reshape(B, S, 3, H * W) * mpi_depth_src.unsqueeze(2).unsqueeze(3) # BxSx3xHW
162
+ xyz_src_BS3HW = xyz_src.reshape(B, S, 3, H, W)
163
+
164
+ return xyz_src_BS3HW
165
+
166
+
167
+ def get_tgt_xyz_from_plane_disparity(xyz_src_BS3HW,
168
+ G_tgt_src):
169
+ """
170
+
171
+ :param xyz_src_BS3HW: BxSx3xHxW
172
+ :param G_tgt_src: Bx4x4
173
+ :return:
174
+ """
175
+ B, S, _, H, W = xyz_src_BS3HW.size()
176
+ G_tgt_src_Bs33 = G_tgt_src.unsqueeze(1).repeat(1, S, 1, 1).reshape(B*S, 4, 4)
177
+ xyz_tgt = transform_G_xyz(G_tgt_src_Bs33, xyz_src_BS3HW.reshape(B*S, 3, H*W)) # Bsx3xHW
178
+ xyz_tgt_BS3HW = xyz_tgt.reshape(B, S, 3, H, W) # BxSx3xHxW
179
+ return xyz_tgt_BS3HW
180
+
181
+
182
+ def render_tgt_rgb_depth(H_sampler: HomographySample,
183
+ mpi_rgb_src,
184
+ mpi_sigma_src,
185
+ mpi_disparity_src,
186
+ xyz_tgt_BS3HW,
187
+ G_tgt_src,
188
+ K_src_inv, K_tgt,
189
+ use_alpha=False,
190
+ is_bg_depth_inf=False):
191
+ """
192
+ :param H_sampler:
193
+ :param mpi_rgb_src: BxSx3xHxW
194
+ :param mpi_sigma_src: BxSx1xHxW
195
+ :param mpi_disparity_src: BxS
196
+ :param xyz_tgt_BS3HW: BxSx3xHxW
197
+ :param G_tgt_src: Bx4x4
198
+ :param K_src_inv: Bx3x3
199
+ :param K_tgt: Bx3x3
200
+ :return:
201
+ """
202
+ B, S, _, H, W = mpi_rgb_src.size()
203
+ mpi_depth_src = torch.reciprocal(mpi_disparity_src) # BxS
204
+
205
+ # note that here we concat the mpi_src with xyz_tgt, because H_sampler will sample them for tgt frame
206
+ # mpi_src is the same in whatever frame, but xyz has to be in tgt frame
207
+ mpi_xyz_src = torch.cat((mpi_rgb_src, mpi_sigma_src, xyz_tgt_BS3HW), dim=2) # BxSx(3+1+3)xHxW
208
+
209
+ # homography warping of mpi_src into tgt frame
210
+ G_tgt_src_Bs44 = G_tgt_src.unsqueeze(1).repeat(1, S, 1, 1).contiguous().reshape(B*S, 4, 4) # Bsx4x4
211
+ K_src_inv_Bs33 = K_src_inv.unsqueeze(1).repeat(1, S, 1, 1).contiguous().reshape(B*S, 3, 3) # Bsx3x3
212
+ K_tgt_Bs33 = K_tgt.unsqueeze(1).repeat(1, S, 1, 1).contiguous().reshape(B*S, 3, 3) # Bsx3x3
213
+
214
+ # BsxCxHxW, BsxHxW
215
+ tgt_mpi_xyz_BsCHW, tgt_mask_BsHW = H_sampler.sample(mpi_xyz_src.view(B*S, 7, H, W),
216
+ mpi_depth_src.view(B*S),
217
+ G_tgt_src_Bs44,
218
+ K_src_inv_Bs33,
219
+ K_tgt_Bs33)
220
+
221
+ # mpi composition
222
+ tgt_mpi_xyz = tgt_mpi_xyz_BsCHW.view(B, S, 7, H, W)
223
+ tgt_rgb_BS3HW = tgt_mpi_xyz[:, :, 0:3, :, :]
224
+ tgt_sigma_BS1HW = tgt_mpi_xyz[:, :, 3:4, :, :]
225
+ tgt_xyz_BS3HW = tgt_mpi_xyz[:, :, 4:, :, :]
226
+
227
+ tgt_mask_BSHW = tgt_mask_BsHW.view(B, S, H, W)
228
+ tgt_mask_BSHW = torch.where(tgt_mask_BSHW,
229
+ torch.ones((B, S, H, W), dtype=torch.float32, device=mpi_rgb_src.device),
230
+ torch.zeros((B, S, H, W), dtype=torch.float32, device=mpi_rgb_src.device))
231
+
232
+ # Bx3xHxW, Bx1xHxW, Bx1xHxW
233
+ tgt_z_BS1HW = tgt_xyz_BS3HW[:, :, -1:]
234
+ tgt_sigma_BS1HW = torch.where(tgt_z_BS1HW >= 0,
235
+ tgt_sigma_BS1HW,
236
+ torch.zeros_like(tgt_sigma_BS1HW, device=tgt_sigma_BS1HW.device))
237
+ tgt_rgb_syn, tgt_depth_syn, _, _ = render(tgt_rgb_BS3HW, tgt_sigma_BS1HW, tgt_xyz_BS3HW,
238
+ use_alpha=use_alpha,
239
+ is_bg_depth_inf=is_bg_depth_inf)
240
+ tgt_mask = torch.sum(tgt_mask_BSHW, dim=1, keepdim=True) # Bx1xHxW
241
+
242
+ return tgt_rgb_syn, tgt_depth_syn, tgt_mask
243
+
244
+
245
+ def predict_mpi_coarse_to_fine(mpi_predictor, src_imgs, xyz_src_BS3HW_coarse,
246
+ disparity_coarse_src, S_fine, is_bg_depth_inf):
247
+ if S_fine > 0:
248
+ with torch.no_grad():
249
+ # predict coarse mpi
250
+ mpi_coarse_src_list = mpi_predictor(src_imgs, disparity_coarse_src) # BxS_coarsex4xHxW
251
+ mpi_coarse_rgb_src = mpi_coarse_src_list[0][:, :, 0:3, :, :] # BxSx1xHxW
252
+ mpi_coarse_sigma_src = mpi_coarse_src_list[0][:, :, 3:, :, :] # BxSx1xHxW
253
+ _, _, _, weights = plane_volume_rendering(
254
+ mpi_coarse_rgb_src,
255
+ mpi_coarse_sigma_src,
256
+ xyz_src_BS3HW_coarse,
257
+ is_bg_depth_inf
258
+ )
259
+ weights = weights.mean((2, 3, 4)).unsqueeze(1).unsqueeze(2)
260
+
261
+ # sample fine disparity
262
+ disparity_fine_src = sample_pdf(disparity_coarse_src.unsqueeze(1).unsqueeze(2), weights, S_fine)
263
+ disparity_fine_src = disparity_fine_src.squeeze(2).squeeze(1)
264
+
265
+ # assemble coarse and fine disparity
266
+ disparity_all_src = torch.cat((disparity_coarse_src, disparity_fine_src), dim=1) # Bx(S_coarse + S_fine)
267
+ disparity_all_src, _ = torch.sort(disparity_all_src, dim=1, descending=True)
268
+ mpi_all_src_list = mpi_predictor(src_imgs, disparity_all_src) # BxS_coarsex4xHxW
269
+ return mpi_all_src_list, disparity_all_src
270
+ else:
271
+ mpi_coarse_src_list = mpi_predictor(src_imgs, disparity_coarse_src) # BxS_coarsex4xHxW
272
+ return mpi_coarse_src_list, disparity_coarse_src
utils/mpi/rendering_utils.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def transform_G_xyz(G, xyz, is_return_homo=False):
5
+ """
6
+
7
+ :param G: Bx4x4
8
+ :param xyz: Bx3xN
9
+ :return:
10
+ """
11
+ assert len(G.size()) == len(xyz.size())
12
+ if len(G.size()) == 2:
13
+ G_B44 = G.unsqueeze(0)
14
+ xyz_B3N = xyz.unsqueeze(0)
15
+ else:
16
+ G_B44 = G
17
+ xyz_B3N = xyz
18
+ xyz_B4N = torch.cat((xyz_B3N, torch.ones_like(xyz_B3N[:, 0:1, :])), dim=1)
19
+ G_xyz_B4N = torch.matmul(G_B44, xyz_B4N)
20
+ if is_return_homo:
21
+ return G_xyz_B4N
22
+ else:
23
+ return G_xyz_B4N[:, 0:3, :]
24
+
25
+
26
+ def gather_pixel_by_pxpy(img, pxpy):
27
+ """
28
+
29
+ :param img: Bx3xHxW
30
+ :param pxpy: Bx2xN
31
+ :return:
32
+ """
33
+ with torch.no_grad():
34
+ B, C, H, W = img.size()
35
+ if pxpy.dtype == torch.float32:
36
+ pxpy_int = torch.round(pxpy).to(torch.int64)
37
+ pxpy_int = pxpy_int.to(torch.int64)
38
+ pxpy_int[:, 0, :] = torch.clamp(pxpy_int[:, 0, :], min=0, max=W-1)
39
+ pxpy_int[:, 1, :] = torch.clamp(pxpy_int[:, 1, :], min=0, max=H-1)
40
+ pxpy_idx = pxpy_int[:, 0:1, :] + W * pxpy_int[:, 1:2, :] # Bx1xN_pt
41
+ rgb = torch.gather(img.view(B, C, H * W), dim=2,
42
+ index=pxpy_idx.repeat(1, C, 1)) # BxCxN_pt
43
+ return rgb
44
+
45
+
46
+ def uniformly_sample_disparity_from_bins(batch_size, disparity_np, device):
47
+ """
48
+ In the disparity dimension, it has to be from large to small, i.e., depth from small (near) to large (far)
49
+ :param start:
50
+ :param end:
51
+ :param num_bins:
52
+ :return:
53
+ """
54
+ assert disparity_np[0] > disparity_np[-1]
55
+ S = disparity_np.shape[0] - 1
56
+
57
+ B = batch_size
58
+ bin_edges = torch.from_numpy(disparity_np).to(dtype=torch.float32, device=device) # S+1
59
+ interval = bin_edges[1:] - bin_edges[0:-1] # S
60
+ bin_edges_start = bin_edges[0:-1].unsqueeze(0).repeat(B, 1) # S -> BxS
61
+ # bin_edges_end = bin_edges[1:].unsqueeze(0).repeat(B, 1) # S -> BxS
62
+ interval = interval.unsqueeze(0).repeat(B, 1) # S -> BxS
63
+
64
+ random_float = torch.rand((B, S), dtype=torch.float32, device=device) # BxS
65
+ disparity_array = bin_edges_start + interval * random_float
66
+ return disparity_array # BxS
67
+
68
+
69
+ def uniformly_sample_disparity_from_linspace_bins(batch_size, num_bins, start, end, device):
70
+ """
71
+ In the disparity dimension, it has to be from large to small, i.e., depth from small (near) to large (far)
72
+ :param start:
73
+ :param end:
74
+ :param num_bins:
75
+ :return:
76
+ """
77
+ assert start > end
78
+
79
+ B, S = batch_size, num_bins
80
+ bin_edges = torch.linspace(start, end, num_bins+1, dtype=torch.float32, device=device) # S+1
81
+ interval = bin_edges[1] - bin_edges[0] # scalar
82
+ bin_edges_start = bin_edges[0:-1].unsqueeze(0).repeat(B, 1) # S -> BxS
83
+ # bin_edges_end = bin_edges[1:].unsqueeze(0).repeat(B, 1) # S -> BxS
84
+
85
+ random_float = torch.rand((B, S), dtype=torch.float32, device=device) # BxS
86
+ disparity_array = bin_edges_start + interval * random_float
87
+ return disparity_array # BxS
88
+
89
+
90
+ def sample_pdf(values, weights, N_samples):
91
+ """
92
+ draw samples from distribution approximated by values and weights.
93
+ the probability distribution can be denoted as weights = p(values)
94
+ :param values: Bx1xNxS
95
+ :param weights: Bx1xNxS
96
+ :param N_samples: number of sample to draw
97
+ :return:
98
+ """
99
+ B, N, S = weights.size(0), weights.size(2), weights.size(3)
100
+ assert values.size() == (B, 1, N, S)
101
+
102
+ # convert values to bin edges
103
+ bin_edges = (values[:, :, :, 1:] + values[:, :, :, :-1]) * 0.5 # Bx1xNxS-1
104
+ bin_edges = torch.cat((values[:, :, :, 0:1],
105
+ bin_edges,
106
+ values[:, :, :, -1:]), dim=3) # Bx1xNxS+1
107
+
108
+ pdf = weights / (torch.sum(weights, dim=3, keepdim=True) + 1e-5) # Bx1xNxS
109
+ cdf = torch.cumsum(pdf, dim=3) # Bx1xNxS
110
+ cdf = torch.cat((torch.zeros((B, 1, N, 1), dtype=cdf.dtype, device=cdf.device),
111
+ cdf), dim=3) # Bx1xNxS+1
112
+
113
+ # uniform sample over the cdf values
114
+ u = torch.rand((B, 1, N, N_samples), dtype=weights.dtype, device=weights.device) # Bx1xNxN_samples
115
+
116
+ # get the index on the cdf array
117
+ cdf_idx = torch.searchsorted(cdf, u, right=True) # Bx1xNxN_samples
118
+ cdf_idx_lower = torch.clamp(cdf_idx-1, min=0) # Bx1xNxN_samples
119
+ cdf_idx_upper = torch.clamp(cdf_idx, max=S) # Bx1xNxN_samples
120
+
121
+ # linear approximation for each bin
122
+ cdf_idx_lower_upper = torch.cat((cdf_idx_lower, cdf_idx_upper), dim=3) # Bx1xNx(N_samplesx2)
123
+ cdf_bounds_N2 = torch.gather(cdf, index=cdf_idx_lower_upper, dim=3) # Bx1xNx(N_samplesx2)
124
+ cdf_bounds = torch.stack((cdf_bounds_N2[..., 0:N_samples], cdf_bounds_N2[..., N_samples:]), dim=4)
125
+ bin_bounds_N2 = torch.gather(bin_edges, index=cdf_idx_lower_upper, dim=3) # Bx1xNx(N_samplesx2)
126
+ bin_bounds = torch.stack((bin_bounds_N2[..., 0:N_samples], bin_bounds_N2[..., N_samples:]), dim=4)
127
+
128
+ # avoid zero cdf_intervals
129
+ cdf_intervals = cdf_bounds[:, :, :, :, 1] - cdf_bounds[:, :, :, :, 0] # Bx1xNxN_samples
130
+ bin_intervals = bin_bounds[:, :, :, :, 1] - bin_bounds[:, :, :, :, 0] # Bx1xNxN_samples
131
+ u_cdf_lower = u - cdf_bounds[:, :, :, :, 0] # Bx1xNxN_samples
132
+ # there is the case that cdf_interval = 0, caused by the cdf_idx_lower/upper clamp above, need special handling
133
+ t = u_cdf_lower / torch.clamp(cdf_intervals, min=1e-5)
134
+ t = torch.where(cdf_intervals <= 1e-4,
135
+ torch.full_like(u_cdf_lower, 0.5),
136
+ t)
137
+
138
+ samples = bin_bounds[:, :, :, :, 0] + t*bin_intervals
139
+ return samples
utils/rendererBackbone.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # System Imports
2
+ import os
3
+ import math
4
+ import argparse
5
+ import time
6
+
7
+ # Common Libs
8
+ import numpy as np
9
+ from pathlib import Path
10
+ import cv2
11
+ import tkinter as tk
12
+ import threading
13
+ import queue
14
+
15
+ # Torch Imports
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torchvision import transforms
19
+ from torchvision.utils import save_image
20
+
21
+ # 3rd party imports
22
+ from transformers import DPTForDepthEstimation, DPTImageProcessor
23
+ from tqdm import tqdm
24
+ import mediapipe as mp
25
+ from PIL import Image, ImageTk
26
+ from moviepy.editor import ImageSequenceClip
27
+
28
+ # From Codebase
29
+ from utils.mpi import mpi_rendering
30
+ from utils.mpi.homography_sampler import HomographySample
31
+ from utils.mpi.homography_sampler import HomographySample
32
+ from utils.utils import (
33
+ image_to_tensor,
34
+ disparity_to_tensor,
35
+ render_3dphoto,
36
+ render_novel_view,
37
+ )
38
+ from model.AdaMPI import MPIPredictor
39
+ from parameters import *
40
+
41
+
42
+ #=================================================
43
+ # Define the MPI Layers Processing Module Here
44
+ #=================================================
45
+ def processMPIs(src_imgs, mpi_all_src, disparity_all_src, k_src, k_tgt, save_path=None):
46
+ h, w = mpi_all_src.shape[-2:]
47
+ device = mpi_all_src.device
48
+ homography_sampler = HomographySample(h, w, device)
49
+ k_src_inv = torch.inverse(k_src)
50
+
51
+ # preprocess the predict MPI
52
+ xyz_src_BS3HW = mpi_rendering.get_src_xyz_from_plane_disparity(
53
+ homography_sampler.meshgrid,
54
+ disparity_all_src,
55
+ k_src_inv,
56
+ )
57
+ mpi_all_rgb_src = mpi_all_src[:, :, 0:3, :, :] # BxSx3xHxW
58
+ mpi_all_sigma_src = mpi_all_src[:, :, 3:, :, :] # BxSx1xHxW
59
+ _, _, blend_weights, _ = mpi_rendering.render(
60
+ mpi_all_rgb_src,
61
+ mpi_all_sigma_src,
62
+ xyz_src_BS3HW,
63
+ use_alpha=False,
64
+ is_bg_depth_inf=False,
65
+ )
66
+ mpi_all_rgb_src = blend_weights * src_imgs.unsqueeze(1) + (1 - blend_weights) * mpi_all_rgb_src
67
+
68
+ return mpi_all_rgb_src, mpi_all_sigma_src, disparity_all_src, k_src_inv,k_tgt,homography_sampler
69
+
70
+
71
+
72
+ def cropFOV(image, original_fov, new_fov):
73
+ image = np.array(image)
74
+ if new_fov >= original_fov:
75
+ raise ValueError("New FoV must be smaller than the original FoV")
76
+
77
+ crop_ratio = new_fov / original_fov
78
+ height, width = image.shape[:2]
79
+
80
+ new_width = int(width * crop_ratio)
81
+ new_height = int(height * crop_ratio)
82
+
83
+ start_x = (width - new_width) // 2
84
+ start_y = (height - new_height) // 2
85
+
86
+ cropped_image = image[start_y:start_y + new_height, start_x:start_x + new_width]
87
+ cropped_image = Image.fromarray(cropped_image)
88
+ return cropped_image
89
+
90
+
91
+
92
+ def renderSingleFrame(mpi_all_rgb_src, mpi_all_sigma_src, disparity_all_src, cam_ext, k_src_inv, k_tgt, homography_sampler):
93
+ frame = render_novel_view(
94
+ mpi_all_rgb_src,
95
+ mpi_all_sigma_src,
96
+ disparity_all_src,
97
+ cam_ext.to(device),
98
+ k_src_inv,
99
+ k_tgt,
100
+ homography_sampler,
101
+ )
102
+ frame_np = frame[0].permute(1, 2, 0).contiguous().cpu().numpy() # [b,h,w,3]
103
+ frame_np = np.clip(np.round(frame_np * 255), a_min=0, a_max=255).astype(np.uint8)
104
+ im = Image.fromarray(frame_np)
105
+ return im
106
+
107
+
108
+ class VideoCapture:
109
+ def __init__(self, name):
110
+ self.cap = cv2.VideoCapture(name)
111
+ self.q = queue.Queue()
112
+ t = threading.Thread(target=self._reader)
113
+ t.daemon = True
114
+ t.start()
115
+
116
+ def _reader(self):
117
+ while True:
118
+ ret, frame = self.cap.read()
119
+ if not ret:
120
+ break
121
+ if not self.q.empty():
122
+ try:
123
+ self.q.get_nowait()
124
+ except queue.Empty:
125
+ pass
126
+ self.q.put(frame)
127
+
128
+ def read(self):
129
+ return self.q.get()
130
+
131
+
132
+
133
+ def captureBackground(capture_device):
134
+ frame_background = capture_device.read()
135
+ img = cv2.cvtColor(frame_background, cv2.COLOR_BGR2RGB)
136
+ im_pil = Image.fromarray(img)
137
+ return im_pil
138
+
139
+
140
+
141
+ def getImageTensor(pil_image, height, width, unsqueeze=True):
142
+ t = transforms.Compose([transforms.CenterCrop((height, width)),transforms.ToTensor()])
143
+ rgb = t(pil_image)
144
+
145
+ if unsqueeze:
146
+ rgb = rgb.unsqueeze(0)
147
+ return rgb
utils/utils.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ from PIL import Image
4
+ import cv2
5
+ from tqdm import tqdm
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torchvision import transforms
9
+ from torchvision.utils import save_image
10
+ import numpy as np
11
+ from moviepy.editor import ImageSequenceClip
12
+
13
+ from utils.mpi import mpi_rendering
14
+ from utils.mpi.homography_sampler import HomographySample
15
+
16
+
17
+ def image_to_tensor(img_path, unsqueeze=True):
18
+ rgb = transforms.ToTensor()(Image.open(img_path))
19
+ if unsqueeze:
20
+ rgb = rgb.unsqueeze(0)
21
+ return rgb
22
+
23
+
24
+ def disparity_to_tensor(disp_path, unsqueeze=True):
25
+ disp = cv2.imread(disp_path, -1) / (2 ** 16 - 1)
26
+ disp = torch.from_numpy(disp)[None, ...]
27
+ if unsqueeze:
28
+ disp = disp.unsqueeze(0)
29
+ return disp.float()
30
+
31
+
32
+ def gen_swing_path(num_frames=90, r_x=0.14, r_y=0., r_z=0.10):
33
+ "Return a list of matrix [4, 4]"
34
+ t = torch.arange(num_frames) / (num_frames - 1)
35
+ poses = torch.eye(4).repeat(num_frames, 1, 1)
36
+ poses[:, 0, 3] = r_x * torch.sin(2. * math.pi * t)
37
+ poses[:, 1, 3] = r_y * torch.cos(2. * math.pi * t)
38
+ poses[:, 2, 3] = r_z * (torch.cos(2. * math.pi * t) - 1.)
39
+ return poses.unbind()
40
+
41
+
42
+ def render_3dphoto(
43
+ src_imgs, # [b,3,h,w]
44
+ mpi_all_src, # [b,s,4,h,w]
45
+ disparity_all_src, # [b,s]
46
+ k_src, # [b,3,3]
47
+ k_tgt, # [b,3,3]
48
+ save_path,
49
+ ):
50
+ h, w = mpi_all_src.shape[-2:]
51
+ device = mpi_all_src.device
52
+ homography_sampler = HomographySample(h, w, device)
53
+ k_src_inv = torch.inverse(k_src)
54
+
55
+ # preprocess the predict MPI
56
+ xyz_src_BS3HW = mpi_rendering.get_src_xyz_from_plane_disparity(
57
+ homography_sampler.meshgrid,
58
+ disparity_all_src,
59
+ k_src_inv,
60
+ )
61
+ mpi_all_rgb_src = mpi_all_src[:, :, 0:3, :, :] # BxSx3xHxW
62
+ mpi_all_sigma_src = mpi_all_src[:, :, 3:, :, :] # BxSx1xHxW
63
+ _, _, blend_weights, _ = mpi_rendering.render(
64
+ mpi_all_rgb_src,
65
+ mpi_all_sigma_src,
66
+ xyz_src_BS3HW,
67
+ use_alpha=False,
68
+ is_bg_depth_inf=False,
69
+ )
70
+ mpi_all_rgb_src = blend_weights * src_imgs.unsqueeze(1) + (1 - blend_weights) * mpi_all_rgb_src
71
+
72
+ # render novel views
73
+ swing_path_list = gen_swing_path()
74
+ frames = []
75
+ for cam_ext in tqdm(swing_path_list):
76
+ frame = render_novel_view(
77
+ mpi_all_rgb_src,
78
+ mpi_all_sigma_src,
79
+ disparity_all_src,
80
+ cam_ext,
81
+ k_src_inv,
82
+ k_tgt,
83
+ homography_sampler,
84
+ )
85
+ frame_np = frame[0].permute(1, 2, 0).contiguous().cpu().numpy() # [b,h,w,3]
86
+ frame_np = np.clip(np.round(frame_np * 255), a_min=0, a_max=255).astype(np.uint8)
87
+ frames.append(frame_np)
88
+ rgb_clip = ImageSequenceClip(frames, fps=30)
89
+ rgb_clip.write_videofile(save_path, verbose=False, codec='mpeg4', logger=None, bitrate='2000k')
90
+
91
+
92
+ def render_novel_view(
93
+ mpi_all_rgb_src,
94
+ mpi_all_sigma_src,
95
+ disparity_all_src,
96
+ G_tgt_src,
97
+ K_src_inv,
98
+ K_tgt,
99
+ homography_sampler,
100
+ ):
101
+ xyz_src_BS3HW = mpi_rendering.get_src_xyz_from_plane_disparity(
102
+ homography_sampler.meshgrid,
103
+ disparity_all_src,
104
+ K_src_inv
105
+ )
106
+
107
+ xyz_tgt_BS3HW = mpi_rendering.get_tgt_xyz_from_plane_disparity(
108
+ xyz_src_BS3HW,
109
+ G_tgt_src
110
+ )
111
+
112
+ tgt_imgs_syn, _, _ = mpi_rendering.render_tgt_rgb_depth(
113
+ homography_sampler,
114
+ mpi_all_rgb_src,
115
+ mpi_all_sigma_src,
116
+ disparity_all_src,
117
+ xyz_tgt_BS3HW,
118
+ G_tgt_src,
119
+ K_src_inv,
120
+ K_tgt,
121
+ use_alpha=False,
122
+ is_bg_depth_inf=False,
123
+ )
124
+
125
+ return tgt_imgs_syn
126
+
127
+
128
+ class AverageMeter(object):
129
+ """Computes and stores the average and current value"""
130
+ def __init__(self, name, fmt=":f"):
131
+ self.name = name
132
+ self.fmt = fmt
133
+ self.reset()
134
+
135
+ def reset(self):
136
+ self.val = 0
137
+ self.avg = 0
138
+ self.sum = 0
139
+ self.count = 0
140
+
141
+ def update(self, val, n=1):
142
+ self.val = val
143
+ self.sum += val * n
144
+ self.count += n
145
+ self.avg = self.sum / self.count
146
+
147
+ def __str__(self):
148
+ # fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
149
+ # return fmtstr.format(**self.__dict__)
150
+ return f"{self.name:s}: {self.avg:.6f}"