initial commit
Browse files- README.md +30 -14
- app.py +188 -0
- helperFunctions.py +26 -0
- helper_image_functions.py +290 -0
- model_Large.py +535 -0
- model_Medium.py +535 -0
- model_Small.py +544 -0
- parameters.py +48 -0
- post-install.sh +2 -0
- requirements.txt +19 -0
- utils.py +243 -0
- utils/.DS_Store +0 -0
- utils/__init__.py +0 -0
- utils/__pycache__/__init__.cpython-38.pyc +0 -0
- utils/__pycache__/__init__.cpython-39.pyc +0 -0
- utils/__pycache__/rendererBackbone.cpython-39.pyc +0 -0
- utils/__pycache__/utils.cpython-38.pyc +0 -0
- utils/__pycache__/utils.cpython-39.pyc +0 -0
- utils/mpi/__init__.py +0 -0
- utils/mpi/__pycache__/__init__.cpython-38.pyc +0 -0
- utils/mpi/__pycache__/__init__.cpython-39.pyc +0 -0
- utils/mpi/__pycache__/homography_sampler.cpython-38.pyc +0 -0
- utils/mpi/__pycache__/homography_sampler.cpython-39.pyc +0 -0
- utils/mpi/__pycache__/mpi_rendering.cpython-38.pyc +0 -0
- utils/mpi/__pycache__/mpi_rendering.cpython-39.pyc +0 -0
- utils/mpi/__pycache__/rendering_utils.cpython-38.pyc +0 -0
- utils/mpi/__pycache__/rendering_utils.cpython-39.pyc +0 -0
- utils/mpi/homography_sampler.py +159 -0
- utils/mpi/mpi_rendering.py +272 -0
- utils/mpi/rendering_utils.py +139 -0
- utils/rendererBackbone.py +147 -0
- utils/utils.py +150 -0
README.md
CHANGED
|
@@ -1,14 +1,30 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div align="center">
|
| 2 |
+
<a href="#"><img src='https://img.shields.io/badge/-Paper-00629B?style=flat&logo=ieee&logoColor=white' alt='arXiv'></a>
|
| 3 |
+
<a href='https://realistic3d-miun.github.io/Research/RT_MPINet/index.html'><img src='https://img.shields.io/badge/Project_Page-Website-green?logo=googlechrome&logoColor=white' alt='Project Page'></a>
|
| 4 |
+
<a href='https://huggingface.co/spaces/3ZadeSSG/RT-MPINet'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Demo_(RT_MPINet)-blue'></a>
|
| 5 |
+
</div>
|
| 6 |
+
|
| 7 |
+
# RT-MPINet
|
| 8 |
+
#### Real-Time View Synthesis with Multiplane Image Network using Multimodal Supervision (RT-MPINet)
|
| 9 |
+
|
| 10 |
+
We present a real-time multiplane image (MPI) network. Unlike existing MPI based approaches that often rely on a separate depth estimation network to guide the network for estimating MPI parameters, our method directly predicts these parameters from a single RGB image. To guide the network we present a multimodal training strategy utilizing joint supervision from view synthesis and depth estimation losses. More details can be found in the paper.
|
| 11 |
+
|
| 12 |
+
**Please head to the [Project Page](https://realistic3d-miun.github.io/Research/RT_MPINet/index.html) to see supplementary materials and Full Code**
|
| 13 |
+
|
| 14 |
+
## Acknowledgements
|
| 15 |
+
- We thank the authors of [AdaMPI](https://github.com/yxuhan/AdaMPI) for their implementation of the homography renderer which has been used in this codebase under `./utils` directory
|
| 16 |
+
- We tank the author of [Deepview renderer](https://github.com/Findeton/deepview) template, which was used in our project page.
|
| 17 |
+
|
| 18 |
+
## Citation
|
| 19 |
+
If you use our work please use following citation:
|
| 20 |
+
```
|
| 21 |
+
@inproceedings{gond2025rtmpi,
|
| 22 |
+
title={Real-Time View Synthesis with Multiplane Image Network using Multimodal Supervision},
|
| 23 |
+
author={Gond, Manu and Shamshirgarha, Mohammadreza and Zerman, Emin and Knorr, Sebastian and Sj{\"o}str{\"o}m, M{\aa}rten},
|
| 24 |
+
booktitle={2025 IEEE 27th International Workshop on Multimedia Signal Processing (MMSP)},
|
| 25 |
+
pages={},
|
| 26 |
+
year={2025},
|
| 27 |
+
organization={IEEE}
|
| 28 |
+
}
|
| 29 |
+
```
|
| 30 |
+
|
app.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import cv2
|
| 5 |
+
import tempfile
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import torchvision.transforms as transforms
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
from model_Small import MMPI as MMPI_S
|
| 10 |
+
from model_Medium import MMPI as MMPI_M
|
| 11 |
+
from model_Large import MMPI as MMPI_L
|
| 12 |
+
import helperFunctions as helper
|
| 13 |
+
import socket
|
| 14 |
+
import parameters as params
|
| 15 |
+
from utils.mpi.homography_sampler import HomographySample
|
| 16 |
+
from utils.utils import (
|
| 17 |
+
render_novel_view,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
# Checkpoint locations for all models
|
| 21 |
+
MODEL_S_LOCATION = "./checkpoint/checkpoint_RT_MPI_Small.pth"
|
| 22 |
+
MODEL_M_LOCATION = "./checkpoint/checkpoint_RT_MPI_Medium.pth"
|
| 23 |
+
MODEL_L_LOCATION = "./checkpoint/checkpoint_RT_MPI_Large.pth"
|
| 24 |
+
|
| 25 |
+
DEVICE = "cpu"
|
| 26 |
+
|
| 27 |
+
def getPositionVector(x, y, z, pose):
|
| 28 |
+
pose[0,0,3] = x
|
| 29 |
+
pose[0,1,3] = y
|
| 30 |
+
pose[0,2,3] = z
|
| 31 |
+
return pose
|
| 32 |
+
|
| 33 |
+
def generateCircularTrajectory(radius, num_frames):
|
| 34 |
+
angles = np.linspace(0, 2 * np.pi, num_frames, endpoint=False)
|
| 35 |
+
return [[radius * np.cos(angle), radius * np.sin(angle), 0] for angle in angles]
|
| 36 |
+
|
| 37 |
+
def generateWiggleTrajectory(radius, num_frames):
|
| 38 |
+
angles = np.linspace(0, 2 * np.pi, num_frames, endpoint=False)
|
| 39 |
+
return [[radius * np.cos(angle), 0, radius * np.sin(angle)] for angle in angles]
|
| 40 |
+
|
| 41 |
+
def create_video_from_memory(frames, fps=60):
|
| 42 |
+
if not frames:
|
| 43 |
+
return None
|
| 44 |
+
height, width, _ = frames[0].shape
|
| 45 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 46 |
+
temp_video = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
|
| 47 |
+
out = cv2.VideoWriter(temp_video.name, fourcc, fps, (width, height))
|
| 48 |
+
for frame in frames:
|
| 49 |
+
out.write(frame)
|
| 50 |
+
out.release()
|
| 51 |
+
return temp_video.name
|
| 52 |
+
|
| 53 |
+
def process_image(img, video_type, radius, num_frames, num_loops, model_type, resolution):
|
| 54 |
+
# Parse resolution string
|
| 55 |
+
height, width = map(int, resolution.lower().split("x"))
|
| 56 |
+
|
| 57 |
+
# Select model class and checkpoint
|
| 58 |
+
if model_type == "Small":
|
| 59 |
+
model_class = MMPI_S
|
| 60 |
+
checkpoint = MODEL_S_LOCATION
|
| 61 |
+
elif model_type == "Medium":
|
| 62 |
+
model_class = MMPI_M
|
| 63 |
+
checkpoint = MODEL_M_LOCATION
|
| 64 |
+
else:
|
| 65 |
+
model_class = MMPI_L
|
| 66 |
+
checkpoint = MODEL_L_LOCATION
|
| 67 |
+
|
| 68 |
+
# Load model
|
| 69 |
+
model = model_class(total_image_input=params.params_number_input, height=height, width=width)
|
| 70 |
+
model = helper.load_Checkpoint(checkpoint, model, load_cpu=True)
|
| 71 |
+
model.to(DEVICE)
|
| 72 |
+
model.eval()
|
| 73 |
+
|
| 74 |
+
min_side = min(img.width, img.height)
|
| 75 |
+
left = (img.width - min_side) // 2
|
| 76 |
+
top = (img.height - min_side) // 2
|
| 77 |
+
right = left + min_side
|
| 78 |
+
bottom = top + min_side
|
| 79 |
+
img = img.crop((left, top, right, bottom))
|
| 80 |
+
|
| 81 |
+
if video_type == "Circle":
|
| 82 |
+
trajectory = generateCircularTrajectory(radius, num_frames)
|
| 83 |
+
elif video_type == "Swing":
|
| 84 |
+
trajectory = generateWiggleTrajectory(radius, num_frames)
|
| 85 |
+
else:
|
| 86 |
+
trajectory = generateCircularTrajectory(radius, num_frames)
|
| 87 |
+
|
| 88 |
+
transform = transforms.Compose([
|
| 89 |
+
transforms.Resize((height, width)),
|
| 90 |
+
transforms.ToTensor()
|
| 91 |
+
])
|
| 92 |
+
img_input = transform(img).to(DEVICE).unsqueeze(0)
|
| 93 |
+
|
| 94 |
+
grid = params.get_disparity_all_src().unsqueeze(0).to(DEVICE)
|
| 95 |
+
k_tgt = torch.tensor([
|
| 96 |
+
[0.58, 0, 0.5],
|
| 97 |
+
[0, 0.58, 0.5],
|
| 98 |
+
[0, 0, 1]]).to(DEVICE)
|
| 99 |
+
k_tgt[0, :] *= height
|
| 100 |
+
k_tgt[1, :] *= width
|
| 101 |
+
k_tgt = k_tgt.unsqueeze(0)
|
| 102 |
+
k_src_inv = torch.inverse(k_tgt)
|
| 103 |
+
pose = torch.eye(4).to(DEVICE).unsqueeze(0)
|
| 104 |
+
|
| 105 |
+
homography_sampler = HomographySample(height, width, DEVICE)
|
| 106 |
+
|
| 107 |
+
with torch.no_grad():
|
| 108 |
+
rgb_layers, sigma_layers = model.get_layers(img_input, height=height, width=width)
|
| 109 |
+
|
| 110 |
+
predicted_depth = model.get_depth(img_input)
|
| 111 |
+
predicted_depth = (predicted_depth-predicted_depth.min())/(predicted_depth.max()-predicted_depth.min())
|
| 112 |
+
img_predicted_depth = predicted_depth.squeeze().cpu().detach().numpy()
|
| 113 |
+
img_predicted_depth_colored = plt.get_cmap('inferno')(img_predicted_depth / np.max(img_predicted_depth))[:, :, :3]
|
| 114 |
+
img_predicted_depth_colored = (img_predicted_depth_colored * 255).astype(np.uint8)
|
| 115 |
+
img_predicted_depth_colored = Image.fromarray(img_predicted_depth_colored)
|
| 116 |
+
|
| 117 |
+
layer_depth = model.get_layer_depth(img_input, grid)
|
| 118 |
+
img_layer_depth = layer_depth.squeeze().cpu().detach().numpy()
|
| 119 |
+
img_layer_depth_colored = plt.get_cmap('inferno')(img_layer_depth / np.max(img_layer_depth))[:, :, :3]
|
| 120 |
+
img_layer_depth_colored = (img_layer_depth_colored * 255).astype(np.uint8)
|
| 121 |
+
img_layer_depth_colored = Image.fromarray(img_layer_depth_colored)
|
| 122 |
+
|
| 123 |
+
single_loop_frames = []
|
| 124 |
+
for idx, pose_coords in enumerate(trajectory):
|
| 125 |
+
#print(f" - Rendering frame {idx + 1}/{len(trajectory)}", end="\r")
|
| 126 |
+
with torch.no_grad():
|
| 127 |
+
target_pose = getPositionVector(pose_coords[0], pose_coords[1], pose_coords[2], pose)
|
| 128 |
+
output_img = render_novel_view(rgb_layers,
|
| 129 |
+
sigma_layers,
|
| 130 |
+
grid,
|
| 131 |
+
target_pose,
|
| 132 |
+
k_src_inv,
|
| 133 |
+
k_tgt,
|
| 134 |
+
homography_sampler)
|
| 135 |
+
|
| 136 |
+
img_np = output_img.detach().cpu().squeeze(0).permute(1, 2, 0).numpy()
|
| 137 |
+
img_np = (img_np * 255).astype(np.uint8)
|
| 138 |
+
img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
| 139 |
+
single_loop_frames.append(img_bgr)
|
| 140 |
+
|
| 141 |
+
final_frames = single_loop_frames * int(num_loops)
|
| 142 |
+
|
| 143 |
+
video_path = create_video_from_memory(final_frames)
|
| 144 |
+
#print("Video generation complete!")
|
| 145 |
+
|
| 146 |
+
return video_path, img_predicted_depth_colored, img_layer_depth_colored
|
| 147 |
+
|
| 148 |
+
with gr.Blocks(title="RT-MPINet", theme="default") as demo:
|
| 149 |
+
gr.Markdown(
|
| 150 |
+
"""
|
| 151 |
+
## Parallax Video Generator via Real-Time Multiplane Image Network (RT-MPINet)
|
| 152 |
+
We use a smaller 256x256 model for faster inference on CPU instances.
|
| 153 |
+
|
| 154 |
+
#### Notes:
|
| 155 |
+
1. Use a higher number of frames (>80) and loops (>4) to get a smoother video.
|
| 156 |
+
2. The default uses 60 frames and 4 camera loops for fast video generation.
|
| 157 |
+
3. We have 3 models available (larger the model, slower the inference):
|
| 158 |
+
* **Small:** 6.6 Million parameters
|
| 159 |
+
* **Medium:** 69 Million parameters
|
| 160 |
+
* **Large:** 288 Million parameters (Not available in this demo due to storage limits, you need to download this model and run locally)
|
| 161 |
+
""")
|
| 162 |
+
with gr.Row():
|
| 163 |
+
img_input = gr.Image(type="pil", label="Upload Image")
|
| 164 |
+
video_type = gr.Dropdown(["Circle", "Swing"], label="Video Type", value="Swing")
|
| 165 |
+
with gr.Column():
|
| 166 |
+
with gr.Accordion("Advanced Settings", open=False):
|
| 167 |
+
radius = gr.Slider(0.001, 0.1, value=0.05, label="Radius (for Circle/Swing)")
|
| 168 |
+
num_frames = gr.Slider(10, 180, value=60, step=1, label="Frames per Loop")
|
| 169 |
+
num_loops = gr.Slider(1, 10, value=4, step=1, label="Number of Loops")
|
| 170 |
+
with gr.Column():
|
| 171 |
+
model_type_dropdown = gr.Dropdown(["Small", "Medium"], label="Model Type", value="Medium")
|
| 172 |
+
resolution_dropdown = gr.Dropdown(["256x256", "384x384", "512x512"], label="Input Resolution", value="384x384")
|
| 173 |
+
generate_btn = gr.Button("Generate Video", variant="primary")
|
| 174 |
+
|
| 175 |
+
with gr.Row():
|
| 176 |
+
video_output = gr.Video(label="Generated Video")
|
| 177 |
+
depth_output = gr.Image(label="Depth Map - From Depth Decoder")
|
| 178 |
+
layer_depth_output = gr.Image(label="Layer Depth Map - From MPI Layers")
|
| 179 |
+
|
| 180 |
+
def toggle_custom_path(video_type_selection):
|
| 181 |
+
is_custom = (video_type_selection == "Custom")
|
| 182 |
+
return gr.update(visible=is_custom)
|
| 183 |
+
|
| 184 |
+
generate_btn.click(fn=process_image,
|
| 185 |
+
inputs=[img_input, video_type, radius, num_frames, num_loops, model_type_dropdown, resolution_dropdown],
|
| 186 |
+
outputs=[video_output, depth_output, layer_depth_output])
|
| 187 |
+
|
| 188 |
+
demo.launch()
|
helperFunctions.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import os
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
def save_checkpoint(model, filelocation, save_parallel = True):
|
| 6 |
+
if save_parallel:
|
| 7 |
+
torch.save(model.module.state_dict(), filelocation)
|
| 8 |
+
else:
|
| 9 |
+
torch.save(model.state_dict(), filelocation)
|
| 10 |
+
|
| 11 |
+
def load_Checkpoint(fileLocation,model, load_cpu=False):
|
| 12 |
+
if load_cpu:
|
| 13 |
+
model.load_state_dict(torch.load(fileLocation,map_location=lambda storage, loc: storage))
|
| 14 |
+
else:
|
| 15 |
+
model.load_state_dict(torch.load(fileLocation))
|
| 16 |
+
return model
|
| 17 |
+
|
| 18 |
+
def writeLog(logList, filename):
|
| 19 |
+
with open(filename, 'w') as outfile:
|
| 20 |
+
outfile.write("\n".join(logList))
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def kl_loss(mu, logvar):
|
| 24 |
+
return -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).mean()
|
| 25 |
+
|
| 26 |
+
|
helper_image_functions.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
Author: Manu Gond (manu.gond@miun.se)
|
| 3 |
+
Date: Nov-15-2022
|
| 4 |
+
Objective: Accumulation of some general functions which I
|
| 5 |
+
use daily in my code realted to image relasted task.
|
| 6 |
+
The function names and parameters are self explanetory.
|
| 7 |
+
Requirements: Installed python libraries which have been imported.
|
| 8 |
+
'''
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torchvision.utils import save_image
|
| 12 |
+
from torchvision.transforms import transforms
|
| 13 |
+
import torchmetrics
|
| 14 |
+
import cv2
|
| 15 |
+
import numpy as np
|
| 16 |
+
from PIL import Image
|
| 17 |
+
import utils
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
#======================= Read and Write =====================#
|
| 21 |
+
def readImage(location):
|
| 22 |
+
image = Image.open(location).convert("RGB")
|
| 23 |
+
return image
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def writeImage(image, location):
|
| 27 |
+
image.save(location)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def writeTensorImage(image, filename):
|
| 31 |
+
save_image(image, filename)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def removeChannel(sourceLocation, targetLocation):
|
| 35 |
+
img = readImage(sourceLocation)
|
| 36 |
+
writeImage(img, targetLocation)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def getImageTransform(width, height):
|
| 40 |
+
transform = transforms.Compose([transforms.Resize((height,width)),
|
| 41 |
+
transforms.ToTensor()])
|
| 42 |
+
return transform
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def convertTensor(image):
|
| 46 |
+
transform = getImageTransform(image.size[0], image.size[1])
|
| 47 |
+
image = transform(image)
|
| 48 |
+
return image
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
#=================== 360 Images =======================#
|
| 52 |
+
|
| 53 |
+
def rotateERP180(image):
|
| 54 |
+
'''
|
| 55 |
+
:param image: PIL Image
|
| 56 |
+
:return: BxHxW Torch Tensor Image
|
| 57 |
+
'''
|
| 58 |
+
W = image.size[0]
|
| 59 |
+
H = image.size[1]
|
| 60 |
+
transform = getImageTransform(W, H)
|
| 61 |
+
image = transform(image)
|
| 62 |
+
image1 = image[:, :, 0:(W//2)]
|
| 63 |
+
image2 = image[:, :, (W//2):W]
|
| 64 |
+
image3 = torch.zeros(image.size())
|
| 65 |
+
image3[:, :, 0:(W//2)] = image2
|
| 66 |
+
image3[:, :, (W//2):W] = image1
|
| 67 |
+
return image3
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def convertERP2Cube(e_img, face_w=256, mode='bilinear', cube_format='dice'):
|
| 71 |
+
'''
|
| 72 |
+
e_img: ndarray in shape of [H, W, *]
|
| 73 |
+
face_w: int, the length of each face of the cubemap
|
| 74 |
+
'''
|
| 75 |
+
assert len(e_img.shape) == 3
|
| 76 |
+
h, w = e_img.shape[:2]
|
| 77 |
+
if mode == 'bilinear':
|
| 78 |
+
order = 1
|
| 79 |
+
elif mode == 'nearest':
|
| 80 |
+
order = 0
|
| 81 |
+
else:
|
| 82 |
+
raise NotImplementedError('unknown mode')
|
| 83 |
+
|
| 84 |
+
xyz = utils.xyzcube(face_w)
|
| 85 |
+
uv = utils.xyz2uv(xyz)
|
| 86 |
+
coor_xy = utils.uv2coor(uv, h, w)
|
| 87 |
+
|
| 88 |
+
cubemap = np.stack([
|
| 89 |
+
utils.sample_equirec(e_img[..., i], coor_xy, order=order)
|
| 90 |
+
for i in range(e_img.shape[2])
|
| 91 |
+
], axis=-1)
|
| 92 |
+
|
| 93 |
+
if cube_format == 'horizon':
|
| 94 |
+
pass
|
| 95 |
+
elif cube_format == 'list':
|
| 96 |
+
cubemap = utils.cube_h2list(cubemap)
|
| 97 |
+
elif cube_format == 'dict':
|
| 98 |
+
cubemap = utils.cube_h2dict(cubemap)
|
| 99 |
+
elif cube_format == 'dice':
|
| 100 |
+
cubemap = utils.cube_h2dice(cubemap)
|
| 101 |
+
else:
|
| 102 |
+
raise NotImplementedError()
|
| 103 |
+
return cubemap
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def convertCube2ERP(cubemap, h, w, mode='bilinear', cube_format='dice'):
|
| 107 |
+
if mode == 'bilinear':
|
| 108 |
+
order = 1
|
| 109 |
+
elif mode == 'nearest':
|
| 110 |
+
order = 0
|
| 111 |
+
else:
|
| 112 |
+
raise NotImplementedError('unknown mode')
|
| 113 |
+
|
| 114 |
+
if cube_format == 'horizon':
|
| 115 |
+
pass
|
| 116 |
+
elif cube_format == 'list':
|
| 117 |
+
cubemap = utils.cube_list2h(cubemap)
|
| 118 |
+
elif cube_format == 'dict':
|
| 119 |
+
cubemap = utils.cube_dict2h(cubemap)
|
| 120 |
+
elif cube_format == 'dice':
|
| 121 |
+
cubemap = utils.cube_dice2h(cubemap)
|
| 122 |
+
else:
|
| 123 |
+
raise NotImplementedError('unknown cube_format')
|
| 124 |
+
assert len(cubemap.shape) == 3
|
| 125 |
+
assert cubemap.shape[0] * 6 == cubemap.shape[1]
|
| 126 |
+
assert w % 8 == 0
|
| 127 |
+
face_w = cubemap.shape[0]
|
| 128 |
+
|
| 129 |
+
uv = utils.equirect_uvgrid(h, w)
|
| 130 |
+
u, v = np.split(uv, 2, axis=-1)
|
| 131 |
+
u = u[..., 0]
|
| 132 |
+
v = v[..., 0]
|
| 133 |
+
cube_faces = np.stack(np.split(cubemap, 6, 1), 0)
|
| 134 |
+
|
| 135 |
+
# Get face id to each pixel: 0F 1R 2B 3L 4U 5D
|
| 136 |
+
tp = utils.equirect_facetype(h, w)
|
| 137 |
+
coor_x = np.zeros((h, w))
|
| 138 |
+
coor_y = np.zeros((h, w))
|
| 139 |
+
|
| 140 |
+
for i in range(4):
|
| 141 |
+
mask = (tp == i)
|
| 142 |
+
coor_x[mask] = 0.5 * np.tan(u[mask] - np.pi * i / 2)
|
| 143 |
+
coor_y[mask] = -0.5 * np.tan(v[mask]) / np.cos(u[mask] - np.pi * i / 2)
|
| 144 |
+
|
| 145 |
+
mask = (tp == 4)
|
| 146 |
+
c = 0.5 * np.tan(np.pi / 2 - v[mask])
|
| 147 |
+
coor_x[mask] = c * np.sin(u[mask])
|
| 148 |
+
coor_y[mask] = c * np.cos(u[mask])
|
| 149 |
+
|
| 150 |
+
mask = (tp == 5)
|
| 151 |
+
c = 0.5 * np.tan(np.pi / 2 - np.abs(v[mask]))
|
| 152 |
+
coor_x[mask] = c * np.sin(u[mask])
|
| 153 |
+
coor_y[mask] = -c * np.cos(u[mask])
|
| 154 |
+
|
| 155 |
+
# Final renormalize
|
| 156 |
+
coor_x = (np.clip(coor_x, -0.5, 0.5) + 0.5) * face_w
|
| 157 |
+
coor_y = (np.clip(coor_y, -0.5, 0.5) + 0.5) * face_w
|
| 158 |
+
|
| 159 |
+
equirec = np.stack([
|
| 160 |
+
utils.sample_cubefaces(cube_faces[..., i], tp, coor_y, coor_x, order=order)
|
| 161 |
+
for i in range(cube_faces.shape[3])
|
| 162 |
+
], axis=-1)
|
| 163 |
+
return equirec
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def convertCube2Slices(image):
|
| 168 |
+
'''
|
| 169 |
+
:param image: Image numpy array
|
| 170 |
+
:return: List of Torch Tensors, CxHxW
|
| 171 |
+
'''
|
| 172 |
+
image = convertTensor(image)
|
| 173 |
+
C, H, W = image.size()
|
| 174 |
+
#print(C,H,W)
|
| 175 |
+
top = torch.zeros((C,W//4,W//4))
|
| 176 |
+
left = torch.zeros(top.size())
|
| 177 |
+
front = torch.zeros(top.size())
|
| 178 |
+
right = torch.zeros(top.size())
|
| 179 |
+
back = torch.zeros(top.size())
|
| 180 |
+
bottom = torch.zeros(top.size())
|
| 181 |
+
|
| 182 |
+
top = image[:, 0:H//3, (W//4):(W//4)*2]
|
| 183 |
+
left = image[:, (H//3):(H//3)*2, 0:W//4]
|
| 184 |
+
front = image[:, (H//3):(H//3)*2, (W//4):(W//4)*2]
|
| 185 |
+
right = image[:, (H//3):(H//3)*2, (W//4)*2:(W//4)*3]
|
| 186 |
+
back = image[:, (H // 3):(H // 3) * 2, (W // 4) * 3:]
|
| 187 |
+
bottom = image[:, (H//3)*2:, (W//4):(W//4)*2]
|
| 188 |
+
|
| 189 |
+
'''
|
| 190 |
+
save_image(top, 'top.png')
|
| 191 |
+
save_image(left, 'left.png')
|
| 192 |
+
save_image(front, 'front.png')
|
| 193 |
+
save_image(right, 'right.png')
|
| 194 |
+
save_image(back, 'back.png')
|
| 195 |
+
save_image(bottom, 'bottom.png')
|
| 196 |
+
'''
|
| 197 |
+
return [top, left, front, right, back, bottom]
|
| 198 |
+
|
| 199 |
+
def convertSlicesToCube(imageList):
|
| 200 |
+
'''
|
| 201 |
+
top = convertTensor(readImage(imageList[0]))
|
| 202 |
+
left = convertTensor(readImage(imageList[1]))
|
| 203 |
+
front = convertTensor(readImage(imageList[2]))
|
| 204 |
+
right = convertTensor(readImage(imageList[3]))
|
| 205 |
+
back = convertTensor(readImage(imageList[4]))
|
| 206 |
+
bottom = convertTensor(readImage(imageList[5]))
|
| 207 |
+
'''
|
| 208 |
+
top = imageList[0]
|
| 209 |
+
left = imageList[1]
|
| 210 |
+
front = imageList[2]
|
| 211 |
+
right = imageList[3]
|
| 212 |
+
back = imageList[4]
|
| 213 |
+
bottom = imageList[5]
|
| 214 |
+
|
| 215 |
+
C, H, W = 3, top.size()[1]*3, top.size()[2]*4
|
| 216 |
+
cube = torch.zeros((C, H, W))
|
| 217 |
+
|
| 218 |
+
cube[:, 0:H//3, (W//4):(W//4)*2] = top
|
| 219 |
+
cube[:, (H // 3):(H // 3) * 2, 0:W // 4] = left
|
| 220 |
+
cube[:, (H // 3):(H // 3) * 2, (W // 4):(W // 4) * 2] = front
|
| 221 |
+
cube[:, (H // 3):(H // 3) * 2, (W // 4) * 2:(W // 4) * 3] = right
|
| 222 |
+
cube[:, (H // 3):(H // 3) * 2, (W // 4) * 3:] = back
|
| 223 |
+
cube[:, (H // 3) * 2:, (W // 4):(W // 4) * 2] = bottom
|
| 224 |
+
|
| 225 |
+
return cube
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
#=================== Quality Measures =======================#
|
| 230 |
+
'''
|
| 231 |
+
Predicted Shape : BxCxHxW
|
| 232 |
+
Original Shape : BxCxHxW
|
| 233 |
+
Data Type: Torch Tensor
|
| 234 |
+
'''
|
| 235 |
+
def getSSIM(predicted, original):
|
| 236 |
+
SSIM = torchmetrics.StructuralSimilarityIndexMeasure()
|
| 237 |
+
return SSIM(predicted, original).item()
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def getPSNR(predicted, original):
|
| 241 |
+
PSNR = torchmetrics.PeakSignalNoiseRatio()
|
| 242 |
+
return PSNR(predicted, original).item()
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def getMSE(predicted, original):
|
| 246 |
+
MSE = torchmetrics.MeanSquaredError()
|
| 247 |
+
return MSE(predicted, original).item()
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def getMAE(predicted, original):
|
| 251 |
+
MAE = torchmetrics.MeanAbsoluteError()
|
| 252 |
+
return MAE(predicted, original).item()
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
if __name__ == "__main__":
|
| 257 |
+
|
| 258 |
+
'''
|
| 259 |
+
img = readImage("31_image_0_0.png")
|
| 260 |
+
img = convertERP2Cube(e_img=np.asarray(img), face_w=256)
|
| 261 |
+
img = Image.fromarray(img.astype('uint8'),'RGB')
|
| 262 |
+
convertCube2Slices(img)
|
| 263 |
+
'''
|
| 264 |
+
#image = convertSlicesToCube(["top.png", "left.png", "front.png", "right.png", "back.png", "bottom.png"])
|
| 265 |
+
#writeTensorImage(image,'this.png')
|
| 266 |
+
|
| 267 |
+
'''
|
| 268 |
+
writeImage(img, 'cube.png')
|
| 269 |
+
|
| 270 |
+
img = readImage('cube.png')
|
| 271 |
+
img = convertCube2ERP(np.asarray(img),512,1024)
|
| 272 |
+
img = Image.fromarray(img.astype('uint8'),'RGB')
|
| 273 |
+
writeImage(img, 'cubeERP.png')
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
img1 = readImage("31_image_0_0.png")
|
| 277 |
+
img2 = readImage("cubeERP.png")
|
| 278 |
+
img1 = convertTensor(img1)
|
| 279 |
+
img2 = convertTensor(img2)
|
| 280 |
+
print(getSSIM(img1.unsqueeze(0), img2.unsqueeze(0)))
|
| 281 |
+
'''
|
| 282 |
+
|
| 283 |
+
#img = rotateERP180(img)
|
| 284 |
+
#writeTensorImage(img, 'rotated_image.png')
|
| 285 |
+
#img = convertTensor(img)
|
| 286 |
+
#print(getMAE(img.unsqueeze(0),img.unsqueeze(0)))
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
|
model_Large.py
ADDED
|
@@ -0,0 +1,535 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import warnings
|
| 5 |
+
warnings.filterwarnings("ignore")
|
| 6 |
+
import torchvision
|
| 7 |
+
import parameters as params
|
| 8 |
+
import timm
|
| 9 |
+
|
| 10 |
+
class DinoV2FeatureExtractor(nn.Module):
|
| 11 |
+
def __init__(self, out_channels=256, out_size=(64, 64)):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.dino = timm.create_model('vit_base_patch14_dinov2.lvd142m', pretrained=False)
|
| 14 |
+
self.dino.eval()
|
| 15 |
+
for p in self.dino.parameters():
|
| 16 |
+
p.requires_grad = False
|
| 17 |
+
|
| 18 |
+
self.out_size = out_size
|
| 19 |
+
self.feat_proj = nn.Sequential(
|
| 20 |
+
nn.Conv2d(self.dino.embed_dim, out_channels, kernel_size=1),
|
| 21 |
+
nn.ReLU(),
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 25 |
+
x = F.interpolate(x, size=(518, 518), mode='bilinear', align_corners=False)
|
| 26 |
+
patch_tokens = self.dino.forward_features(x)
|
| 27 |
+
patch_tokens = patch_tokens[:, 1:]
|
| 28 |
+
B, N, C = patch_tokens.shape
|
| 29 |
+
h = w = int(N ** 0.5)
|
| 30 |
+
feat_map = patch_tokens.transpose(1, 2).reshape(B, C, h, w) # [B, C, H', W']
|
| 31 |
+
feat_map = F.interpolate(feat_map, size=self.out_size, mode='bilinear', align_corners=False)
|
| 32 |
+
return self.feat_proj(feat_map)
|
| 33 |
+
|
| 34 |
+
def getLinearLayer(in_feat, out_feat, activation=nn.ReLU(True)):
|
| 35 |
+
return nn.Sequential(
|
| 36 |
+
nn.Linear(in_features=in_feat, out_features=out_feat, bias=True),
|
| 37 |
+
activation
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
def getConvLayer(in_channel,out_channel,stride=1,padding=1,activation=nn.ReLU()):
|
| 41 |
+
return nn.Sequential(nn.Conv2d(in_channel,
|
| 42 |
+
out_channel,
|
| 43 |
+
kernel_size=3,
|
| 44 |
+
stride=stride,
|
| 45 |
+
padding=padding,
|
| 46 |
+
padding_mode='reflect'),
|
| 47 |
+
activation)
|
| 48 |
+
|
| 49 |
+
def getConvTransposeLayer(in_channel, out_channel,kernel=3,stride=1,padding=1,activation=nn.ReLU()):
|
| 50 |
+
return nn.Sequential(nn.ConvTranspose2d(in_channel,
|
| 51 |
+
out_channel,
|
| 52 |
+
kernel_size = kernel,
|
| 53 |
+
stride=stride,
|
| 54 |
+
padding=padding),
|
| 55 |
+
activation)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class ResidualBlock(nn.Module):
|
| 59 |
+
def __init__(self, in_channels, out_channels, stride=1):
|
| 60 |
+
super(ResidualBlock, self).__init__()
|
| 61 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
|
| 62 |
+
self.relu = nn.ReLU()
|
| 63 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
| 64 |
+
self.stride = stride
|
| 65 |
+
|
| 66 |
+
self.shortcut = nn.Sequential()
|
| 67 |
+
if stride != 1 or in_channels != out_channels:
|
| 68 |
+
self.shortcut = nn.Sequential(
|
| 69 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
|
| 70 |
+
nn.BatchNorm2d(out_channels)
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
def forward(self, x):
|
| 74 |
+
residual = x
|
| 75 |
+
|
| 76 |
+
out = self.conv1(x)
|
| 77 |
+
out = self.relu(out)
|
| 78 |
+
|
| 79 |
+
out = self.conv2(out)
|
| 80 |
+
|
| 81 |
+
out = out + self.shortcut(residual)
|
| 82 |
+
out = self.relu(out)
|
| 83 |
+
return out
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# class ResidualBlock(nn.Module):
|
| 87 |
+
# def __init__(self, in_channels, out_channels, stride=1, expansion=4):
|
| 88 |
+
# super().__init__()
|
| 89 |
+
# mid_channels = out_channels // expansion
|
| 90 |
+
# self.pw_reduce = nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False)
|
| 91 |
+
# self.bn1 = nn.BatchNorm2d(mid_channels)
|
| 92 |
+
# self.dw = nn.Conv2d(mid_channels, mid_channels, kernel_size=3,
|
| 93 |
+
# stride=stride, padding=1, groups=mid_channels, bias=False)
|
| 94 |
+
# self.bn2 = nn.BatchNorm2d(mid_channels)
|
| 95 |
+
# self.pw_expand = nn.Conv2d(mid_channels, out_channels, kernel_size=1, bias=False)
|
| 96 |
+
# self.bn3 = nn.BatchNorm2d(out_channels)
|
| 97 |
+
# self.relu = nn.ReLU(inplace=True)
|
| 98 |
+
# self.stride = stride
|
| 99 |
+
# if stride != 1 or in_channels != out_channels:
|
| 100 |
+
# self.shortcut = nn.Sequential(
|
| 101 |
+
# nn.Conv2d(in_channels, out_channels, kernel_size=1,
|
| 102 |
+
# stride=stride, bias=False),
|
| 103 |
+
# nn.BatchNorm2d(out_channels),
|
| 104 |
+
# )
|
| 105 |
+
# else:
|
| 106 |
+
# self.shortcut = nn.Identity()
|
| 107 |
+
|
| 108 |
+
# def forward(self, x):
|
| 109 |
+
# identity = x
|
| 110 |
+
|
| 111 |
+
# out = self.pw_reduce(x)
|
| 112 |
+
# out = self.bn1(out)
|
| 113 |
+
# out = self.relu(out)
|
| 114 |
+
|
| 115 |
+
# out = self.dw(out)
|
| 116 |
+
# out = self.bn2(out)
|
| 117 |
+
# out = self.relu(out)
|
| 118 |
+
|
| 119 |
+
# out = self.pw_expand(out)
|
| 120 |
+
# out = self.bn3(out)
|
| 121 |
+
|
| 122 |
+
# out += self.shortcut(identity)
|
| 123 |
+
# out = self.relu(out)
|
| 124 |
+
# return out
|
| 125 |
+
|
| 126 |
+
class FeatureNet(nn.Module):
|
| 127 |
+
def __init__(self,height,width):
|
| 128 |
+
super().__init__()
|
| 129 |
+
model = torchvision.models.resnet152(pretrained=False)
|
| 130 |
+
layers = list(model.children())
|
| 131 |
+
self.FeatureEncoder = torch.nn.Sequential(*layers[:5].copy())
|
| 132 |
+
self.expand_layer = ResidualBlock(256, 500)
|
| 133 |
+
|
| 134 |
+
def forward(self, x):
|
| 135 |
+
x = self.FeatureEncoder(x)
|
| 136 |
+
x = self.expand_layer(x)
|
| 137 |
+
return x
|
| 138 |
+
|
| 139 |
+
def apply_feature_encoder(self, x):
|
| 140 |
+
x = self.FeatureEncoder(x)
|
| 141 |
+
x = self.expand_layer(x)
|
| 142 |
+
return x
|
| 143 |
+
|
| 144 |
+
class Encoder(nn.Module):
|
| 145 |
+
def __init__(self, height, width, total_image_input=1):
|
| 146 |
+
super().__init__()
|
| 147 |
+
self.height = height
|
| 148 |
+
self.width = width
|
| 149 |
+
self.encoder_pre = ResidualBlock((total_image_input*3), 20)
|
| 150 |
+
self.encoder_layer1 = ResidualBlock(20, 30)
|
| 151 |
+
self.encoder_layer2 = ResidualBlock(30, 50)
|
| 152 |
+
|
| 153 |
+
self.encoder_layer3 = nn.Sequential(
|
| 154 |
+
ResidualBlock(50, 100),
|
| 155 |
+
nn.MaxPool2d(kernel_size=2, stride=2)
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
self.encoder_layer4 = ResidualBlock(100, 500)
|
| 159 |
+
self.encoder_layer5 = nn.Sequential(
|
| 160 |
+
ResidualBlock(500, 500),
|
| 161 |
+
nn.MaxPool2d(kernel_size=2, stride=2)
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
self.encoder_layer6 = ResidualBlock(500, 500)
|
| 165 |
+
self.encoder_layer7 = nn.Sequential(
|
| 166 |
+
ResidualBlock(500, 500),
|
| 167 |
+
nn.MaxPool2d(kernel_size=2, stride=2)
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
self.encoder_layer8 = ResidualBlock(500, 1000)
|
| 171 |
+
self.encoder_layer9 = nn.Sequential(
|
| 172 |
+
ResidualBlock(1000, 1000),
|
| 173 |
+
nn.MaxPool2d(kernel_size=2, stride=2)
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
self.encoder_layer10 = ResidualBlock(1000, 1000)
|
| 177 |
+
self.encoder_layer11 = ResidualBlock(1000, 1000)
|
| 178 |
+
|
| 179 |
+
def forward(self, x, height=None, width=None):
|
| 180 |
+
if height == None and width == None:
|
| 181 |
+
height = self.height
|
| 182 |
+
width = self.width
|
| 183 |
+
|
| 184 |
+
x = self.encoder_pre(x)
|
| 185 |
+
x = self.encoder_layer1(x)
|
| 186 |
+
x = self.encoder_layer2(x)
|
| 187 |
+
skip1 = self.encoder_layer3(x)
|
| 188 |
+
|
| 189 |
+
x = self.encoder_layer4(skip1)
|
| 190 |
+
skip2 = self.encoder_layer5(x)
|
| 191 |
+
|
| 192 |
+
x = self.encoder_layer6(skip2)
|
| 193 |
+
skip3 = self.encoder_layer7(x)
|
| 194 |
+
|
| 195 |
+
x = self.encoder_layer8(skip3)
|
| 196 |
+
skip4 = self.encoder_layer9(x)
|
| 197 |
+
|
| 198 |
+
x = self.encoder_layer10(skip4)
|
| 199 |
+
x = self.encoder_layer11(x)
|
| 200 |
+
|
| 201 |
+
return x, [skip1, skip2, skip3, skip4]
|
| 202 |
+
|
| 203 |
+
class DecoderRGB(nn.Module):
|
| 204 |
+
def __init__(self,height,width):
|
| 205 |
+
super().__init__()
|
| 206 |
+
self.height = height
|
| 207 |
+
self.width = width
|
| 208 |
+
self.decoder_layer1 = ResidualBlock(1000, 1000)
|
| 209 |
+
self.decoder_layer2 = ResidualBlock(1000, 1000)
|
| 210 |
+
self.decoder_layer3 = ResidualBlock(1000, 1000)
|
| 211 |
+
|
| 212 |
+
self.decoder_layer4 = nn.Sequential(
|
| 213 |
+
nn.ConvTranspose2d(1000, 500, 2, stride=2, padding=0),
|
| 214 |
+
nn.ReLU(True)
|
| 215 |
+
)
|
| 216 |
+
self.decoder_layer5 = ResidualBlock(500, 500)
|
| 217 |
+
|
| 218 |
+
self.decoder_layer6 = nn.Sequential(
|
| 219 |
+
nn.ConvTranspose2d(500, 500, 2, stride=2, padding=0),
|
| 220 |
+
nn.ReLU(True)
|
| 221 |
+
)
|
| 222 |
+
self.decoder_layer7 = ResidualBlock(500, 500)
|
| 223 |
+
|
| 224 |
+
self.decoder_layer8 = nn.Sequential(
|
| 225 |
+
nn.ConvTranspose2d(500, 100, 2, stride=2, padding=0),
|
| 226 |
+
nn.ReLU(True)
|
| 227 |
+
)
|
| 228 |
+
self.decoder_layer9 = ResidualBlock(100, 100)
|
| 229 |
+
|
| 230 |
+
self.decoder_layer10 = nn.Sequential(
|
| 231 |
+
nn.ConvTranspose2d(100, 100, 2, stride=2, padding=0),
|
| 232 |
+
nn.ReLU(True)
|
| 233 |
+
)
|
| 234 |
+
self.decoder_layer11 = ResidualBlock(100, 100)
|
| 235 |
+
self.decoder_layer12 = ResidualBlock(100, 96)
|
| 236 |
+
self.decoder_layer13 = ResidualBlock(96, 96)
|
| 237 |
+
self.decoder_layer14 = ResidualBlock(96, 96)
|
| 238 |
+
self.decoder_layer15 = nn.Sequential(
|
| 239 |
+
nn.Conv2d(96, 96, 3, stride=1, padding=1),
|
| 240 |
+
nn.Sigmoid()
|
| 241 |
+
)
|
| 242 |
+
self.decoder_layer16 = nn.Sequential(
|
| 243 |
+
nn.Conv2d(96, 96, 3, stride=1, padding=1),
|
| 244 |
+
nn.Sigmoid()
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
def forward(self, x, lower_skip_list, imagenet_features, height=None, width=None):
|
| 248 |
+
if height == None and width == None:
|
| 249 |
+
height = self.height
|
| 250 |
+
width = self.width
|
| 251 |
+
|
| 252 |
+
x = self.decoder_layer1(x)
|
| 253 |
+
x = self.decoder_layer2(x)
|
| 254 |
+
x = x + lower_skip_list[3]
|
| 255 |
+
|
| 256 |
+
x = self.decoder_layer3(x)
|
| 257 |
+
x = self.decoder_layer4(x)
|
| 258 |
+
x = x + lower_skip_list[2]
|
| 259 |
+
|
| 260 |
+
x = self.decoder_layer5(x)
|
| 261 |
+
x = self.decoder_layer6(x)
|
| 262 |
+
x = x + lower_skip_list[1] + imagenet_features
|
| 263 |
+
|
| 264 |
+
x = self.decoder_layer7(x)
|
| 265 |
+
x = self.decoder_layer8(x)
|
| 266 |
+
x = x + lower_skip_list[0]
|
| 267 |
+
|
| 268 |
+
x = self.decoder_layer9(x)
|
| 269 |
+
x = self.decoder_layer10(x)
|
| 270 |
+
x = self.decoder_layer11(x)
|
| 271 |
+
x = self.decoder_layer12(x)
|
| 272 |
+
x = self.decoder_layer13(x)
|
| 273 |
+
x = self.decoder_layer14(x)
|
| 274 |
+
x = self.decoder_layer15(x)
|
| 275 |
+
x = self.decoder_layer16(x)
|
| 276 |
+
x = x.view(x.size()[0], 32, 3, height, width)
|
| 277 |
+
return x
|
| 278 |
+
|
| 279 |
+
class DecoderSigma(nn.Module):
|
| 280 |
+
def __init__(self,height,width):
|
| 281 |
+
super().__init__()
|
| 282 |
+
self.height = height
|
| 283 |
+
self.width = width
|
| 284 |
+
self.decoder_layer1 = ResidualBlock(1000, 1000)
|
| 285 |
+
self.decoder_layer2 = ResidualBlock(1000, 1000)
|
| 286 |
+
self.decoder_layer3 = ResidualBlock(1000, 1000)
|
| 287 |
+
|
| 288 |
+
self.decoder_layer4 = nn.Sequential(
|
| 289 |
+
nn.ConvTranspose2d(1000, 500, 2, stride=2, padding=0),
|
| 290 |
+
nn.ReLU(True)
|
| 291 |
+
)
|
| 292 |
+
self.decoder_layer5 = ResidualBlock(500, 500)
|
| 293 |
+
|
| 294 |
+
self.decoder_layer6 = nn.Sequential(
|
| 295 |
+
nn.ConvTranspose2d(500, 500, 2, stride=2, padding=0),
|
| 296 |
+
nn.ReLU(True)
|
| 297 |
+
)
|
| 298 |
+
self.decoder_layer7 = ResidualBlock(500, 500)
|
| 299 |
+
|
| 300 |
+
self.decoder_layer8 = nn.Sequential(
|
| 301 |
+
nn.ConvTranspose2d(500, 100, 2, stride=2, padding=0),
|
| 302 |
+
nn.ReLU(True)
|
| 303 |
+
)
|
| 304 |
+
self.decoder_layer9 = ResidualBlock(100, 100)
|
| 305 |
+
|
| 306 |
+
self.decoder_layer10 = nn.Sequential(
|
| 307 |
+
nn.ConvTranspose2d(100, 100, 2, stride=2, padding=0),
|
| 308 |
+
nn.ReLU(True)
|
| 309 |
+
)
|
| 310 |
+
self.decoder_layer11 = ResidualBlock(100, 100)
|
| 311 |
+
self.decoder_layer12 = ResidualBlock(100, 50)
|
| 312 |
+
self.decoder_layer13 = ResidualBlock(50, 40)
|
| 313 |
+
self.decoder_layer14 = ResidualBlock(40, 32)
|
| 314 |
+
self.decoder_layer15 = nn.Sequential(
|
| 315 |
+
nn.Conv2d(32, 32, 3, stride=1, padding=1),
|
| 316 |
+
nn.ReLU(True)
|
| 317 |
+
)
|
| 318 |
+
self.decoder_layer16 = nn.Sequential(
|
| 319 |
+
nn.Conv2d(32, 32, 3, stride=1, padding=1),
|
| 320 |
+
nn.ReLU(True)
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
def forward(self, x, lower_skip_list, imagenet_features, height=None, width=None):
|
| 324 |
+
if height == None and width == None:
|
| 325 |
+
height = self.height
|
| 326 |
+
width = self.width
|
| 327 |
+
|
| 328 |
+
x = self.decoder_layer1(x)
|
| 329 |
+
x = self.decoder_layer2(x)
|
| 330 |
+
x = x + lower_skip_list[3]
|
| 331 |
+
|
| 332 |
+
x = self.decoder_layer3(x)
|
| 333 |
+
x = self.decoder_layer4(x)
|
| 334 |
+
x = x + lower_skip_list[2]
|
| 335 |
+
|
| 336 |
+
x = self.decoder_layer5(x)
|
| 337 |
+
x = self.decoder_layer6(x)
|
| 338 |
+
x = x + lower_skip_list[1] + imagenet_features
|
| 339 |
+
|
| 340 |
+
x = self.decoder_layer7(x)
|
| 341 |
+
x = self.decoder_layer8(x)
|
| 342 |
+
x = x + lower_skip_list[0]
|
| 343 |
+
|
| 344 |
+
x = self.decoder_layer9(x)
|
| 345 |
+
x = self.decoder_layer10(x)
|
| 346 |
+
x = self.decoder_layer11(x)
|
| 347 |
+
x = self.decoder_layer12(x)
|
| 348 |
+
x = self.decoder_layer13(x)
|
| 349 |
+
x = self.decoder_layer14(x)
|
| 350 |
+
x = self.decoder_layer15(x)
|
| 351 |
+
x = self.decoder_layer16(x)
|
| 352 |
+
x = x.view(x.size()[0], 32, 1, height, width)
|
| 353 |
+
return x
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
class DecoderDepth(nn.Module):
|
| 357 |
+
def __init__(self,height,width):
|
| 358 |
+
super().__init__()
|
| 359 |
+
self.height = height
|
| 360 |
+
self.width = width
|
| 361 |
+
self.decoder_layer1 = ResidualBlock(1000, 1000)
|
| 362 |
+
self.decoder_layer2 = ResidualBlock(1000, 1000)
|
| 363 |
+
self.decoder_layer3 = ResidualBlock(1000, 1000)
|
| 364 |
+
|
| 365 |
+
self.decoder_layer4 = nn.Sequential(
|
| 366 |
+
nn.ConvTranspose2d(1000, 500, 2, stride=2, padding=0),
|
| 367 |
+
nn.ReLU(True)
|
| 368 |
+
)
|
| 369 |
+
self.decoder_layer5 = ResidualBlock(500, 500)
|
| 370 |
+
|
| 371 |
+
self.decoder_layer6 = nn.Sequential(
|
| 372 |
+
nn.ConvTranspose2d(500, 500, 2, stride=2, padding=0),
|
| 373 |
+
nn.ReLU(True)
|
| 374 |
+
)
|
| 375 |
+
self.decoder_layer7 = ResidualBlock(500, 500)
|
| 376 |
+
|
| 377 |
+
self.decoder_layer8 = nn.Sequential(
|
| 378 |
+
nn.ConvTranspose2d(500, 100, 2, stride=2, padding=0),
|
| 379 |
+
nn.ReLU(True)
|
| 380 |
+
)
|
| 381 |
+
self.decoder_layer9 = ResidualBlock(100, 100)
|
| 382 |
+
|
| 383 |
+
self.decoder_layer10 = nn.Sequential(
|
| 384 |
+
nn.ConvTranspose2d(100, 100, 2, stride=2, padding=0),
|
| 385 |
+
nn.ReLU(True)
|
| 386 |
+
)
|
| 387 |
+
self.decoder_layer11 = ResidualBlock(100, 100)
|
| 388 |
+
self.decoder_layer12 = ResidualBlock(100, 50)
|
| 389 |
+
self.decoder_layer13 = ResidualBlock(50, 40)
|
| 390 |
+
self.decoder_layer14 = ResidualBlock(40, 16)
|
| 391 |
+
self.decoder_layer15 = nn.Sequential(
|
| 392 |
+
nn.Conv2d(16, 8, 3, stride=1, padding=1),
|
| 393 |
+
nn.ReLU(True)
|
| 394 |
+
)
|
| 395 |
+
self.decoder_layer16 = nn.Sequential(
|
| 396 |
+
nn.Conv2d(8, 1, 3, stride=1, padding=1),
|
| 397 |
+
nn.ReLU(True)
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
def forward(self, x, lower_skip_list, imagenet_features, height=None, width=None):
|
| 401 |
+
if height == None and width == None:
|
| 402 |
+
height = self.height
|
| 403 |
+
width = self.width
|
| 404 |
+
|
| 405 |
+
x = self.decoder_layer1(x)
|
| 406 |
+
x = self.decoder_layer2(x)
|
| 407 |
+
x = x + lower_skip_list[3]
|
| 408 |
+
|
| 409 |
+
x = self.decoder_layer3(x)
|
| 410 |
+
x = self.decoder_layer4(x)
|
| 411 |
+
x = x + lower_skip_list[2]
|
| 412 |
+
|
| 413 |
+
x = self.decoder_layer5(x)
|
| 414 |
+
x = self.decoder_layer6(x)
|
| 415 |
+
x = x + lower_skip_list[1] + imagenet_features
|
| 416 |
+
|
| 417 |
+
x = self.decoder_layer7(x)
|
| 418 |
+
x = self.decoder_layer8(x)
|
| 419 |
+
x = x + lower_skip_list[0]
|
| 420 |
+
|
| 421 |
+
x = self.decoder_layer9(x)
|
| 422 |
+
x = self.decoder_layer10(x)
|
| 423 |
+
x = self.decoder_layer11(x)
|
| 424 |
+
x = self.decoder_layer12(x)
|
| 425 |
+
x = self.decoder_layer13(x)
|
| 426 |
+
x = self.decoder_layer14(x)
|
| 427 |
+
x = self.decoder_layer15(x)
|
| 428 |
+
x = self.decoder_layer16(x)
|
| 429 |
+
return x
|
| 430 |
+
|
| 431 |
+
class MMPI(nn.Module):
|
| 432 |
+
def __init__(self,total_image_input=1, height=384,width=384):
|
| 433 |
+
super().__init__()
|
| 434 |
+
self.height = height
|
| 435 |
+
self.width = width
|
| 436 |
+
self.feature_encoder = FeatureNet(height,width)
|
| 437 |
+
self.lower_encoder = Encoder(height, width, total_image_input)
|
| 438 |
+
self.merge_decoder_rgb = DecoderRGB(height, width)
|
| 439 |
+
self.merge_decoder_sigma = DecoderSigma(height, width)
|
| 440 |
+
self.depth_decoder = DecoderDepth(height, width)
|
| 441 |
+
|
| 442 |
+
def forward(self, x, height=None, width=None):
|
| 443 |
+
if height == None and width == None:
|
| 444 |
+
height = self.height
|
| 445 |
+
width = self.width
|
| 446 |
+
|
| 447 |
+
imagenet_fatures = self.feature_encoder.apply_feature_encoder(x)
|
| 448 |
+
lower_feature, skip_list = self.lower_encoder(x, height, width)
|
| 449 |
+
|
| 450 |
+
merged_feature_rgb = self.merge_decoder_rgb(lower_feature, skip_list, imagenet_fatures, height, width)
|
| 451 |
+
merged_feature_sigma = self.merge_decoder_sigma(lower_feature, skip_list, imagenet_fatures, height, width)
|
| 452 |
+
|
| 453 |
+
merged_feature_depth = self.depth_decoder(lower_feature, skip_list, imagenet_fatures)
|
| 454 |
+
|
| 455 |
+
return merged_feature_rgb, merged_feature_sigma, merged_feature_depth
|
| 456 |
+
|
| 457 |
+
def get_rgb_sigma(self, x, height=None, width=None):
|
| 458 |
+
if height == None and width == None:
|
| 459 |
+
height = self.height
|
| 460 |
+
width = self.width
|
| 461 |
+
|
| 462 |
+
imagenet_fatures = self.feature_encoder.apply_feature_encoder(x)
|
| 463 |
+
lower_feature, skip_list = self.lower_encoder(x, height, width)
|
| 464 |
+
merged_feature_rgb = self.merge_decoder_rgb(lower_feature, skip_list, imagenet_fatures, height, width)
|
| 465 |
+
merged_feature_sigma = self.merge_decoder_sigma(lower_feature, skip_list, imagenet_fatures, height, width)
|
| 466 |
+
return merged_feature_rgb, merged_feature_sigma
|
| 467 |
+
|
| 468 |
+
def get_depth(self, x, height=None, width=None):
|
| 469 |
+
if height == None and width == None:
|
| 470 |
+
height = self.height
|
| 471 |
+
width = self.width
|
| 472 |
+
|
| 473 |
+
imagenet_fatures = self.feature_encoder.apply_feature_encoder(x)
|
| 474 |
+
lower_feature, skip_list = self.lower_encoder(x, height, width)
|
| 475 |
+
merged_feature_depth = self.depth_decoder(lower_feature, skip_list, imagenet_fatures)
|
| 476 |
+
return merged_feature_depth
|
| 477 |
+
|
| 478 |
+
def get_layer_depth(self, x, grid, height=None, width=None):
|
| 479 |
+
if height == None and width == None:
|
| 480 |
+
height = self.height
|
| 481 |
+
width = self.width
|
| 482 |
+
|
| 483 |
+
imagenet_fatures = self.feature_encoder.apply_feature_encoder(x)
|
| 484 |
+
lower_feature, skip_list = self.lower_encoder(x, height, width)
|
| 485 |
+
|
| 486 |
+
rgb_layers = self.merge_decoder_rgb(lower_feature, skip_list, imagenet_fatures, height, width)
|
| 487 |
+
sigma_layers = self.merge_decoder_sigma(lower_feature, skip_list, imagenet_fatures, height, width)
|
| 488 |
+
|
| 489 |
+
pred_mpi_planes = torch.randn((1, 4, height, width)).to(params.DEVICE)
|
| 490 |
+
for i in range(params.params_num_planes):
|
| 491 |
+
RGBA = torch.cat((rgb_layers[0,i,:,:,:],sigma_layers[0,i,:,:,:]),dim=0).unsqueeze(0)
|
| 492 |
+
pred_mpi_planes = torch.cat((pred_mpi_planes,RGBA),dim=0)
|
| 493 |
+
|
| 494 |
+
pred_mpi_planes = pred_mpi_planes[1:,:,:,:].unsqueeze(0)
|
| 495 |
+
|
| 496 |
+
sigma = pred_mpi_planes[:, :, 3, :, :]
|
| 497 |
+
B, D, H, W = sigma.shape
|
| 498 |
+
|
| 499 |
+
pred_mpi_disp = grid
|
| 500 |
+
disp_sorted, _ = pred_mpi_disp.sort(dim=1)
|
| 501 |
+
delta = disp_sorted[:, 1:] - disp_sorted[:, :-1]
|
| 502 |
+
delta_last = delta[:, -1:]
|
| 503 |
+
delta = torch.cat([delta, delta_last], dim=1)
|
| 504 |
+
|
| 505 |
+
delta = delta.unsqueeze(-1).unsqueeze(-1).expand_as(sigma)
|
| 506 |
+
|
| 507 |
+
alpha = 1.0 - torch.exp(-delta * sigma)
|
| 508 |
+
|
| 509 |
+
transmittance = torch.cumprod(1 - alpha + 1e-7, dim=1)
|
| 510 |
+
shifted_transmittance = torch.ones_like(transmittance)
|
| 511 |
+
shifted_transmittance[:, 1:, :, :] = transmittance[:, :-1, :, :]
|
| 512 |
+
|
| 513 |
+
disparity = pred_mpi_disp.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, H, W)
|
| 514 |
+
|
| 515 |
+
disparity_map = (disparity * alpha * shifted_transmittance).sum(dim=1, keepdim=True)
|
| 516 |
+
|
| 517 |
+
return disparity_map
|
| 518 |
+
|
| 519 |
+
def get_layers(self, x, height=None, width=None):
|
| 520 |
+
if height == None and width == None:
|
| 521 |
+
height = self.height
|
| 522 |
+
width = self.width
|
| 523 |
+
|
| 524 |
+
imagenet_fatures = self.feature_encoder.apply_feature_encoder(x)
|
| 525 |
+
lower_feature, skip_list = self.lower_encoder(x, height, width)
|
| 526 |
+
merged_feature_rgb = self.merge_decoder_rgb(lower_feature, skip_list, imagenet_fatures, height, width)
|
| 527 |
+
merged_feature_sigma = self.merge_decoder_sigma(lower_feature, skip_list, imagenet_fatures, height, width)
|
| 528 |
+
return merged_feature_rgb, merged_feature_sigma
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
|
model_Medium.py
ADDED
|
@@ -0,0 +1,535 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import warnings
|
| 5 |
+
warnings.filterwarnings("ignore")
|
| 6 |
+
import torchvision
|
| 7 |
+
import parameters as params
|
| 8 |
+
import timm
|
| 9 |
+
|
| 10 |
+
class DinoV2FeatureExtractor(nn.Module):
|
| 11 |
+
def __init__(self, out_channels=256, out_size=(64, 64)):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.dino = timm.create_model('vit_base_patch14_dinov2.lvd142m', pretrained=False)
|
| 14 |
+
self.dino.eval()
|
| 15 |
+
for p in self.dino.parameters():
|
| 16 |
+
p.requires_grad = False
|
| 17 |
+
|
| 18 |
+
self.out_size = out_size
|
| 19 |
+
self.feat_proj = nn.Sequential(
|
| 20 |
+
nn.Conv2d(self.dino.embed_dim, out_channels, kernel_size=1),
|
| 21 |
+
nn.ReLU(),
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 25 |
+
x = F.interpolate(x, size=(518, 518), mode='bilinear', align_corners=False)
|
| 26 |
+
patch_tokens = self.dino.forward_features(x)
|
| 27 |
+
patch_tokens = patch_tokens[:, 1:]
|
| 28 |
+
B, N, C = patch_tokens.shape
|
| 29 |
+
h = w = int(N ** 0.5)
|
| 30 |
+
feat_map = patch_tokens.transpose(1, 2).reshape(B, C, h, w) # [B, C, H', W']
|
| 31 |
+
feat_map = F.interpolate(feat_map, size=self.out_size, mode='bilinear', align_corners=False)
|
| 32 |
+
return self.feat_proj(feat_map)
|
| 33 |
+
|
| 34 |
+
def getLinearLayer(in_feat, out_feat, activation=nn.ReLU(True)):
|
| 35 |
+
return nn.Sequential(
|
| 36 |
+
nn.Linear(in_features=in_feat, out_features=out_feat, bias=True),
|
| 37 |
+
activation
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
def getConvLayer(in_channel,out_channel,stride=1,padding=1,activation=nn.ReLU()):
|
| 41 |
+
return nn.Sequential(nn.Conv2d(in_channel,
|
| 42 |
+
out_channel,
|
| 43 |
+
kernel_size=3,
|
| 44 |
+
stride=stride,
|
| 45 |
+
padding=padding,
|
| 46 |
+
padding_mode='reflect'),
|
| 47 |
+
activation)
|
| 48 |
+
|
| 49 |
+
def getConvTransposeLayer(in_channel, out_channel,kernel=3,stride=1,padding=1,activation=nn.ReLU()):
|
| 50 |
+
return nn.Sequential(nn.ConvTranspose2d(in_channel,
|
| 51 |
+
out_channel,
|
| 52 |
+
kernel_size = kernel,
|
| 53 |
+
stride=stride,
|
| 54 |
+
padding=padding),
|
| 55 |
+
activation)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class ResidualBlock(nn.Module):
|
| 59 |
+
def __init__(self, in_channels, out_channels, stride=1):
|
| 60 |
+
super(ResidualBlock, self).__init__()
|
| 61 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
|
| 62 |
+
self.relu = nn.ReLU()
|
| 63 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
| 64 |
+
self.stride = stride
|
| 65 |
+
|
| 66 |
+
self.shortcut = nn.Sequential()
|
| 67 |
+
if stride != 1 or in_channels != out_channels:
|
| 68 |
+
self.shortcut = nn.Sequential(
|
| 69 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
|
| 70 |
+
nn.BatchNorm2d(out_channels)
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
def forward(self, x):
|
| 74 |
+
residual = x
|
| 75 |
+
|
| 76 |
+
out = self.conv1(x)
|
| 77 |
+
out = self.relu(out)
|
| 78 |
+
|
| 79 |
+
out = self.conv2(out)
|
| 80 |
+
|
| 81 |
+
out = out + self.shortcut(residual)
|
| 82 |
+
out = self.relu(out)
|
| 83 |
+
return out
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# class ResidualBlock(nn.Module):
|
| 87 |
+
# def __init__(self, in_channels, out_channels, stride=1, expansion=4):
|
| 88 |
+
# super().__init__()
|
| 89 |
+
# mid_channels = out_channels // expansion
|
| 90 |
+
# self.pw_reduce = nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False)
|
| 91 |
+
# self.bn1 = nn.BatchNorm2d(mid_channels)
|
| 92 |
+
# self.dw = nn.Conv2d(mid_channels, mid_channels, kernel_size=3,
|
| 93 |
+
# stride=stride, padding=1, groups=mid_channels, bias=False)
|
| 94 |
+
# self.bn2 = nn.BatchNorm2d(mid_channels)
|
| 95 |
+
# self.pw_expand = nn.Conv2d(mid_channels, out_channels, kernel_size=1, bias=False)
|
| 96 |
+
# self.bn3 = nn.BatchNorm2d(out_channels)
|
| 97 |
+
# self.relu = nn.ReLU(inplace=True)
|
| 98 |
+
# self.stride = stride
|
| 99 |
+
# if stride != 1 or in_channels != out_channels:
|
| 100 |
+
# self.shortcut = nn.Sequential(
|
| 101 |
+
# nn.Conv2d(in_channels, out_channels, kernel_size=1,
|
| 102 |
+
# stride=stride, bias=False),
|
| 103 |
+
# nn.BatchNorm2d(out_channels),
|
| 104 |
+
# )
|
| 105 |
+
# else:
|
| 106 |
+
# self.shortcut = nn.Identity()
|
| 107 |
+
|
| 108 |
+
# def forward(self, x):
|
| 109 |
+
# identity = x
|
| 110 |
+
|
| 111 |
+
# out = self.pw_reduce(x)
|
| 112 |
+
# out = self.bn1(out)
|
| 113 |
+
# out = self.relu(out)
|
| 114 |
+
|
| 115 |
+
# out = self.dw(out)
|
| 116 |
+
# out = self.bn2(out)
|
| 117 |
+
# out = self.relu(out)
|
| 118 |
+
|
| 119 |
+
# out = self.pw_expand(out)
|
| 120 |
+
# out = self.bn3(out)
|
| 121 |
+
|
| 122 |
+
# out += self.shortcut(identity)
|
| 123 |
+
# out = self.relu(out)
|
| 124 |
+
# return out
|
| 125 |
+
|
| 126 |
+
class FeatureNet(nn.Module):
|
| 127 |
+
def __init__(self,height,width):
|
| 128 |
+
super().__init__()
|
| 129 |
+
model = torchvision.models.resnet152(pretrained=False)
|
| 130 |
+
layers = list(model.children())
|
| 131 |
+
self.FeatureEncoder = torch.nn.Sequential(*layers[:5].copy())
|
| 132 |
+
self.expand_layer = ResidualBlock(256, 200)
|
| 133 |
+
|
| 134 |
+
def forward(self, x):
|
| 135 |
+
x = self.FeatureEncoder(x)
|
| 136 |
+
x = self.expand_layer(x)
|
| 137 |
+
return x
|
| 138 |
+
|
| 139 |
+
def apply_feature_encoder(self, x):
|
| 140 |
+
x = self.FeatureEncoder(x)
|
| 141 |
+
x = self.expand_layer(x)
|
| 142 |
+
return x
|
| 143 |
+
|
| 144 |
+
class Encoder(nn.Module):
|
| 145 |
+
def __init__(self, height, width, total_image_input=1):
|
| 146 |
+
super().__init__()
|
| 147 |
+
self.height = height
|
| 148 |
+
self.width = width
|
| 149 |
+
self.encoder_pre = ResidualBlock((total_image_input*3), 20)
|
| 150 |
+
self.encoder_layer1 = ResidualBlock(20, 30)
|
| 151 |
+
self.encoder_layer2 = ResidualBlock(30, 50)
|
| 152 |
+
|
| 153 |
+
self.encoder_layer3 = nn.Sequential(
|
| 154 |
+
ResidualBlock(50, 100),
|
| 155 |
+
nn.MaxPool2d(kernel_size=2, stride=2)
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
self.encoder_layer4 = ResidualBlock(100, 200)
|
| 159 |
+
self.encoder_layer5 = nn.Sequential(
|
| 160 |
+
ResidualBlock(200, 200),
|
| 161 |
+
nn.MaxPool2d(kernel_size=2, stride=2)
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
self.encoder_layer6 = ResidualBlock(200, 200)
|
| 165 |
+
self.encoder_layer7 = nn.Sequential(
|
| 166 |
+
ResidualBlock(200, 200),
|
| 167 |
+
nn.MaxPool2d(kernel_size=2, stride=2)
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
self.encoder_layer8 = ResidualBlock(200, 500)
|
| 171 |
+
self.encoder_layer9 = nn.Sequential(
|
| 172 |
+
ResidualBlock(500, 500),
|
| 173 |
+
nn.MaxPool2d(kernel_size=2, stride=2)
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
self.encoder_layer10 = ResidualBlock(500, 500)
|
| 177 |
+
self.encoder_layer11 = ResidualBlock(500, 500)
|
| 178 |
+
|
| 179 |
+
def forward(self, x, height=None, width=None):
|
| 180 |
+
if height == None and width == None:
|
| 181 |
+
height = self.height
|
| 182 |
+
width = self.width
|
| 183 |
+
|
| 184 |
+
x = self.encoder_pre(x)
|
| 185 |
+
x = self.encoder_layer1(x)
|
| 186 |
+
x = self.encoder_layer2(x)
|
| 187 |
+
skip1 = self.encoder_layer3(x)
|
| 188 |
+
|
| 189 |
+
x = self.encoder_layer4(skip1)
|
| 190 |
+
skip2 = self.encoder_layer5(x)
|
| 191 |
+
|
| 192 |
+
x = self.encoder_layer6(skip2)
|
| 193 |
+
skip3 = self.encoder_layer7(x)
|
| 194 |
+
|
| 195 |
+
x = self.encoder_layer8(skip3)
|
| 196 |
+
skip4 = self.encoder_layer9(x)
|
| 197 |
+
|
| 198 |
+
x = self.encoder_layer10(skip4)
|
| 199 |
+
x = self.encoder_layer11(x)
|
| 200 |
+
|
| 201 |
+
return x, [skip1, skip2, skip3, skip4]
|
| 202 |
+
|
| 203 |
+
class DecoderRGB(nn.Module):
|
| 204 |
+
def __init__(self,height,width):
|
| 205 |
+
super().__init__()
|
| 206 |
+
self.height = height
|
| 207 |
+
self.width = width
|
| 208 |
+
self.decoder_layer1 = ResidualBlock(500, 500)
|
| 209 |
+
self.decoder_layer2 = ResidualBlock(500, 500)
|
| 210 |
+
self.decoder_layer3 = ResidualBlock(500, 500)
|
| 211 |
+
|
| 212 |
+
self.decoder_layer4 = nn.Sequential(
|
| 213 |
+
nn.ConvTranspose2d(500, 200, 2, stride=2, padding=0),
|
| 214 |
+
nn.ReLU(True)
|
| 215 |
+
)
|
| 216 |
+
self.decoder_layer5 = ResidualBlock(200, 200)
|
| 217 |
+
|
| 218 |
+
self.decoder_layer6 = nn.Sequential(
|
| 219 |
+
nn.ConvTranspose2d(200, 200, 2, stride=2, padding=0),
|
| 220 |
+
nn.ReLU(True)
|
| 221 |
+
)
|
| 222 |
+
self.decoder_layer7 = ResidualBlock(200, 200)
|
| 223 |
+
|
| 224 |
+
self.decoder_layer8 = nn.Sequential(
|
| 225 |
+
nn.ConvTranspose2d(200, 100, 2, stride=2, padding=0),
|
| 226 |
+
nn.ReLU(True)
|
| 227 |
+
)
|
| 228 |
+
self.decoder_layer9 = ResidualBlock(100, 100)
|
| 229 |
+
|
| 230 |
+
self.decoder_layer10 = nn.Sequential(
|
| 231 |
+
nn.ConvTranspose2d(100, 100, 2, stride=2, padding=0),
|
| 232 |
+
nn.ReLU(True)
|
| 233 |
+
)
|
| 234 |
+
self.decoder_layer11 = ResidualBlock(100, 100)
|
| 235 |
+
self.decoder_layer12 = ResidualBlock(100, 96)
|
| 236 |
+
self.decoder_layer13 = ResidualBlock(96, 96)
|
| 237 |
+
self.decoder_layer14 = ResidualBlock(96, 96)
|
| 238 |
+
self.decoder_layer15 = nn.Sequential(
|
| 239 |
+
nn.Conv2d(96, 96, 3, stride=1, padding=1),
|
| 240 |
+
nn.Sigmoid()
|
| 241 |
+
)
|
| 242 |
+
self.decoder_layer16 = nn.Sequential(
|
| 243 |
+
nn.Conv2d(96, 96, 3, stride=1, padding=1),
|
| 244 |
+
nn.Sigmoid()
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
def forward(self, x, lower_skip_list, imagenet_features, height=None, width=None):
|
| 248 |
+
if height == None and width == None:
|
| 249 |
+
height = self.height
|
| 250 |
+
width = self.width
|
| 251 |
+
|
| 252 |
+
x = self.decoder_layer1(x)
|
| 253 |
+
x = self.decoder_layer2(x)
|
| 254 |
+
x = x + lower_skip_list[3]
|
| 255 |
+
|
| 256 |
+
x = self.decoder_layer3(x)
|
| 257 |
+
x = self.decoder_layer4(x)
|
| 258 |
+
x = x + lower_skip_list[2]
|
| 259 |
+
|
| 260 |
+
x = self.decoder_layer5(x)
|
| 261 |
+
x = self.decoder_layer6(x)
|
| 262 |
+
x = x + lower_skip_list[1] + imagenet_features
|
| 263 |
+
|
| 264 |
+
x = self.decoder_layer7(x)
|
| 265 |
+
x = self.decoder_layer8(x)
|
| 266 |
+
x = x + lower_skip_list[0]
|
| 267 |
+
|
| 268 |
+
x = self.decoder_layer9(x)
|
| 269 |
+
x = self.decoder_layer10(x)
|
| 270 |
+
x = self.decoder_layer11(x)
|
| 271 |
+
x = self.decoder_layer12(x)
|
| 272 |
+
x = self.decoder_layer13(x)
|
| 273 |
+
x = self.decoder_layer14(x)
|
| 274 |
+
x = self.decoder_layer15(x)
|
| 275 |
+
x = self.decoder_layer16(x)
|
| 276 |
+
x = x.view(x.size()[0], 32, 3, height, width)
|
| 277 |
+
return x
|
| 278 |
+
|
| 279 |
+
class DecoderSigma(nn.Module):
|
| 280 |
+
def __init__(self,height,width):
|
| 281 |
+
super().__init__()
|
| 282 |
+
self.height = height
|
| 283 |
+
self.width = width
|
| 284 |
+
self.decoder_layer1 = ResidualBlock(500, 500)
|
| 285 |
+
self.decoder_layer2 = ResidualBlock(500, 500)
|
| 286 |
+
self.decoder_layer3 = ResidualBlock(500, 500)
|
| 287 |
+
|
| 288 |
+
self.decoder_layer4 = nn.Sequential(
|
| 289 |
+
nn.ConvTranspose2d(500, 200, 2, stride=2, padding=0),
|
| 290 |
+
nn.ReLU(True)
|
| 291 |
+
)
|
| 292 |
+
self.decoder_layer5 = ResidualBlock(200, 200)
|
| 293 |
+
|
| 294 |
+
self.decoder_layer6 = nn.Sequential(
|
| 295 |
+
nn.ConvTranspose2d(200, 200, 2, stride=2, padding=0),
|
| 296 |
+
nn.ReLU(True)
|
| 297 |
+
)
|
| 298 |
+
self.decoder_layer7 = ResidualBlock(200, 200)
|
| 299 |
+
|
| 300 |
+
self.decoder_layer8 = nn.Sequential(
|
| 301 |
+
nn.ConvTranspose2d(200, 100, 2, stride=2, padding=0),
|
| 302 |
+
nn.ReLU(True)
|
| 303 |
+
)
|
| 304 |
+
self.decoder_layer9 = ResidualBlock(100, 100)
|
| 305 |
+
|
| 306 |
+
self.decoder_layer10 = nn.Sequential(
|
| 307 |
+
nn.ConvTranspose2d(100, 100, 2, stride=2, padding=0),
|
| 308 |
+
nn.ReLU(True)
|
| 309 |
+
)
|
| 310 |
+
self.decoder_layer11 = ResidualBlock(100, 100)
|
| 311 |
+
self.decoder_layer12 = ResidualBlock(100, 50)
|
| 312 |
+
self.decoder_layer13 = ResidualBlock(50, 40)
|
| 313 |
+
self.decoder_layer14 = ResidualBlock(40, 32)
|
| 314 |
+
self.decoder_layer15 = nn.Sequential(
|
| 315 |
+
nn.Conv2d(32, 32, 3, stride=1, padding=1),
|
| 316 |
+
nn.ReLU(True)
|
| 317 |
+
)
|
| 318 |
+
self.decoder_layer16 = nn.Sequential(
|
| 319 |
+
nn.Conv2d(32, 32, 3, stride=1, padding=1),
|
| 320 |
+
nn.ReLU(True)
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
def forward(self, x, lower_skip_list, imagenet_features, height=None, width=None):
|
| 324 |
+
if height == None and width == None:
|
| 325 |
+
height = self.height
|
| 326 |
+
width = self.width
|
| 327 |
+
|
| 328 |
+
x = self.decoder_layer1(x)
|
| 329 |
+
x = self.decoder_layer2(x)
|
| 330 |
+
x = x + lower_skip_list[3]
|
| 331 |
+
|
| 332 |
+
x = self.decoder_layer3(x)
|
| 333 |
+
x = self.decoder_layer4(x)
|
| 334 |
+
x = x + lower_skip_list[2]
|
| 335 |
+
|
| 336 |
+
x = self.decoder_layer5(x)
|
| 337 |
+
x = self.decoder_layer6(x)
|
| 338 |
+
x = x + lower_skip_list[1] + imagenet_features
|
| 339 |
+
|
| 340 |
+
x = self.decoder_layer7(x)
|
| 341 |
+
x = self.decoder_layer8(x)
|
| 342 |
+
x = x + lower_skip_list[0]
|
| 343 |
+
|
| 344 |
+
x = self.decoder_layer9(x)
|
| 345 |
+
x = self.decoder_layer10(x)
|
| 346 |
+
x = self.decoder_layer11(x)
|
| 347 |
+
x = self.decoder_layer12(x)
|
| 348 |
+
x = self.decoder_layer13(x)
|
| 349 |
+
x = self.decoder_layer14(x)
|
| 350 |
+
x = self.decoder_layer15(x)
|
| 351 |
+
x = self.decoder_layer16(x)
|
| 352 |
+
x = x.view(x.size()[0], 32, 1, height, width)
|
| 353 |
+
return x
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
class DecoderDepth(nn.Module):
|
| 357 |
+
def __init__(self,height,width):
|
| 358 |
+
super().__init__()
|
| 359 |
+
self.height = height
|
| 360 |
+
self.width = width
|
| 361 |
+
self.decoder_layer1 = ResidualBlock(500, 500)
|
| 362 |
+
self.decoder_layer2 = ResidualBlock(500, 500)
|
| 363 |
+
self.decoder_layer3 = ResidualBlock(500, 500)
|
| 364 |
+
|
| 365 |
+
self.decoder_layer4 = nn.Sequential(
|
| 366 |
+
nn.ConvTranspose2d(500, 200, 2, stride=2, padding=0),
|
| 367 |
+
nn.ReLU(True)
|
| 368 |
+
)
|
| 369 |
+
self.decoder_layer5 = ResidualBlock(200, 200)
|
| 370 |
+
|
| 371 |
+
self.decoder_layer6 = nn.Sequential(
|
| 372 |
+
nn.ConvTranspose2d(200, 200, 2, stride=2, padding=0),
|
| 373 |
+
nn.ReLU(True)
|
| 374 |
+
)
|
| 375 |
+
self.decoder_layer7 = ResidualBlock(200, 200)
|
| 376 |
+
|
| 377 |
+
self.decoder_layer8 = nn.Sequential(
|
| 378 |
+
nn.ConvTranspose2d(200, 100, 2, stride=2, padding=0),
|
| 379 |
+
nn.ReLU(True)
|
| 380 |
+
)
|
| 381 |
+
self.decoder_layer9 = ResidualBlock(100, 100)
|
| 382 |
+
|
| 383 |
+
self.decoder_layer10 = nn.Sequential(
|
| 384 |
+
nn.ConvTranspose2d(100, 100, 2, stride=2, padding=0),
|
| 385 |
+
nn.ReLU(True)
|
| 386 |
+
)
|
| 387 |
+
self.decoder_layer11 = ResidualBlock(100, 100)
|
| 388 |
+
self.decoder_layer12 = ResidualBlock(100, 50)
|
| 389 |
+
self.decoder_layer13 = ResidualBlock(50, 40)
|
| 390 |
+
self.decoder_layer14 = ResidualBlock(40, 16)
|
| 391 |
+
self.decoder_layer15 = nn.Sequential(
|
| 392 |
+
nn.Conv2d(16, 8, 3, stride=1, padding=1),
|
| 393 |
+
nn.ReLU(True)
|
| 394 |
+
)
|
| 395 |
+
self.decoder_layer16 = nn.Sequential(
|
| 396 |
+
nn.Conv2d(8, 1, 3, stride=1, padding=1),
|
| 397 |
+
nn.ReLU(True)
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
def forward(self, x, lower_skip_list, imagenet_features, height=None, width=None):
|
| 401 |
+
if height == None and width == None:
|
| 402 |
+
height = self.height
|
| 403 |
+
width = self.width
|
| 404 |
+
|
| 405 |
+
x = self.decoder_layer1(x)
|
| 406 |
+
x = self.decoder_layer2(x)
|
| 407 |
+
x = x + lower_skip_list[3]
|
| 408 |
+
|
| 409 |
+
x = self.decoder_layer3(x)
|
| 410 |
+
x = self.decoder_layer4(x)
|
| 411 |
+
x = x + lower_skip_list[2]
|
| 412 |
+
|
| 413 |
+
x = self.decoder_layer5(x)
|
| 414 |
+
x = self.decoder_layer6(x)
|
| 415 |
+
x = x + lower_skip_list[1] + imagenet_features
|
| 416 |
+
|
| 417 |
+
x = self.decoder_layer7(x)
|
| 418 |
+
x = self.decoder_layer8(x)
|
| 419 |
+
x = x + lower_skip_list[0]
|
| 420 |
+
|
| 421 |
+
x = self.decoder_layer9(x)
|
| 422 |
+
x = self.decoder_layer10(x)
|
| 423 |
+
x = self.decoder_layer11(x)
|
| 424 |
+
x = self.decoder_layer12(x)
|
| 425 |
+
x = self.decoder_layer13(x)
|
| 426 |
+
x = self.decoder_layer14(x)
|
| 427 |
+
x = self.decoder_layer15(x)
|
| 428 |
+
x = self.decoder_layer16(x)
|
| 429 |
+
return x
|
| 430 |
+
|
| 431 |
+
class MMPI(nn.Module):
|
| 432 |
+
def __init__(self,total_image_input=1, height=384,width=384):
|
| 433 |
+
super().__init__()
|
| 434 |
+
self.height = height
|
| 435 |
+
self.width = width
|
| 436 |
+
self.feature_encoder = FeatureNet(height,width)
|
| 437 |
+
self.lower_encoder = Encoder(height, width, total_image_input)
|
| 438 |
+
self.merge_decoder_rgb = DecoderRGB(height, width)
|
| 439 |
+
self.merge_decoder_sigma = DecoderSigma(height, width)
|
| 440 |
+
self.depth_decoder = DecoderDepth(height, width)
|
| 441 |
+
|
| 442 |
+
def forward(self, x, height=None, width=None):
|
| 443 |
+
if height == None and width == None:
|
| 444 |
+
height = self.height
|
| 445 |
+
width = self.width
|
| 446 |
+
|
| 447 |
+
imagenet_fatures = self.feature_encoder.apply_feature_encoder(x)
|
| 448 |
+
lower_feature, skip_list = self.lower_encoder(x, height, width)
|
| 449 |
+
|
| 450 |
+
merged_feature_rgb = self.merge_decoder_rgb(lower_feature, skip_list, imagenet_fatures, height, width)
|
| 451 |
+
merged_feature_sigma = self.merge_decoder_sigma(lower_feature, skip_list, imagenet_fatures, height, width)
|
| 452 |
+
|
| 453 |
+
merged_feature_depth = self.depth_decoder(lower_feature, skip_list, imagenet_fatures)
|
| 454 |
+
|
| 455 |
+
return merged_feature_rgb, merged_feature_sigma, merged_feature_depth
|
| 456 |
+
|
| 457 |
+
def get_rgb_sigma(self, x, height=None, width=None):
|
| 458 |
+
if height == None and width == None:
|
| 459 |
+
height = self.height
|
| 460 |
+
width = self.width
|
| 461 |
+
|
| 462 |
+
imagenet_fatures = self.feature_encoder.apply_feature_encoder(x)
|
| 463 |
+
lower_feature, skip_list = self.lower_encoder(x, height, width)
|
| 464 |
+
merged_feature_rgb = self.merge_decoder_rgb(lower_feature, skip_list, imagenet_fatures, height, width)
|
| 465 |
+
merged_feature_sigma = self.merge_decoder_sigma(lower_feature, skip_list, imagenet_fatures, height, width)
|
| 466 |
+
return merged_feature_rgb, merged_feature_sigma
|
| 467 |
+
|
| 468 |
+
def get_depth(self, x, height=None, width=None):
|
| 469 |
+
if height == None and width == None:
|
| 470 |
+
height = self.height
|
| 471 |
+
width = self.width
|
| 472 |
+
|
| 473 |
+
imagenet_fatures = self.feature_encoder.apply_feature_encoder(x)
|
| 474 |
+
lower_feature, skip_list = self.lower_encoder(x, height, width)
|
| 475 |
+
merged_feature_depth = self.depth_decoder(lower_feature, skip_list, imagenet_fatures)
|
| 476 |
+
return merged_feature_depth
|
| 477 |
+
|
| 478 |
+
def get_layer_depth(self, x, grid, height=None, width=None):
|
| 479 |
+
if height == None and width == None:
|
| 480 |
+
height = self.height
|
| 481 |
+
width = self.width
|
| 482 |
+
|
| 483 |
+
imagenet_fatures = self.feature_encoder.apply_feature_encoder(x)
|
| 484 |
+
lower_feature, skip_list = self.lower_encoder(x, height, width)
|
| 485 |
+
|
| 486 |
+
rgb_layers = self.merge_decoder_rgb(lower_feature, skip_list, imagenet_fatures, height, width)
|
| 487 |
+
sigma_layers = self.merge_decoder_sigma(lower_feature, skip_list, imagenet_fatures, height, width)
|
| 488 |
+
|
| 489 |
+
pred_mpi_planes = torch.randn((1, 4, height, width)).to(params.DEVICE)
|
| 490 |
+
for i in range(params.params_num_planes):
|
| 491 |
+
RGBA = torch.cat((rgb_layers[0,i,:,:,:],sigma_layers[0,i,:,:,:]),dim=0).unsqueeze(0)
|
| 492 |
+
pred_mpi_planes = torch.cat((pred_mpi_planes,RGBA),dim=0)
|
| 493 |
+
|
| 494 |
+
pred_mpi_planes = pred_mpi_planes[1:,:,:,:].unsqueeze(0)
|
| 495 |
+
|
| 496 |
+
sigma = pred_mpi_planes[:, :, 3, :, :]
|
| 497 |
+
B, D, H, W = sigma.shape
|
| 498 |
+
|
| 499 |
+
pred_mpi_disp = grid
|
| 500 |
+
disp_sorted, _ = pred_mpi_disp.sort(dim=1)
|
| 501 |
+
delta = disp_sorted[:, 1:] - disp_sorted[:, :-1]
|
| 502 |
+
delta_last = delta[:, -1:]
|
| 503 |
+
delta = torch.cat([delta, delta_last], dim=1)
|
| 504 |
+
|
| 505 |
+
delta = delta.unsqueeze(-1).unsqueeze(-1).expand_as(sigma)
|
| 506 |
+
|
| 507 |
+
alpha = 1.0 - torch.exp(-delta * sigma)
|
| 508 |
+
|
| 509 |
+
transmittance = torch.cumprod(1 - alpha + 1e-7, dim=1)
|
| 510 |
+
shifted_transmittance = torch.ones_like(transmittance)
|
| 511 |
+
shifted_transmittance[:, 1:, :, :] = transmittance[:, :-1, :, :]
|
| 512 |
+
|
| 513 |
+
disparity = pred_mpi_disp.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, H, W)
|
| 514 |
+
|
| 515 |
+
disparity_map = (disparity * alpha * shifted_transmittance).sum(dim=1, keepdim=True)
|
| 516 |
+
|
| 517 |
+
return disparity_map
|
| 518 |
+
|
| 519 |
+
def get_layers(self, x, height=None, width=None):
|
| 520 |
+
if height == None and width == None:
|
| 521 |
+
height = self.height
|
| 522 |
+
width = self.width
|
| 523 |
+
|
| 524 |
+
imagenet_fatures = self.feature_encoder.apply_feature_encoder(x)
|
| 525 |
+
lower_feature, skip_list = self.lower_encoder(x, height, width)
|
| 526 |
+
merged_feature_rgb = self.merge_decoder_rgb(lower_feature, skip_list, imagenet_fatures, height, width)
|
| 527 |
+
merged_feature_sigma = self.merge_decoder_sigma(lower_feature, skip_list, imagenet_fatures, height, width)
|
| 528 |
+
return merged_feature_rgb, merged_feature_sigma
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
|
model_Small.py
ADDED
|
@@ -0,0 +1,544 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import warnings
|
| 5 |
+
warnings.filterwarnings("ignore")
|
| 6 |
+
import torchvision
|
| 7 |
+
import parameters as params
|
| 8 |
+
import timm
|
| 9 |
+
|
| 10 |
+
class DinoV2FeatureExtractor(nn.Module):
|
| 11 |
+
def __init__(self, out_channels=256, out_size=(64, 64)):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.dino = timm.create_model('vit_base_patch14_dinov2.lvd142m', pretrained=False)
|
| 14 |
+
self.dino.eval()
|
| 15 |
+
for p in self.dino.parameters():
|
| 16 |
+
p.requires_grad = False
|
| 17 |
+
|
| 18 |
+
self.out_size = out_size
|
| 19 |
+
self.feat_proj = nn.Sequential(
|
| 20 |
+
nn.Conv2d(self.dino.embed_dim, out_channels, kernel_size=1),
|
| 21 |
+
nn.ReLU(),
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 25 |
+
x = F.interpolate(x, size=(518, 518), mode='bilinear', align_corners=False)
|
| 26 |
+
patch_tokens = self.dino.forward_features(x)
|
| 27 |
+
patch_tokens = patch_tokens[:, 1:]
|
| 28 |
+
B, N, C = patch_tokens.shape
|
| 29 |
+
h = w = int(N ** 0.5)
|
| 30 |
+
feat_map = patch_tokens.transpose(1, 2).reshape(B, C, h, w) # [B, C, H', W']
|
| 31 |
+
feat_map = F.interpolate(feat_map, size=self.out_size, mode='bilinear', align_corners=False)
|
| 32 |
+
return self.feat_proj(feat_map)
|
| 33 |
+
|
| 34 |
+
def getLinearLayer(in_feat, out_feat, activation=nn.ReLU(True)):
|
| 35 |
+
return nn.Sequential(
|
| 36 |
+
nn.Linear(in_features=in_feat, out_features=out_feat, bias=True),
|
| 37 |
+
activation
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
def getConvLayer(in_channel,out_channel,stride=1,padding=1,activation=nn.ReLU()):
|
| 41 |
+
return nn.Sequential(nn.Conv2d(in_channel,
|
| 42 |
+
out_channel,
|
| 43 |
+
kernel_size=3,
|
| 44 |
+
stride=stride,
|
| 45 |
+
padding=padding,
|
| 46 |
+
padding_mode='reflect'),
|
| 47 |
+
activation)
|
| 48 |
+
|
| 49 |
+
def getConvTransposeLayer(in_channel, out_channel,kernel=3,stride=1,padding=1,activation=nn.ReLU()):
|
| 50 |
+
return nn.Sequential(nn.ConvTranspose2d(in_channel,
|
| 51 |
+
out_channel,
|
| 52 |
+
kernel_size = kernel,
|
| 53 |
+
stride=stride,
|
| 54 |
+
padding=padding),
|
| 55 |
+
activation)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class ResidualBlock(nn.Module):
|
| 59 |
+
def __init__(self, in_channels, out_channels, stride=1):
|
| 60 |
+
super(ResidualBlock, self).__init__()
|
| 61 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
|
| 62 |
+
self.relu = nn.ReLU()
|
| 63 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
| 64 |
+
self.stride = stride
|
| 65 |
+
|
| 66 |
+
self.shortcut = nn.Sequential()
|
| 67 |
+
if stride != 1 or in_channels != out_channels:
|
| 68 |
+
self.shortcut = nn.Sequential(
|
| 69 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
|
| 70 |
+
nn.BatchNorm2d(out_channels)
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
def forward(self, x):
|
| 74 |
+
residual = x
|
| 75 |
+
|
| 76 |
+
out = self.conv1(x)
|
| 77 |
+
out = self.relu(out)
|
| 78 |
+
|
| 79 |
+
out = self.conv2(out)
|
| 80 |
+
|
| 81 |
+
out = out + self.shortcut(residual)
|
| 82 |
+
out = self.relu(out)
|
| 83 |
+
return out
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# class ResidualBlock(nn.Module):
|
| 87 |
+
# def __init__(self, in_channels, out_channels, stride=1, expansion=4):
|
| 88 |
+
# super().__init__()
|
| 89 |
+
# mid_channels = out_channels // expansion
|
| 90 |
+
# self.pw_reduce = nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False)
|
| 91 |
+
# self.bn1 = nn.BatchNorm2d(mid_channels)
|
| 92 |
+
# self.dw = nn.Conv2d(mid_channels, mid_channels, kernel_size=3,
|
| 93 |
+
# stride=stride, padding=1, groups=mid_channels, bias=False)
|
| 94 |
+
# self.bn2 = nn.BatchNorm2d(mid_channels)
|
| 95 |
+
# self.pw_expand = nn.Conv2d(mid_channels, out_channels, kernel_size=1, bias=False)
|
| 96 |
+
# self.bn3 = nn.BatchNorm2d(out_channels)
|
| 97 |
+
# self.relu = nn.ReLU(inplace=True)
|
| 98 |
+
# self.stride = stride
|
| 99 |
+
# if stride != 1 or in_channels != out_channels:
|
| 100 |
+
# self.shortcut = nn.Sequential(
|
| 101 |
+
# nn.Conv2d(in_channels, out_channels, kernel_size=1,
|
| 102 |
+
# stride=stride, bias=False),
|
| 103 |
+
# nn.BatchNorm2d(out_channels),
|
| 104 |
+
# )
|
| 105 |
+
# else:
|
| 106 |
+
# self.shortcut = nn.Identity()
|
| 107 |
+
|
| 108 |
+
# def forward(self, x):
|
| 109 |
+
# identity = x
|
| 110 |
+
|
| 111 |
+
# out = self.pw_reduce(x)
|
| 112 |
+
# out = self.bn1(out)
|
| 113 |
+
# out = self.relu(out)
|
| 114 |
+
|
| 115 |
+
# out = self.dw(out)
|
| 116 |
+
# out = self.bn2(out)
|
| 117 |
+
# out = self.relu(out)
|
| 118 |
+
|
| 119 |
+
# out = self.pw_expand(out)
|
| 120 |
+
# out = self.bn3(out)
|
| 121 |
+
|
| 122 |
+
# out += self.shortcut(identity)
|
| 123 |
+
# out = self.relu(out)
|
| 124 |
+
# return out
|
| 125 |
+
|
| 126 |
+
class FeatureNet(nn.Module):
|
| 127 |
+
def __init__(self,height,width):
|
| 128 |
+
super().__init__()
|
| 129 |
+
model = torchvision.models.resnet152(pretrained=False)
|
| 130 |
+
layers = list(model.children())
|
| 131 |
+
self.FeatureEncoder = torch.nn.Sequential(*layers[:5].copy())
|
| 132 |
+
del model
|
| 133 |
+
|
| 134 |
+
def forward(self, x):
|
| 135 |
+
x = self.FeatureEncoder(x)
|
| 136 |
+
return x
|
| 137 |
+
|
| 138 |
+
def apply_feature_encoder(self, x):
|
| 139 |
+
x = self.FeatureEncoder(x)
|
| 140 |
+
return x
|
| 141 |
+
|
| 142 |
+
class Encoder(nn.Module):
|
| 143 |
+
def __init__(self, height, width, total_image_input=1):
|
| 144 |
+
super().__init__()
|
| 145 |
+
self.height = height
|
| 146 |
+
self.width = width
|
| 147 |
+
self.encoder_pre = ResidualBlock((total_image_input*3), 20)
|
| 148 |
+
self.encoder_layer1 = ResidualBlock(20, 30)
|
| 149 |
+
self.encoder_layer2 = ResidualBlock(30, 50)
|
| 150 |
+
|
| 151 |
+
self.encoder_layer3 = nn.Sequential(
|
| 152 |
+
ResidualBlock(50, 100),
|
| 153 |
+
nn.MaxPool2d(kernel_size=2, stride=2)
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
self.encoder_layer4 = ResidualBlock(100, 100)
|
| 157 |
+
self.encoder_layer5 = nn.Sequential(
|
| 158 |
+
ResidualBlock(100, 100),
|
| 159 |
+
nn.MaxPool2d(kernel_size=2, stride=2)
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
self.encoder_layer6 = ResidualBlock(100, 100)
|
| 163 |
+
self.encoder_layer7 = nn.Sequential(
|
| 164 |
+
ResidualBlock(100, 100),
|
| 165 |
+
nn.MaxPool2d(kernel_size=2, stride=2)
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
self.encoder_layer8 = ResidualBlock(100, 100)
|
| 169 |
+
self.encoder_layer9 = nn.Sequential(
|
| 170 |
+
ResidualBlock(100, 100),
|
| 171 |
+
nn.MaxPool2d(kernel_size=2, stride=2)
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
self.encoder_layer10 = ResidualBlock(100, 100)
|
| 175 |
+
self.encoder_layer11 = ResidualBlock(100, 100)
|
| 176 |
+
|
| 177 |
+
def forward(self, x, height=None, width=None):
|
| 178 |
+
if height == None and width == None:
|
| 179 |
+
height = self.height
|
| 180 |
+
width = self.width
|
| 181 |
+
|
| 182 |
+
x = self.encoder_pre(x)
|
| 183 |
+
x = self.encoder_layer1(x)
|
| 184 |
+
x = self.encoder_layer2(x)
|
| 185 |
+
skip1 = self.encoder_layer3(x)
|
| 186 |
+
|
| 187 |
+
x = self.encoder_layer4(skip1)
|
| 188 |
+
skip2 = self.encoder_layer5(x)
|
| 189 |
+
|
| 190 |
+
x = self.encoder_layer6(skip2)
|
| 191 |
+
skip3 = self.encoder_layer7(x)
|
| 192 |
+
|
| 193 |
+
x = self.encoder_layer8(skip3)
|
| 194 |
+
skip4 = self.encoder_layer9(x)
|
| 195 |
+
|
| 196 |
+
x = self.encoder_layer10(skip4)
|
| 197 |
+
x = self.encoder_layer11(x)
|
| 198 |
+
|
| 199 |
+
return x, [skip1, skip2, skip3, skip4]
|
| 200 |
+
|
| 201 |
+
class DecoderRGB(nn.Module):
|
| 202 |
+
def __init__(self,height,width):
|
| 203 |
+
super().__init__()
|
| 204 |
+
self.height = height
|
| 205 |
+
self.width = width
|
| 206 |
+
self.decoder_layer1 = ResidualBlock(100, 100)
|
| 207 |
+
self.decoder_layer2 = ResidualBlock(100, 100)
|
| 208 |
+
self.decoder_layer3 = ResidualBlock(100, 100)
|
| 209 |
+
|
| 210 |
+
self.decoder_layer4 = nn.Sequential(
|
| 211 |
+
nn.ConvTranspose2d(100, 100, 2, stride=2, padding=0),
|
| 212 |
+
nn.ReLU(True)
|
| 213 |
+
)
|
| 214 |
+
self.decoder_layer5 = ResidualBlock(100, 100)
|
| 215 |
+
|
| 216 |
+
self.decoder_layer6 = nn.Sequential(
|
| 217 |
+
nn.ConvTranspose2d(100, 100, 2, stride=2, padding=0),
|
| 218 |
+
nn.ReLU(True)
|
| 219 |
+
)
|
| 220 |
+
self.decoder_layer7 = ResidualBlock(100, 100)
|
| 221 |
+
|
| 222 |
+
self.decoder_layer8 = nn.Sequential(
|
| 223 |
+
nn.ConvTranspose2d(100, 100, 2, stride=2, padding=0),
|
| 224 |
+
nn.ReLU(True)
|
| 225 |
+
)
|
| 226 |
+
self.decoder_layer9 = ResidualBlock(100, 100)
|
| 227 |
+
|
| 228 |
+
self.decoder_layer10 = nn.Sequential(
|
| 229 |
+
nn.ConvTranspose2d(100, 100, 2, stride=2, padding=0),
|
| 230 |
+
nn.ReLU(True)
|
| 231 |
+
)
|
| 232 |
+
self.decoder_layer11 = ResidualBlock(100, 100)
|
| 233 |
+
self.decoder_layer12 = ResidualBlock(100, 96)
|
| 234 |
+
self.decoder_layer13 = ResidualBlock(96, 96)
|
| 235 |
+
self.decoder_layer14 = ResidualBlock(96, 96)
|
| 236 |
+
self.decoder_layer15 = nn.Sequential(
|
| 237 |
+
nn.Conv2d(96, 96, 3, stride=1, padding=1),
|
| 238 |
+
nn.Sigmoid()
|
| 239 |
+
)
|
| 240 |
+
self.decoder_layer16 = nn.Sequential(
|
| 241 |
+
nn.Conv2d(96, 96, 3, stride=1, padding=1),
|
| 242 |
+
nn.Sigmoid()
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
def forward(self, x, lower_skip_list, upper_skip_list, height=None, width=None):
|
| 246 |
+
if height == None and width == None:
|
| 247 |
+
height = self.height
|
| 248 |
+
width = self.width
|
| 249 |
+
|
| 250 |
+
x = self.decoder_layer1(x)
|
| 251 |
+
x = self.decoder_layer2(x)
|
| 252 |
+
x = x + lower_skip_list[3] + upper_skip_list[1]
|
| 253 |
+
|
| 254 |
+
x = self.decoder_layer3(x)
|
| 255 |
+
x = self.decoder_layer4(x)
|
| 256 |
+
x = x + lower_skip_list[2] + upper_skip_list[0]
|
| 257 |
+
|
| 258 |
+
x = self.decoder_layer5(x)
|
| 259 |
+
x = self.decoder_layer6(x)
|
| 260 |
+
x = x + lower_skip_list[1]
|
| 261 |
+
|
| 262 |
+
x = self.decoder_layer7(x)
|
| 263 |
+
x = self.decoder_layer8(x)
|
| 264 |
+
x = x + lower_skip_list[0]
|
| 265 |
+
|
| 266 |
+
x = self.decoder_layer9(x)
|
| 267 |
+
x = self.decoder_layer10(x)
|
| 268 |
+
x = self.decoder_layer11(x)
|
| 269 |
+
x = self.decoder_layer12(x)
|
| 270 |
+
x = self.decoder_layer13(x)
|
| 271 |
+
x = self.decoder_layer14(x)
|
| 272 |
+
x = self.decoder_layer15(x)
|
| 273 |
+
x = self.decoder_layer16(x)
|
| 274 |
+
x = x.view(x.size()[0], 32, 3, height, width)
|
| 275 |
+
return x
|
| 276 |
+
|
| 277 |
+
class DecoderSigma(nn.Module):
|
| 278 |
+
def __init__(self,height,width):
|
| 279 |
+
super().__init__()
|
| 280 |
+
self.height = height
|
| 281 |
+
self.width = width
|
| 282 |
+
self.decoder_layer1 = ResidualBlock(100, 100)
|
| 283 |
+
self.decoder_layer2 = ResidualBlock(100, 100)
|
| 284 |
+
self.decoder_layer3 = ResidualBlock(100, 100)
|
| 285 |
+
|
| 286 |
+
self.decoder_layer4 = nn.Sequential(
|
| 287 |
+
nn.ConvTranspose2d(100, 100, 2, stride=2, padding=0),
|
| 288 |
+
nn.ReLU(True)
|
| 289 |
+
)
|
| 290 |
+
self.decoder_layer5 = ResidualBlock(100, 100)
|
| 291 |
+
|
| 292 |
+
self.decoder_layer6 = nn.Sequential(
|
| 293 |
+
nn.ConvTranspose2d(100, 100, 2, stride=2, padding=0),
|
| 294 |
+
nn.ReLU(True)
|
| 295 |
+
)
|
| 296 |
+
self.decoder_layer7 = ResidualBlock(100, 100)
|
| 297 |
+
|
| 298 |
+
self.decoder_layer8 = nn.Sequential(
|
| 299 |
+
nn.ConvTranspose2d(100, 100, 2, stride=2, padding=0),
|
| 300 |
+
nn.ReLU(True)
|
| 301 |
+
)
|
| 302 |
+
self.decoder_layer9 = ResidualBlock(100, 100)
|
| 303 |
+
|
| 304 |
+
self.decoder_layer10 = nn.Sequential(
|
| 305 |
+
nn.ConvTranspose2d(100, 50, 2, stride=2, padding=0),
|
| 306 |
+
nn.ReLU(True)
|
| 307 |
+
)
|
| 308 |
+
self.decoder_layer11 = nn.Sequential(
|
| 309 |
+
nn.Conv2d(50, 32, 3, stride=1, padding=1),
|
| 310 |
+
nn.ReLU(True)
|
| 311 |
+
)
|
| 312 |
+
self.decoder_layer12 = nn.Sequential(
|
| 313 |
+
nn.Conv2d(32, 32, 3, stride=1, padding=1),
|
| 314 |
+
nn.ReLU(True)
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
def forward(self, x, lower_skip_list, upper_skip_list, height=None, width=None):
|
| 318 |
+
if height == None and width == None:
|
| 319 |
+
height = self.height
|
| 320 |
+
width = self.width
|
| 321 |
+
|
| 322 |
+
x = self.decoder_layer1(x)
|
| 323 |
+
x = self.decoder_layer2(x)
|
| 324 |
+
x = x + lower_skip_list[3] + upper_skip_list[1]
|
| 325 |
+
|
| 326 |
+
x = self.decoder_layer3(x)
|
| 327 |
+
x = self.decoder_layer4(x)
|
| 328 |
+
x = x + lower_skip_list[2] + upper_skip_list[0]
|
| 329 |
+
|
| 330 |
+
x = self.decoder_layer5(x)
|
| 331 |
+
x = self.decoder_layer6(x)
|
| 332 |
+
x = x + lower_skip_list[1]
|
| 333 |
+
|
| 334 |
+
x = self.decoder_layer7(x)
|
| 335 |
+
x = self.decoder_layer8(x)
|
| 336 |
+
x = x + lower_skip_list[0]
|
| 337 |
+
|
| 338 |
+
x = self.decoder_layer9(x)
|
| 339 |
+
x = self.decoder_layer10(x)
|
| 340 |
+
x = self.decoder_layer11(x)
|
| 341 |
+
x = self.decoder_layer12(x)
|
| 342 |
+
x = x.view(x.size()[0], 32, 1, height, width)
|
| 343 |
+
return x
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
class DecoderDepth(nn.Module):
|
| 347 |
+
def __init__(self,height,width):
|
| 348 |
+
super().__init__()
|
| 349 |
+
self.height = height
|
| 350 |
+
self.width = width
|
| 351 |
+
self.decoder_layer1 = ResidualBlock(100, 100)
|
| 352 |
+
self.decoder_layer2 = ResidualBlock(100, 100)
|
| 353 |
+
self.decoder_layer3 = ResidualBlock(100, 100)
|
| 354 |
+
|
| 355 |
+
self.decoder_layer4 = nn.Sequential(
|
| 356 |
+
nn.ConvTranspose2d(100, 100, 2, stride=2, padding=0),
|
| 357 |
+
nn.ReLU(True)
|
| 358 |
+
)
|
| 359 |
+
self.decoder_layer5 = ResidualBlock(100, 100)
|
| 360 |
+
|
| 361 |
+
self.decoder_layer6 = nn.Sequential(
|
| 362 |
+
nn.ConvTranspose2d(100, 100, 2, stride=2, padding=0),
|
| 363 |
+
nn.ReLU(True)
|
| 364 |
+
)
|
| 365 |
+
self.decoder_layer7 = ResidualBlock(100, 100)
|
| 366 |
+
|
| 367 |
+
self.decoder_layer8 = nn.Sequential(
|
| 368 |
+
nn.ConvTranspose2d(100, 100, 2, stride=2, padding=0),
|
| 369 |
+
nn.ReLU(True)
|
| 370 |
+
)
|
| 371 |
+
self.decoder_layer9 = ResidualBlock(100, 50)
|
| 372 |
+
|
| 373 |
+
self.decoder_layer10 = nn.Sequential(
|
| 374 |
+
nn.ConvTranspose2d(50, 20, 2, stride=2, padding=0),
|
| 375 |
+
nn.ReLU(True)
|
| 376 |
+
)
|
| 377 |
+
self.decoder_layer11 = nn.Sequential(
|
| 378 |
+
nn.Conv2d(20, 5, 3, stride=1, padding=1),
|
| 379 |
+
nn.ReLU(True)
|
| 380 |
+
)
|
| 381 |
+
self.decoder_layer12 = nn.Sequential(
|
| 382 |
+
nn.Conv2d(5, 1, 3, stride=1, padding=1),
|
| 383 |
+
nn.ReLU(True)
|
| 384 |
+
)
|
| 385 |
+
def forward(self, x, lower_skip_list, upper_skip_list, height=None, width=None):
|
| 386 |
+
if height == None and width == None:
|
| 387 |
+
height = self.height
|
| 388 |
+
width = self.width
|
| 389 |
+
|
| 390 |
+
x = self.decoder_layer1(x)
|
| 391 |
+
x = self.decoder_layer2(x)
|
| 392 |
+
x = x + lower_skip_list[3] + upper_skip_list[1]
|
| 393 |
+
|
| 394 |
+
x = self.decoder_layer3(x)
|
| 395 |
+
x = self.decoder_layer4(x)
|
| 396 |
+
x = x + lower_skip_list[2] + upper_skip_list[0]
|
| 397 |
+
|
| 398 |
+
x = self.decoder_layer5(x)
|
| 399 |
+
x = self.decoder_layer6(x)
|
| 400 |
+
x = x + lower_skip_list[1]
|
| 401 |
+
|
| 402 |
+
x = self.decoder_layer7(x)
|
| 403 |
+
x = self.decoder_layer8(x)
|
| 404 |
+
x = x + lower_skip_list[0]
|
| 405 |
+
|
| 406 |
+
x = self.decoder_layer9(x)
|
| 407 |
+
x = self.decoder_layer10(x)
|
| 408 |
+
x = self.decoder_layer11(x)
|
| 409 |
+
x = self.decoder_layer12(x)
|
| 410 |
+
return x
|
| 411 |
+
|
| 412 |
+
class MMPI(nn.Module):
|
| 413 |
+
def __init__(self,total_image_input=1, height=384,width=384):
|
| 414 |
+
super().__init__()
|
| 415 |
+
self.height = height
|
| 416 |
+
self.width = width
|
| 417 |
+
self.feature_encoder = FeatureNet(height,width)
|
| 418 |
+
self.lower_encoder = Encoder(height, width, total_image_input)
|
| 419 |
+
self.merge_decoder_rgb = DecoderRGB(height, width)
|
| 420 |
+
self.merge_decoder_sigma = DecoderSigma(height, width)
|
| 421 |
+
self.depth_decoder = DecoderDepth(height, width)
|
| 422 |
+
self.upper_encoder_extra_1 = nn.Sequential(
|
| 423 |
+
ResidualBlock(256, 100),
|
| 424 |
+
nn.MaxPool2d(kernel_size=2, stride=2)
|
| 425 |
+
)
|
| 426 |
+
self.upper_encoder_extra_2 = nn.Sequential(
|
| 427 |
+
ResidualBlock(100, 100),
|
| 428 |
+
nn.MaxPool2d(kernel_size=2, stride=2)
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
def forward(self, x, height=None, width=None):
|
| 432 |
+
if height == None and width == None:
|
| 433 |
+
height = self.height
|
| 434 |
+
width = self.width
|
| 435 |
+
|
| 436 |
+
upper_features_1 = self.feature_encoder.apply_feature_encoder(x)
|
| 437 |
+
upper_features_1 = self.upper_encoder_extra_1(upper_features_1)
|
| 438 |
+
upper_features_2 = self.upper_encoder_extra_2(upper_features_1)
|
| 439 |
+
|
| 440 |
+
lower_feature, skip_list = self.lower_encoder(x, height, width)
|
| 441 |
+
|
| 442 |
+
merged_feature_rgb = self.merge_decoder_rgb(lower_feature, skip_list, [upper_features_1, upper_features_2], height, width)
|
| 443 |
+
merged_feature_sigma = self.merge_decoder_sigma(lower_feature, skip_list, [upper_features_1, upper_features_2], height, width)
|
| 444 |
+
|
| 445 |
+
merged_feature_depth = self.depth_decoder(lower_feature, skip_list, [upper_features_1, upper_features_2])
|
| 446 |
+
|
| 447 |
+
return merged_feature_rgb, merged_feature_sigma, merged_feature_depth
|
| 448 |
+
|
| 449 |
+
def get_rgb_sigma(self, x, height=None, width=None):
|
| 450 |
+
if height == None and width == None:
|
| 451 |
+
height = self.height
|
| 452 |
+
width = self.width
|
| 453 |
+
|
| 454 |
+
upper_features_1 = self.feature_encoder.apply_feature_encoder(x)
|
| 455 |
+
upper_features_1 = self.upper_encoder_extra_1(upper_features_1)
|
| 456 |
+
upper_features_2 = self.upper_encoder_extra_2(upper_features_1)
|
| 457 |
+
|
| 458 |
+
lower_feature, skip_list = self.lower_encoder(x, height, width)
|
| 459 |
+
|
| 460 |
+
merged_feature_rgb = self.merge_decoder_rgb(lower_feature, skip_list, [upper_features_1, upper_features_2], height, width)
|
| 461 |
+
merged_feature_sigma = self.merge_decoder_sigma(lower_feature, skip_list, [upper_features_1, upper_features_2], height, width)
|
| 462 |
+
|
| 463 |
+
return merged_feature_rgb, merged_feature_sigma
|
| 464 |
+
|
| 465 |
+
def get_depth(self, x, height=None, width=None):
|
| 466 |
+
if height == None and width == None:
|
| 467 |
+
height = self.height
|
| 468 |
+
width = self.width
|
| 469 |
+
|
| 470 |
+
upper_features_1 = self.feature_encoder.apply_feature_encoder(x)
|
| 471 |
+
upper_features_1 = self.upper_encoder_extra_1(upper_features_1)
|
| 472 |
+
upper_features_2 = self.upper_encoder_extra_2(upper_features_1)
|
| 473 |
+
|
| 474 |
+
lower_feature, skip_list = self.lower_encoder(x, height, width)
|
| 475 |
+
|
| 476 |
+
merged_feature_depth = self.depth_decoder(lower_feature, skip_list, [upper_features_1, upper_features_2])
|
| 477 |
+
return merged_feature_depth
|
| 478 |
+
|
| 479 |
+
def get_layer_depth(self, x, grid, height=None, width=None):
|
| 480 |
+
if height == None and width == None:
|
| 481 |
+
height = self.height
|
| 482 |
+
width = self.width
|
| 483 |
+
|
| 484 |
+
upper_features_1 = self.feature_encoder.apply_feature_encoder(x)
|
| 485 |
+
upper_features_1 = self.upper_encoder_extra_1(upper_features_1)
|
| 486 |
+
upper_features_2 = self.upper_encoder_extra_2(upper_features_1)
|
| 487 |
+
|
| 488 |
+
lower_feature, skip_list = self.lower_encoder(x, height, width)
|
| 489 |
+
|
| 490 |
+
rgb_layers = self.merge_decoder_rgb(lower_feature, skip_list, [upper_features_1, upper_features_2], height, width)
|
| 491 |
+
sigma_layers = self.merge_decoder_sigma(lower_feature, skip_list, [upper_features_1, upper_features_2], height, width)
|
| 492 |
+
|
| 493 |
+
pred_mpi_planes = torch.randn((1, 4, height, width)).to(params.DEVICE)
|
| 494 |
+
for i in range(params.params_num_planes):
|
| 495 |
+
RGBA = torch.cat((rgb_layers[0,i,:,:,:],sigma_layers[0,i,:,:,:]),dim=0).unsqueeze(0)
|
| 496 |
+
pred_mpi_planes = torch.cat((pred_mpi_planes,RGBA),dim=0)
|
| 497 |
+
|
| 498 |
+
pred_mpi_planes = pred_mpi_planes[1:,:,:,:].unsqueeze(0)
|
| 499 |
+
|
| 500 |
+
sigma = pred_mpi_planes[:, :, 3, :, :]
|
| 501 |
+
B, D, H, W = sigma.shape
|
| 502 |
+
|
| 503 |
+
pred_mpi_disp = grid
|
| 504 |
+
disp_sorted, _ = pred_mpi_disp.sort(dim=1)
|
| 505 |
+
delta = disp_sorted[:, 1:] - disp_sorted[:, :-1]
|
| 506 |
+
delta_last = delta[:, -1:]
|
| 507 |
+
delta = torch.cat([delta, delta_last], dim=1)
|
| 508 |
+
|
| 509 |
+
delta = delta.unsqueeze(-1).unsqueeze(-1).expand_as(sigma)
|
| 510 |
+
|
| 511 |
+
alpha = 1.0 - torch.exp(-delta * sigma)
|
| 512 |
+
|
| 513 |
+
transmittance = torch.cumprod(1 - alpha + 1e-7, dim=1)
|
| 514 |
+
shifted_transmittance = torch.ones_like(transmittance)
|
| 515 |
+
shifted_transmittance[:, 1:, :, :] = transmittance[:, :-1, :, :]
|
| 516 |
+
|
| 517 |
+
disparity = pred_mpi_disp.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, H, W)
|
| 518 |
+
|
| 519 |
+
disparity_map = (disparity * alpha * shifted_transmittance).sum(dim=1, keepdim=True)
|
| 520 |
+
|
| 521 |
+
return disparity_map
|
| 522 |
+
|
| 523 |
+
def get_layers(self, x, height=None, width=None):
|
| 524 |
+
if height == None and width == None:
|
| 525 |
+
height = self.height
|
| 526 |
+
width = self.width
|
| 527 |
+
|
| 528 |
+
upper_features_1 = self.feature_encoder.apply_feature_encoder(x)
|
| 529 |
+
upper_features_1 = self.upper_encoder_extra_1(upper_features_1)
|
| 530 |
+
upper_features_2 = self.upper_encoder_extra_2(upper_features_1)
|
| 531 |
+
|
| 532 |
+
lower_feature, skip_list = self.lower_encoder(x, height, width)
|
| 533 |
+
|
| 534 |
+
merged_feature_rgb = self.merge_decoder_rgb(lower_feature, skip_list, [upper_features_1, upper_features_2], height, width)
|
| 535 |
+
merged_feature_sigma = self.merge_decoder_sigma(lower_feature, skip_list, [upper_features_1, upper_features_2], height, width)
|
| 536 |
+
|
| 537 |
+
return merged_feature_rgb, merged_feature_sigma
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
|
parameters.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
params_height = 256
|
| 5 |
+
params_width = 256
|
| 6 |
+
params_m = 32
|
| 7 |
+
params_number_input = 1
|
| 8 |
+
params_step_size = 2
|
| 9 |
+
params_gamma = 0.2
|
| 10 |
+
params_num_planes = 32
|
| 11 |
+
|
| 12 |
+
TRAIN_LOCATION = "./lf_train.txt"
|
| 13 |
+
VALIDATION_LOCATION = "./lf_validate.txt"
|
| 14 |
+
TEST_LOCATION = "./lf_test.txt"
|
| 15 |
+
LOG_FILE_LOCATION = "./logs/training_log_0.txt"
|
| 16 |
+
CHECKPOINT_LOCATION = "./checkpoint/"
|
| 17 |
+
RESUME_CHECKPOINT_LOCATION = "./checkpoint/checkpoint_best.pth"
|
| 18 |
+
START_CHECKPOINT_LOCATION = "./checkpoint/checkpoint_init.pth"
|
| 19 |
+
DEVICE = "cpu"
|
| 20 |
+
|
| 21 |
+
BATCH_SIZE = 32
|
| 22 |
+
LEARNING_RATE = 0.0001
|
| 23 |
+
NUM_EPOCHS = 150
|
| 24 |
+
START_EPOCH = 0
|
| 25 |
+
PRINT_INTERVAL = 20
|
| 26 |
+
T_max = 150
|
| 27 |
+
|
| 28 |
+
os.makedirs("./logs",exist_ok=True)
|
| 29 |
+
os.makedirs("./checkpoint",exist_ok=True)
|
| 30 |
+
os.makedirs("./output",exist_ok=True)
|
| 31 |
+
|
| 32 |
+
def uniform_planes(a: float, b: float, n: int) -> torch.Tensor:
|
| 33 |
+
"""
|
| 34 |
+
Return n values uniformly spaced *within* (a, b),
|
| 35 |
+
i.e. excluding the exact endpoints a and b.
|
| 36 |
+
"""
|
| 37 |
+
step = (b - a) / (n + 1)
|
| 38 |
+
# torch.arange(1, n+1) gives [1,2,...,n]
|
| 39 |
+
return a + step * torch.arange(1, n + 1, dtype=torch.float32)
|
| 40 |
+
|
| 41 |
+
def get_disparity_all_src():
|
| 42 |
+
d1 = uniform_planes(0.0, 0.4, 20)
|
| 43 |
+
d2 = uniform_planes(0.4, 1.0, 12)
|
| 44 |
+
disparities = torch.cat([d1, d2], dim=0)
|
| 45 |
+
return disparities
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
|
post-install.sh
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
pip install "pytorch3d @ git+https://github.com/facebookresearch/pytorch3d.git@89653419d0973396f3eff1a381ba09a07fffc2ed"
|
requirements.txt
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy==1.26.4
|
| 2 |
+
torch==2.1.0
|
| 3 |
+
torchvision==0.16.0
|
| 4 |
+
pytorch-lightning==2.1.3
|
| 5 |
+
pytorch-msssim==1.0.0
|
| 6 |
+
pytorchvideo==0.1.5
|
| 7 |
+
grpcio==1.57.0
|
| 8 |
+
opencv-contrib-python==4.10.0.84
|
| 9 |
+
opencv-python==4.6.0.66
|
| 10 |
+
pillow==10.4.0
|
| 11 |
+
pillow_heif==0.15.0
|
| 12 |
+
matplotlib==3.7.2
|
| 13 |
+
matplotlib-inline==0.1.6
|
| 14 |
+
transformers==4.43.3
|
| 15 |
+
tqdm==4.65.0
|
| 16 |
+
moviepy==1.0.3
|
| 17 |
+
scikit-image==0.21.0
|
| 18 |
+
scikit-learn==1.3.0
|
| 19 |
+
scipy==1.11.2
|
utils.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from scipy.ndimage import map_coordinates
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def xyzcube(face_w):
|
| 6 |
+
'''
|
| 7 |
+
Return the xyz cordinates of the unit cube in [F R B L U D] format.
|
| 8 |
+
'''
|
| 9 |
+
out = np.zeros((face_w, face_w * 6, 3), np.float32)
|
| 10 |
+
rng = np.linspace(-0.5, 0.5, num=face_w, dtype=np.float32)
|
| 11 |
+
grid = np.stack(np.meshgrid(rng, -rng), -1)
|
| 12 |
+
|
| 13 |
+
# Front face (z = 0.5)
|
| 14 |
+
out[:, 0*face_w:1*face_w, [0, 1]] = grid
|
| 15 |
+
out[:, 0*face_w:1*face_w, 2] = 0.5
|
| 16 |
+
|
| 17 |
+
# Right face (x = 0.5)
|
| 18 |
+
out[:, 1*face_w:2*face_w, [2, 1]] = grid
|
| 19 |
+
out[:, 1*face_w:2*face_w, 0] = 0.5
|
| 20 |
+
|
| 21 |
+
# Back face (z = -0.5)
|
| 22 |
+
out[:, 2*face_w:3*face_w, [0, 1]] = grid
|
| 23 |
+
out[:, 2*face_w:3*face_w, 2] = -0.5
|
| 24 |
+
|
| 25 |
+
# Left face (x = -0.5)
|
| 26 |
+
out[:, 3*face_w:4*face_w, [2, 1]] = grid
|
| 27 |
+
out[:, 3*face_w:4*face_w, 0] = -0.5
|
| 28 |
+
|
| 29 |
+
# Up face (y = 0.5)
|
| 30 |
+
out[:, 4*face_w:5*face_w, [0, 2]] = grid
|
| 31 |
+
out[:, 4*face_w:5*face_w, 1] = 0.5
|
| 32 |
+
|
| 33 |
+
# Down face (y = -0.5)
|
| 34 |
+
out[:, 5*face_w:6*face_w, [0, 2]] = grid
|
| 35 |
+
out[:, 5*face_w:6*face_w, 1] = -0.5
|
| 36 |
+
|
| 37 |
+
return out
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def equirect_uvgrid(h, w):
|
| 41 |
+
u = np.linspace(-np.pi, np.pi, num=w, dtype=np.float32)
|
| 42 |
+
v = np.linspace(np.pi, -np.pi, num=h, dtype=np.float32) / 2
|
| 43 |
+
|
| 44 |
+
return np.stack(np.meshgrid(u, v), axis=-1)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def equirect_facetype(h, w):
|
| 48 |
+
'''
|
| 49 |
+
0F 1R 2B 3L 4U 5D
|
| 50 |
+
'''
|
| 51 |
+
tp = np.roll(np.arange(4).repeat(w // 4)[None, :].repeat(h, 0), 3 * w // 8, 1)
|
| 52 |
+
|
| 53 |
+
# Prepare ceil mask
|
| 54 |
+
mask = np.zeros((h, w // 4), np.bool)
|
| 55 |
+
idx = np.linspace(-np.pi, np.pi, w // 4) / 4
|
| 56 |
+
idx = h // 2 - np.round(np.arctan(np.cos(idx)) * h / np.pi).astype(int)
|
| 57 |
+
for i, j in enumerate(idx):
|
| 58 |
+
mask[:j, i] = 1
|
| 59 |
+
mask = np.roll(np.concatenate([mask] * 4, 1), 3 * w // 8, 1)
|
| 60 |
+
|
| 61 |
+
tp[mask] = 4
|
| 62 |
+
tp[np.flip(mask, 0)] = 5
|
| 63 |
+
|
| 64 |
+
return tp.astype(np.int32)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def xyzpers(h_fov, v_fov, u, v, out_hw, in_rot):
|
| 68 |
+
out = np.ones((*out_hw, 3), np.float32)
|
| 69 |
+
|
| 70 |
+
x_max = np.tan(h_fov / 2)
|
| 71 |
+
y_max = np.tan(v_fov / 2)
|
| 72 |
+
x_rng = np.linspace(-x_max, x_max, num=out_hw[1], dtype=np.float32)
|
| 73 |
+
y_rng = np.linspace(-y_max, y_max, num=out_hw[0], dtype=np.float32)
|
| 74 |
+
out[..., :2] = np.stack(np.meshgrid(x_rng, -y_rng), -1)
|
| 75 |
+
Rx = rotation_matrix(v, [1, 0, 0])
|
| 76 |
+
Ry = rotation_matrix(u, [0, 1, 0])
|
| 77 |
+
Ri = rotation_matrix(in_rot, np.array([0, 0, 1.0]).dot(Rx).dot(Ry))
|
| 78 |
+
|
| 79 |
+
return out.dot(Rx).dot(Ry).dot(Ri)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def xyz2uv(xyz):
|
| 83 |
+
'''
|
| 84 |
+
xyz: ndarray in shape of [..., 3]
|
| 85 |
+
'''
|
| 86 |
+
x, y, z = np.split(xyz, 3, axis=-1)
|
| 87 |
+
u = np.arctan2(x, z)
|
| 88 |
+
c = np.sqrt(x**2 + z**2)
|
| 89 |
+
v = np.arctan2(y, c)
|
| 90 |
+
|
| 91 |
+
return np.concatenate([u, v], axis=-1)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def uv2unitxyz(uv):
|
| 95 |
+
u, v = np.split(uv, 2, axis=-1)
|
| 96 |
+
y = np.sin(v)
|
| 97 |
+
c = np.cos(v)
|
| 98 |
+
x = c * np.sin(u)
|
| 99 |
+
z = c * np.cos(u)
|
| 100 |
+
|
| 101 |
+
return np.concatenate([x, y, z], axis=-1)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def uv2coor(uv, h, w):
|
| 105 |
+
'''
|
| 106 |
+
uv: ndarray in shape of [..., 2]
|
| 107 |
+
h: int, height of the equirectangular image
|
| 108 |
+
w: int, width of the equirectangular image
|
| 109 |
+
'''
|
| 110 |
+
u, v = np.split(uv, 2, axis=-1)
|
| 111 |
+
coor_x = (u / (2 * np.pi) + 0.5) * w - 0.5
|
| 112 |
+
coor_y = (-v / np.pi + 0.5) * h - 0.5
|
| 113 |
+
|
| 114 |
+
return np.concatenate([coor_x, coor_y], axis=-1)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def coor2uv(coorxy, h, w):
|
| 118 |
+
coor_x, coor_y = np.split(coorxy, 2, axis=-1)
|
| 119 |
+
u = ((coor_x + 0.5) / w - 0.5) * 2 * np.pi
|
| 120 |
+
v = -((coor_y + 0.5) / h - 0.5) * np.pi
|
| 121 |
+
|
| 122 |
+
return np.concatenate([u, v], axis=-1)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def sample_equirec(e_img, coor_xy, order):
|
| 126 |
+
w = e_img.shape[1]
|
| 127 |
+
coor_x, coor_y = np.split(coor_xy, 2, axis=-1)
|
| 128 |
+
pad_u = np.roll(e_img[[0]], w // 2, 1)
|
| 129 |
+
pad_d = np.roll(e_img[[-1]], w // 2, 1)
|
| 130 |
+
e_img = np.concatenate([e_img, pad_d, pad_u], 0)
|
| 131 |
+
return map_coordinates(e_img, [coor_y, coor_x],
|
| 132 |
+
order=order, mode='wrap')[..., 0]
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def sample_cubefaces(cube_faces, tp, coor_y, coor_x, order):
|
| 136 |
+
cube_faces = cube_faces.copy()
|
| 137 |
+
cube_faces[1] = np.flip(cube_faces[1], 1)
|
| 138 |
+
cube_faces[2] = np.flip(cube_faces[2], 1)
|
| 139 |
+
cube_faces[4] = np.flip(cube_faces[4], 0)
|
| 140 |
+
|
| 141 |
+
# Pad up down
|
| 142 |
+
pad_ud = np.zeros((6, 2, cube_faces.shape[2]))
|
| 143 |
+
pad_ud[0, 0] = cube_faces[5, 0, :]
|
| 144 |
+
pad_ud[0, 1] = cube_faces[4, -1, :]
|
| 145 |
+
pad_ud[1, 0] = cube_faces[5, :, -1]
|
| 146 |
+
pad_ud[1, 1] = cube_faces[4, ::-1, -1]
|
| 147 |
+
pad_ud[2, 0] = cube_faces[5, -1, ::-1]
|
| 148 |
+
pad_ud[2, 1] = cube_faces[4, 0, ::-1]
|
| 149 |
+
pad_ud[3, 0] = cube_faces[5, ::-1, 0]
|
| 150 |
+
pad_ud[3, 1] = cube_faces[4, :, 0]
|
| 151 |
+
pad_ud[4, 0] = cube_faces[0, 0, :]
|
| 152 |
+
pad_ud[4, 1] = cube_faces[2, 0, ::-1]
|
| 153 |
+
pad_ud[5, 0] = cube_faces[2, -1, ::-1]
|
| 154 |
+
pad_ud[5, 1] = cube_faces[0, -1, :]
|
| 155 |
+
cube_faces = np.concatenate([cube_faces, pad_ud], 1)
|
| 156 |
+
|
| 157 |
+
# Pad left right
|
| 158 |
+
pad_lr = np.zeros((6, cube_faces.shape[1], 2))
|
| 159 |
+
pad_lr[0, :, 0] = cube_faces[1, :, 0]
|
| 160 |
+
pad_lr[0, :, 1] = cube_faces[3, :, -1]
|
| 161 |
+
pad_lr[1, :, 0] = cube_faces[2, :, 0]
|
| 162 |
+
pad_lr[1, :, 1] = cube_faces[0, :, -1]
|
| 163 |
+
pad_lr[2, :, 0] = cube_faces[3, :, 0]
|
| 164 |
+
pad_lr[2, :, 1] = cube_faces[1, :, -1]
|
| 165 |
+
pad_lr[3, :, 0] = cube_faces[0, :, 0]
|
| 166 |
+
pad_lr[3, :, 1] = cube_faces[2, :, -1]
|
| 167 |
+
pad_lr[4, 1:-1, 0] = cube_faces[1, 0, ::-1]
|
| 168 |
+
pad_lr[4, 1:-1, 1] = cube_faces[3, 0, :]
|
| 169 |
+
pad_lr[5, 1:-1, 0] = cube_faces[1, -2, :]
|
| 170 |
+
pad_lr[5, 1:-1, 1] = cube_faces[3, -2, ::-1]
|
| 171 |
+
cube_faces = np.concatenate([cube_faces, pad_lr], 2)
|
| 172 |
+
|
| 173 |
+
return map_coordinates(cube_faces, [tp, coor_y, coor_x], order=order, mode='wrap')
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def cube_h2list(cube_h):
|
| 177 |
+
assert cube_h.shape[0] * 6 == cube_h.shape[1]
|
| 178 |
+
return np.split(cube_h, 6, axis=1)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def cube_list2h(cube_list):
|
| 182 |
+
assert len(cube_list) == 6
|
| 183 |
+
assert sum(face.shape == cube_list[0].shape for face in cube_list) == 6
|
| 184 |
+
return np.concatenate(cube_list, axis=1)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def cube_h2dict(cube_h):
|
| 188 |
+
cube_list = cube_h2list(cube_h)
|
| 189 |
+
return dict([(k, cube_list[i])
|
| 190 |
+
for i, k in enumerate(['F', 'R', 'B', 'L', 'U', 'D'])])
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def cube_dict2h(cube_dict, face_k=['F', 'R', 'B', 'L', 'U', 'D']):
|
| 194 |
+
assert len(face_k) == 6
|
| 195 |
+
return cube_list2h([cube_dict[k] for k in face_k])
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def cube_h2dice(cube_h):
|
| 199 |
+
assert cube_h.shape[0] * 6 == cube_h.shape[1]
|
| 200 |
+
w = cube_h.shape[0]
|
| 201 |
+
cube_dice = np.zeros((w * 3, w * 4, cube_h.shape[2]), dtype=cube_h.dtype)
|
| 202 |
+
cube_list = cube_h2list(cube_h)
|
| 203 |
+
# Order: F R B L U D
|
| 204 |
+
sxy = [(1, 1), (2, 1), (3, 1), (0, 1), (1, 0), (1, 2)]
|
| 205 |
+
for i, (sx, sy) in enumerate(sxy):
|
| 206 |
+
face = cube_list[i]
|
| 207 |
+
if i in [1, 2]:
|
| 208 |
+
face = np.flip(face, axis=1)
|
| 209 |
+
if i == 4:
|
| 210 |
+
face = np.flip(face, axis=0)
|
| 211 |
+
cube_dice[sy*w:(sy+1)*w, sx*w:(sx+1)*w] = face
|
| 212 |
+
return cube_dice
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def cube_dice2h(cube_dice):
|
| 216 |
+
w = cube_dice.shape[0] // 3
|
| 217 |
+
assert cube_dice.shape[0] == w * 3 and cube_dice.shape[1] == w * 4
|
| 218 |
+
cube_h = np.zeros((w, w * 6, cube_dice.shape[2]), dtype=cube_dice.dtype)
|
| 219 |
+
# Order: F R B L U D
|
| 220 |
+
sxy = [(1, 1), (2, 1), (3, 1), (0, 1), (1, 0), (1, 2)]
|
| 221 |
+
for i, (sx, sy) in enumerate(sxy):
|
| 222 |
+
face = cube_dice[sy*w:(sy+1)*w, sx*w:(sx+1)*w]
|
| 223 |
+
if i in [1, 2]:
|
| 224 |
+
face = np.flip(face, axis=1)
|
| 225 |
+
if i == 4:
|
| 226 |
+
face = np.flip(face, axis=0)
|
| 227 |
+
cube_h[:, i*w:(i+1)*w] = face
|
| 228 |
+
return cube_h
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def rotation_matrix(rad, ax):
|
| 232 |
+
ax = np.array(ax)
|
| 233 |
+
assert len(ax.shape) == 1 and ax.shape[0] == 3
|
| 234 |
+
ax = ax / np.sqrt((ax**2).sum())
|
| 235 |
+
R = np.diag([np.cos(rad)] * 3)
|
| 236 |
+
R = R + np.outer(ax, ax) * (1.0 - np.cos(rad))
|
| 237 |
+
|
| 238 |
+
ax = ax * np.sin(rad)
|
| 239 |
+
R = R + np.array([[0, -ax[2], ax[1]],
|
| 240 |
+
[ax[2], 0, -ax[0]],
|
| 241 |
+
[-ax[1], ax[0], 0]])
|
| 242 |
+
|
| 243 |
+
return R
|
utils/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
utils/__init__.py
ADDED
|
File without changes
|
utils/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (164 Bytes). View file
|
|
|
utils/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (156 Bytes). View file
|
|
|
utils/__pycache__/rendererBackbone.cpython-39.pyc
ADDED
|
Binary file (4.02 kB). View file
|
|
|
utils/__pycache__/utils.cpython-38.pyc
ADDED
|
Binary file (4.03 kB). View file
|
|
|
utils/__pycache__/utils.cpython-39.pyc
ADDED
|
Binary file (4.04 kB). View file
|
|
|
utils/mpi/__init__.py
ADDED
|
File without changes
|
utils/mpi/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (168 Bytes). View file
|
|
|
utils/mpi/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (160 Bytes). View file
|
|
|
utils/mpi/__pycache__/homography_sampler.cpython-38.pyc
ADDED
|
Binary file (4.62 kB). View file
|
|
|
utils/mpi/__pycache__/homography_sampler.cpython-39.pyc
ADDED
|
Binary file (4.64 kB). View file
|
|
|
utils/mpi/__pycache__/mpi_rendering.cpython-38.pyc
ADDED
|
Binary file (7.43 kB). View file
|
|
|
utils/mpi/__pycache__/mpi_rendering.cpython-39.pyc
ADDED
|
Binary file (7.45 kB). View file
|
|
|
utils/mpi/__pycache__/rendering_utils.cpython-38.pyc
ADDED
|
Binary file (4.09 kB). View file
|
|
|
utils/mpi/__pycache__/rendering_utils.cpython-39.pyc
ADDED
|
Binary file (4.07 kB). View file
|
|
|
utils/mpi/homography_sampler.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from scipy.spatial.transform import Rotation
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def inverse(matrices):
|
| 7 |
+
"""
|
| 8 |
+
torch.inverse() sometimes produces outputs with nan the when batch size is 2.
|
| 9 |
+
Ref https://github.com/pytorch/pytorch/issues/47272
|
| 10 |
+
this function keeps inversing the matrix until successful or maximum tries is reached
|
| 11 |
+
:param matrices Bx3x3
|
| 12 |
+
"""
|
| 13 |
+
inverse = None
|
| 14 |
+
max_tries = 5
|
| 15 |
+
while (inverse is None) or (torch.isnan(inverse)).any():
|
| 16 |
+
#torch.cuda.synchronize()
|
| 17 |
+
inverse = torch.inverse(matrices)
|
| 18 |
+
|
| 19 |
+
# Break out of the loop when the inverse is successful or there"re no more tries
|
| 20 |
+
max_tries -= 1
|
| 21 |
+
if max_tries == 0:
|
| 22 |
+
break
|
| 23 |
+
|
| 24 |
+
# Raise an Exception if the inverse contains nan
|
| 25 |
+
if (torch.isnan(inverse)).any():
|
| 26 |
+
raise Exception("Matrix inverse contains nan!")
|
| 27 |
+
return inverse
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class HomographySample:
|
| 31 |
+
def __init__(self, H_tgt, W_tgt, device=None):
|
| 32 |
+
if device is None:
|
| 33 |
+
self.device = torch.device("cpu")
|
| 34 |
+
else:
|
| 35 |
+
self.device = device
|
| 36 |
+
|
| 37 |
+
self.Height_tgt = H_tgt
|
| 38 |
+
self.Width_tgt = W_tgt
|
| 39 |
+
self.meshgrid = self.grid_generation(self.Height_tgt, self.Width_tgt, self.device)
|
| 40 |
+
self.meshgrid = self.meshgrid.permute(2, 0, 1).contiguous() # 3xHxW
|
| 41 |
+
|
| 42 |
+
self.n = self.plane_normal_generation(self.device)
|
| 43 |
+
|
| 44 |
+
@staticmethod
|
| 45 |
+
def grid_generation(H, W, device):
|
| 46 |
+
x = np.linspace(0, W-1, W)
|
| 47 |
+
y = np.linspace(0, H-1, H)
|
| 48 |
+
xv, yv = np.meshgrid(x, y) # HxW
|
| 49 |
+
xv = torch.from_numpy(xv.astype(np.float32)).to(dtype=torch.float32, device=device)
|
| 50 |
+
yv = torch.from_numpy(yv.astype(np.float32)).to(dtype=torch.float32, device=device)
|
| 51 |
+
ones = torch.ones_like(xv)
|
| 52 |
+
meshgrid = torch.stack((xv, yv, ones), dim=2) # HxWx3
|
| 53 |
+
return meshgrid
|
| 54 |
+
|
| 55 |
+
@staticmethod
|
| 56 |
+
def plane_normal_generation(device):
|
| 57 |
+
n = torch.tensor([0, 0, 1], dtype=torch.float32, device=device)
|
| 58 |
+
return n
|
| 59 |
+
|
| 60 |
+
@staticmethod
|
| 61 |
+
def euler_to_rotation_matrix(x_angle, y_angle, z_angle, seq='xyz', degrees=False):
|
| 62 |
+
"""
|
| 63 |
+
Note that here we want to return a rotation matrix rot_mtx, which transform the tgt points into src frame,
|
| 64 |
+
i.e, rot_mtx * p_tgt = p_src
|
| 65 |
+
Therefore we need to add negative to x/y/z_angle
|
| 66 |
+
:param roll:
|
| 67 |
+
:param pitch:
|
| 68 |
+
:param yaw:
|
| 69 |
+
:return:
|
| 70 |
+
"""
|
| 71 |
+
r = Rotation.from_euler(seq,
|
| 72 |
+
[-x_angle, -y_angle, -z_angle],
|
| 73 |
+
degrees=degrees)
|
| 74 |
+
rot_mtx = r.as_matrix().astype(np.float32)
|
| 75 |
+
return rot_mtx
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def sample(self, src_BCHW, d_src_B,
|
| 79 |
+
G_tgt_src,
|
| 80 |
+
K_src_inv, K_tgt):
|
| 81 |
+
"""
|
| 82 |
+
Coordinate system: x, y are the image directions, z is pointing to depth direction
|
| 83 |
+
:param src_BCHW: torch tensor float, 0-1, rgb/rgba. BxCxHxW
|
| 84 |
+
Assume to be at position P=[I|0]
|
| 85 |
+
:param d_src_B: distance of image plane to src camera origin
|
| 86 |
+
:param G_tgt_src: Bx4x4
|
| 87 |
+
:param K_src_inv: Bx3x3
|
| 88 |
+
:param K_tgt: Bx3x3
|
| 89 |
+
:return: tgt_BCHW
|
| 90 |
+
"""
|
| 91 |
+
# parameter processing ------ begin ------
|
| 92 |
+
B, channels, Height_src, Width_src = src_BCHW.size(0), src_BCHW.size(1), src_BCHW.size(2), src_BCHW.size(3)
|
| 93 |
+
R_tgt_src = G_tgt_src[:, 0:3, 0:3]
|
| 94 |
+
t_tgt_src = G_tgt_src[:, 0:3, 3]
|
| 95 |
+
|
| 96 |
+
Height_tgt = self.Height_tgt
|
| 97 |
+
Width_tgt = self.Width_tgt
|
| 98 |
+
# if R_src_tgt is None:
|
| 99 |
+
# R_src_tgt = torch.eye(3, dtype=torch.float32, device=src_BCHW.device)
|
| 100 |
+
# R_src_tgt = R_src_tgt.unsqueeze(0).expand(B, 3, 3)
|
| 101 |
+
# if t_src_tgt is None:
|
| 102 |
+
# t_src_tgt = torch.tensor([0, 0, 0],
|
| 103 |
+
# dtype=torch.float32,
|
| 104 |
+
# device=src_BCHW.device)
|
| 105 |
+
# t_src_tgt = t_src_tgt.unsqueeze(0).expand(B, 3)
|
| 106 |
+
|
| 107 |
+
# relationship between FoV and focal length:
|
| 108 |
+
# assume W > H
|
| 109 |
+
# W / 2 = f*tan(\theta / 2)
|
| 110 |
+
# here we default the horizontal FoV as 53.13 degree
|
| 111 |
+
# the vertical FoV can be computed as H/2 = W*tan(\theta/2)
|
| 112 |
+
|
| 113 |
+
R_tgt_src = R_tgt_src.to(device=src_BCHW.device)
|
| 114 |
+
t_tgt_src = t_tgt_src.to(device=src_BCHW.device)
|
| 115 |
+
K_src_inv = K_src_inv.to(device=src_BCHW.device)
|
| 116 |
+
K_tgt = K_tgt.to(device=src_BCHW.device)
|
| 117 |
+
# parameter processing ------ end ------
|
| 118 |
+
|
| 119 |
+
# the goal is compute H_src_tgt, that maps a tgt pixel to src pixel
|
| 120 |
+
# so we compute H_tgt_src first, and then inverse
|
| 121 |
+
n = self.n.to(device=src_BCHW.device)
|
| 122 |
+
n = n.unsqueeze(0).repeat(B, 1) # Bx3
|
| 123 |
+
# Bx3x3 - (Bx3x1 * Bx1x3)
|
| 124 |
+
# note here we use -d_src, because the plane function is n^T * X - d_src = 0
|
| 125 |
+
d_src_B33 = d_src_B.reshape(B, 1, 1).repeat(1, 3, 3) # B -> Bx3x3
|
| 126 |
+
R_tnd = R_tgt_src - torch.matmul(t_tgt_src.unsqueeze(2), n.unsqueeze(1)) / -d_src_B33
|
| 127 |
+
H_tgt_src = torch.matmul(K_tgt,
|
| 128 |
+
torch.matmul(R_tnd, K_src_inv))
|
| 129 |
+
|
| 130 |
+
# TODO: fix cuda inverse
|
| 131 |
+
with torch.no_grad():
|
| 132 |
+
H_src_tgt = inverse(H_tgt_src)
|
| 133 |
+
|
| 134 |
+
# create tgt image grid, and map to src
|
| 135 |
+
meshgrid_tgt_homo = self.meshgrid.to(src_BCHW.device)
|
| 136 |
+
# 3xHxW -> Bx3xHxW
|
| 137 |
+
meshgrid_tgt_homo = meshgrid_tgt_homo.unsqueeze(0).expand(B, 3, Height_tgt, Width_tgt)
|
| 138 |
+
|
| 139 |
+
# wrap meshgrid_tgt_homo to meshgrid_src
|
| 140 |
+
meshgrid_tgt_homo_B3N = meshgrid_tgt_homo.view(B, 3, -1) # Bx3xHW
|
| 141 |
+
meshgrid_src_homo_B3N = torch.matmul(H_src_tgt, meshgrid_tgt_homo_B3N) # Bx3x3 * Bx3xHW -> Bx3xHW
|
| 142 |
+
# Bx3xHW -> Bx3xHxW -> BxHxWx3
|
| 143 |
+
meshgrid_src_homo = meshgrid_src_homo_B3N.view(B, 3, Height_tgt, Width_tgt).permute(0, 2, 3, 1)
|
| 144 |
+
meshgrid_src = meshgrid_src_homo[:, :, :, 0:2] / meshgrid_src_homo[:, :, :, 2:] # BxHxWx2
|
| 145 |
+
|
| 146 |
+
valid_mask_x = torch.logical_and(meshgrid_src[:, :, :, 0] < Width_src,
|
| 147 |
+
meshgrid_src[:, :, :, 0] > -1)
|
| 148 |
+
valid_mask_y = torch.logical_and(meshgrid_src[:, :, :, 1] < Height_src,
|
| 149 |
+
meshgrid_src[:, :, :, 1] > -1)
|
| 150 |
+
valid_mask = torch.logical_and(valid_mask_x, valid_mask_y) # BxHxW
|
| 151 |
+
|
| 152 |
+
# sample from src_BCHW
|
| 153 |
+
# normalize meshgrid_src to [-1,1]
|
| 154 |
+
meshgrid_src[:, :, :, 0] = (meshgrid_src[:, :, :, 0]+0.5) / (Width_src * 0.5) - 1
|
| 155 |
+
meshgrid_src[:, :, :, 1] = (meshgrid_src[:, :, :, 1]+0.5) / (Height_src * 0.5) - 1
|
| 156 |
+
tgt_BCHW = torch.nn.functional.grid_sample(src_BCHW, grid=meshgrid_src, padding_mode='border',
|
| 157 |
+
align_corners=False)
|
| 158 |
+
# BxCxHxW, BxHxW
|
| 159 |
+
return tgt_BCHW, valid_mask
|
utils/mpi/mpi_rendering.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from utils.mpi.homography_sampler import HomographySample
|
| 4 |
+
from utils.mpi.rendering_utils import transform_G_xyz, sample_pdf, gather_pixel_by_pxpy
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def render(rgb_BS3HW, sigma_BS1HW, xyz_BS3HW, use_alpha=False, is_bg_depth_inf=False):
|
| 8 |
+
if not use_alpha:
|
| 9 |
+
imgs_syn, depth_syn, blend_weights, weights = plane_volume_rendering(
|
| 10 |
+
rgb_BS3HW,
|
| 11 |
+
sigma_BS1HW,
|
| 12 |
+
xyz_BS3HW,
|
| 13 |
+
is_bg_depth_inf
|
| 14 |
+
)
|
| 15 |
+
else:
|
| 16 |
+
imgs_syn, weights = alpha_composition(sigma_BS1HW, rgb_BS3HW)
|
| 17 |
+
depth_syn, _ = alpha_composition(sigma_BS1HW, xyz_BS3HW[:, :, 2:])
|
| 18 |
+
# No rgb blending with alpha composition
|
| 19 |
+
blend_weights = torch.cumprod(1 - sigma_BS1HW + 1e-6, dim=1)
|
| 20 |
+
# blend_weights = torch.zeros_like(rgb_BS3HW).cuda()
|
| 21 |
+
return imgs_syn, depth_syn, blend_weights, weights
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def alpha_composition(alpha_BK1HW, value_BKCHW):
|
| 25 |
+
"""
|
| 26 |
+
composition equation from 'Single-View View Synthesis with Multiplane Images'
|
| 27 |
+
K is the number of planes, k=0 means the nearest plane, k=K-1 means the farthest plane
|
| 28 |
+
:param alpha_BK1HW: alpha at each of the K planes
|
| 29 |
+
:param value_BKCHW: rgb/disparity at each of the K planes
|
| 30 |
+
:return:
|
| 31 |
+
"""
|
| 32 |
+
B, K, _, H, W = alpha_BK1HW.size()
|
| 33 |
+
alpha_comp_cumprod = torch.cumprod(1 - alpha_BK1HW, dim=1) # BxKx1xHxW
|
| 34 |
+
|
| 35 |
+
preserve_ratio = torch.cat((torch.ones((B, 1, 1, H, W), dtype=alpha_BK1HW.dtype, device=alpha_BK1HW.device),
|
| 36 |
+
alpha_comp_cumprod[:, 0:K-1, :, :, :]), dim=1) # BxKx1xHxW
|
| 37 |
+
weights = alpha_BK1HW * preserve_ratio # BxKx1xHxW
|
| 38 |
+
value_composed = torch.sum(value_BKCHW * weights, dim=1, keepdim=False) # Bx3xHxW
|
| 39 |
+
|
| 40 |
+
return value_composed, weights
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def plane_volume_rendering(rgb_BS3HW, sigma_BS1HW, xyz_BS3HW, is_bg_depth_inf):
|
| 44 |
+
B, S, _, H, W = sigma_BS1HW.size()
|
| 45 |
+
|
| 46 |
+
xyz_diff_BS3HW = xyz_BS3HW[:, 1:, :, :, :] - xyz_BS3HW[:, 0:-1, :, :, :] # Bx(S-1)x3xHxW
|
| 47 |
+
xyz_dist_BS1HW = torch.norm(xyz_diff_BS3HW, dim=2, keepdim=True) # Bx(S-1)x1xHxW
|
| 48 |
+
|
| 49 |
+
xyz_dist_BS1HW = torch.cat((xyz_dist_BS1HW,
|
| 50 |
+
torch.full((B, 1, 1, H, W),
|
| 51 |
+
fill_value=1e3,
|
| 52 |
+
dtype=xyz_BS3HW.dtype,
|
| 53 |
+
device=xyz_BS3HW.device)),
|
| 54 |
+
dim=1) # BxSx3xHxW
|
| 55 |
+
transparency = torch.exp(-sigma_BS1HW * xyz_dist_BS1HW) # BxSx1xHxW
|
| 56 |
+
alpha = 1 - transparency # BxSx1xHxW
|
| 57 |
+
|
| 58 |
+
# add small eps to avoid zero transparency_acc
|
| 59 |
+
# pytorch.cumprod is like: [a, b, c] -> [a, a*b, a*b*c], we need to modify it to [1, a, a*b]
|
| 60 |
+
transparency_acc = torch.cumprod(transparency + 1e-6, dim=1) # BxSx1xHxW
|
| 61 |
+
transparency_acc = torch.cat((torch.ones((B, 1, 1, H, W), dtype=transparency.dtype, device=transparency.device),
|
| 62 |
+
transparency_acc[:, 0:-1, :, :, :]),
|
| 63 |
+
dim=1) # BxSx1xHxW
|
| 64 |
+
|
| 65 |
+
weights = transparency_acc * alpha # BxSx1xHxW
|
| 66 |
+
rgb_out, depth_out = weighted_sum_mpi(rgb_BS3HW, xyz_BS3HW, weights, is_bg_depth_inf)
|
| 67 |
+
|
| 68 |
+
return rgb_out, depth_out, transparency_acc, weights
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def weighted_sum_mpi(rgb_BS3HW, xyz_BS3HW, weights, is_bg_depth_inf):
|
| 72 |
+
weights_sum = torch.sum(weights, dim=1, keepdim=False) # Bx1xHxW
|
| 73 |
+
rgb_out = torch.sum(weights * rgb_BS3HW, dim=1, keepdim=False) # Bx3xHxW
|
| 74 |
+
|
| 75 |
+
if is_bg_depth_inf:
|
| 76 |
+
# for dtu dataset, set large depth if weight_sum is small
|
| 77 |
+
depth_out = torch.sum(weights * xyz_BS3HW[:, :, 2:, :, :], dim=1, keepdim=False) \
|
| 78 |
+
+ (1 - weights_sum) * 1000
|
| 79 |
+
else:
|
| 80 |
+
depth_out = torch.sum(weights * xyz_BS3HW[:, :, 2:, :, :], dim=1, keepdim=False) \
|
| 81 |
+
/ (weights_sum + 1e-5) # Bx1xHxW
|
| 82 |
+
|
| 83 |
+
return rgb_out, depth_out
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def get_xyz_from_depth(meshgrid_homo,
|
| 87 |
+
depth,
|
| 88 |
+
K_inv):
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
:param meshgrid_homo: 3xHxW
|
| 92 |
+
:param depth: Bx1xHxW
|
| 93 |
+
:param K_inv: Bx3x3
|
| 94 |
+
:return:
|
| 95 |
+
"""
|
| 96 |
+
H, W = meshgrid_homo.size(1), meshgrid_homo.size(2)
|
| 97 |
+
B, _, H_d, W_d = depth.size()
|
| 98 |
+
assert H==H_d, W==W_d
|
| 99 |
+
|
| 100 |
+
# 3xHxW -> Bx3xHxW
|
| 101 |
+
meshgrid_src_homo = meshgrid_homo.unsqueeze(0).repeat(B, 1, 1, 1)
|
| 102 |
+
meshgrid_src_homo_B3N = meshgrid_src_homo.reshape(B, 3, -1)
|
| 103 |
+
xyz_src = torch.matmul(K_inv, meshgrid_src_homo_B3N) # Bx3xHW
|
| 104 |
+
xyz_src = xyz_src.reshape(B, 3, H, W) * depth # Bx3xHxW
|
| 105 |
+
|
| 106 |
+
return xyz_src
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def disparity_consistency_src_to_tgt(meshgrid_homo, K_src_inv, disparity_src,
|
| 110 |
+
G_tgt_src, K_tgt, disparity_tgt):
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
:param xyz_src_B3N: Bx3xN
|
| 114 |
+
:param G_tgt_src: Bx4x4
|
| 115 |
+
:param K_tgt: Bx3x3
|
| 116 |
+
:param disparity_tgt: Bx1xHxW
|
| 117 |
+
:return:
|
| 118 |
+
"""
|
| 119 |
+
B, _, H, W = disparity_src.size()
|
| 120 |
+
depth_src = torch.reciprocal(disparity_src)
|
| 121 |
+
xyz_src_B3N = get_xyz_from_depth(meshgrid_homo, depth_src, K_src_inv).view(B, 3, H*W)
|
| 122 |
+
|
| 123 |
+
xyz_tgt_B3N = transform_G_xyz(G_tgt_src, xyz_src_B3N, is_return_homo=False)
|
| 124 |
+
K_xyz_tgt_B3N = torch.matmul(K_tgt, xyz_tgt_B3N)
|
| 125 |
+
pxpy_tgt_B2N = K_xyz_tgt_B3N[:, 0:2, :] / K_xyz_tgt_B3N[:, 2:, :] # Bx2xN
|
| 126 |
+
|
| 127 |
+
pxpy_tgt_mask = torch.logical_and(
|
| 128 |
+
torch.logical_and(pxpy_tgt_B2N[:, 0:1, :] >= 0,
|
| 129 |
+
pxpy_tgt_B2N[:, 0:1, :] <= W - 1),
|
| 130 |
+
torch.logical_and(pxpy_tgt_B2N[:, 1:2, :] >= 0,
|
| 131 |
+
pxpy_tgt_B2N[:, 1:2, :] <= H - 1)
|
| 132 |
+
) # B1N
|
| 133 |
+
|
| 134 |
+
disparity_src = torch.reciprocal(xyz_tgt_B3N[:, 2:, :]) # Bx1xN
|
| 135 |
+
disparity_tgt = gather_pixel_by_pxpy(disparity_tgt, pxpy_tgt_B2N) # Bx1xN
|
| 136 |
+
|
| 137 |
+
depth_diff = torch.abs(disparity_src - disparity_tgt)
|
| 138 |
+
return torch.mean(depth_diff[pxpy_tgt_mask])
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def get_src_xyz_from_plane_disparity(meshgrid_src_homo,
|
| 142 |
+
mpi_disparity_src,
|
| 143 |
+
K_src_inv):
|
| 144 |
+
"""
|
| 145 |
+
|
| 146 |
+
:param meshgrid_src_homo: 3xHxW
|
| 147 |
+
:param mpi_disparity_src: BxS
|
| 148 |
+
:param K_src_inv: Bx3x3
|
| 149 |
+
:return:
|
| 150 |
+
"""
|
| 151 |
+
B, S = mpi_disparity_src.size()
|
| 152 |
+
H, W = meshgrid_src_homo.size(1), meshgrid_src_homo.size(2)
|
| 153 |
+
mpi_depth_src = torch.reciprocal(mpi_disparity_src) # BxS
|
| 154 |
+
|
| 155 |
+
K_src_inv_Bs33 = K_src_inv.unsqueeze(1).repeat(1, S, 1, 1).reshape(B * S, 3, 3)
|
| 156 |
+
|
| 157 |
+
# 3xHxW -> BxSx3xHxW
|
| 158 |
+
meshgrid_src_homo = meshgrid_src_homo.unsqueeze(0).unsqueeze(1).repeat(B, S, 1, 1, 1)
|
| 159 |
+
meshgrid_src_homo_Bs3N = meshgrid_src_homo.reshape(B * S, 3, -1)
|
| 160 |
+
xyz_src = torch.matmul(K_src_inv_Bs33, meshgrid_src_homo_Bs3N) # BSx3xHW
|
| 161 |
+
xyz_src = xyz_src.reshape(B, S, 3, H * W) * mpi_depth_src.unsqueeze(2).unsqueeze(3) # BxSx3xHW
|
| 162 |
+
xyz_src_BS3HW = xyz_src.reshape(B, S, 3, H, W)
|
| 163 |
+
|
| 164 |
+
return xyz_src_BS3HW
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def get_tgt_xyz_from_plane_disparity(xyz_src_BS3HW,
|
| 168 |
+
G_tgt_src):
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
:param xyz_src_BS3HW: BxSx3xHxW
|
| 172 |
+
:param G_tgt_src: Bx4x4
|
| 173 |
+
:return:
|
| 174 |
+
"""
|
| 175 |
+
B, S, _, H, W = xyz_src_BS3HW.size()
|
| 176 |
+
G_tgt_src_Bs33 = G_tgt_src.unsqueeze(1).repeat(1, S, 1, 1).reshape(B*S, 4, 4)
|
| 177 |
+
xyz_tgt = transform_G_xyz(G_tgt_src_Bs33, xyz_src_BS3HW.reshape(B*S, 3, H*W)) # Bsx3xHW
|
| 178 |
+
xyz_tgt_BS3HW = xyz_tgt.reshape(B, S, 3, H, W) # BxSx3xHxW
|
| 179 |
+
return xyz_tgt_BS3HW
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def render_tgt_rgb_depth(H_sampler: HomographySample,
|
| 183 |
+
mpi_rgb_src,
|
| 184 |
+
mpi_sigma_src,
|
| 185 |
+
mpi_disparity_src,
|
| 186 |
+
xyz_tgt_BS3HW,
|
| 187 |
+
G_tgt_src,
|
| 188 |
+
K_src_inv, K_tgt,
|
| 189 |
+
use_alpha=False,
|
| 190 |
+
is_bg_depth_inf=False):
|
| 191 |
+
"""
|
| 192 |
+
:param H_sampler:
|
| 193 |
+
:param mpi_rgb_src: BxSx3xHxW
|
| 194 |
+
:param mpi_sigma_src: BxSx1xHxW
|
| 195 |
+
:param mpi_disparity_src: BxS
|
| 196 |
+
:param xyz_tgt_BS3HW: BxSx3xHxW
|
| 197 |
+
:param G_tgt_src: Bx4x4
|
| 198 |
+
:param K_src_inv: Bx3x3
|
| 199 |
+
:param K_tgt: Bx3x3
|
| 200 |
+
:return:
|
| 201 |
+
"""
|
| 202 |
+
B, S, _, H, W = mpi_rgb_src.size()
|
| 203 |
+
mpi_depth_src = torch.reciprocal(mpi_disparity_src) # BxS
|
| 204 |
+
|
| 205 |
+
# note that here we concat the mpi_src with xyz_tgt, because H_sampler will sample them for tgt frame
|
| 206 |
+
# mpi_src is the same in whatever frame, but xyz has to be in tgt frame
|
| 207 |
+
mpi_xyz_src = torch.cat((mpi_rgb_src, mpi_sigma_src, xyz_tgt_BS3HW), dim=2) # BxSx(3+1+3)xHxW
|
| 208 |
+
|
| 209 |
+
# homography warping of mpi_src into tgt frame
|
| 210 |
+
G_tgt_src_Bs44 = G_tgt_src.unsqueeze(1).repeat(1, S, 1, 1).contiguous().reshape(B*S, 4, 4) # Bsx4x4
|
| 211 |
+
K_src_inv_Bs33 = K_src_inv.unsqueeze(1).repeat(1, S, 1, 1).contiguous().reshape(B*S, 3, 3) # Bsx3x3
|
| 212 |
+
K_tgt_Bs33 = K_tgt.unsqueeze(1).repeat(1, S, 1, 1).contiguous().reshape(B*S, 3, 3) # Bsx3x3
|
| 213 |
+
|
| 214 |
+
# BsxCxHxW, BsxHxW
|
| 215 |
+
tgt_mpi_xyz_BsCHW, tgt_mask_BsHW = H_sampler.sample(mpi_xyz_src.view(B*S, 7, H, W),
|
| 216 |
+
mpi_depth_src.view(B*S),
|
| 217 |
+
G_tgt_src_Bs44,
|
| 218 |
+
K_src_inv_Bs33,
|
| 219 |
+
K_tgt_Bs33)
|
| 220 |
+
|
| 221 |
+
# mpi composition
|
| 222 |
+
tgt_mpi_xyz = tgt_mpi_xyz_BsCHW.view(B, S, 7, H, W)
|
| 223 |
+
tgt_rgb_BS3HW = tgt_mpi_xyz[:, :, 0:3, :, :]
|
| 224 |
+
tgt_sigma_BS1HW = tgt_mpi_xyz[:, :, 3:4, :, :]
|
| 225 |
+
tgt_xyz_BS3HW = tgt_mpi_xyz[:, :, 4:, :, :]
|
| 226 |
+
|
| 227 |
+
tgt_mask_BSHW = tgt_mask_BsHW.view(B, S, H, W)
|
| 228 |
+
tgt_mask_BSHW = torch.where(tgt_mask_BSHW,
|
| 229 |
+
torch.ones((B, S, H, W), dtype=torch.float32, device=mpi_rgb_src.device),
|
| 230 |
+
torch.zeros((B, S, H, W), dtype=torch.float32, device=mpi_rgb_src.device))
|
| 231 |
+
|
| 232 |
+
# Bx3xHxW, Bx1xHxW, Bx1xHxW
|
| 233 |
+
tgt_z_BS1HW = tgt_xyz_BS3HW[:, :, -1:]
|
| 234 |
+
tgt_sigma_BS1HW = torch.where(tgt_z_BS1HW >= 0,
|
| 235 |
+
tgt_sigma_BS1HW,
|
| 236 |
+
torch.zeros_like(tgt_sigma_BS1HW, device=tgt_sigma_BS1HW.device))
|
| 237 |
+
tgt_rgb_syn, tgt_depth_syn, _, _ = render(tgt_rgb_BS3HW, tgt_sigma_BS1HW, tgt_xyz_BS3HW,
|
| 238 |
+
use_alpha=use_alpha,
|
| 239 |
+
is_bg_depth_inf=is_bg_depth_inf)
|
| 240 |
+
tgt_mask = torch.sum(tgt_mask_BSHW, dim=1, keepdim=True) # Bx1xHxW
|
| 241 |
+
|
| 242 |
+
return tgt_rgb_syn, tgt_depth_syn, tgt_mask
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def predict_mpi_coarse_to_fine(mpi_predictor, src_imgs, xyz_src_BS3HW_coarse,
|
| 246 |
+
disparity_coarse_src, S_fine, is_bg_depth_inf):
|
| 247 |
+
if S_fine > 0:
|
| 248 |
+
with torch.no_grad():
|
| 249 |
+
# predict coarse mpi
|
| 250 |
+
mpi_coarse_src_list = mpi_predictor(src_imgs, disparity_coarse_src) # BxS_coarsex4xHxW
|
| 251 |
+
mpi_coarse_rgb_src = mpi_coarse_src_list[0][:, :, 0:3, :, :] # BxSx1xHxW
|
| 252 |
+
mpi_coarse_sigma_src = mpi_coarse_src_list[0][:, :, 3:, :, :] # BxSx1xHxW
|
| 253 |
+
_, _, _, weights = plane_volume_rendering(
|
| 254 |
+
mpi_coarse_rgb_src,
|
| 255 |
+
mpi_coarse_sigma_src,
|
| 256 |
+
xyz_src_BS3HW_coarse,
|
| 257 |
+
is_bg_depth_inf
|
| 258 |
+
)
|
| 259 |
+
weights = weights.mean((2, 3, 4)).unsqueeze(1).unsqueeze(2)
|
| 260 |
+
|
| 261 |
+
# sample fine disparity
|
| 262 |
+
disparity_fine_src = sample_pdf(disparity_coarse_src.unsqueeze(1).unsqueeze(2), weights, S_fine)
|
| 263 |
+
disparity_fine_src = disparity_fine_src.squeeze(2).squeeze(1)
|
| 264 |
+
|
| 265 |
+
# assemble coarse and fine disparity
|
| 266 |
+
disparity_all_src = torch.cat((disparity_coarse_src, disparity_fine_src), dim=1) # Bx(S_coarse + S_fine)
|
| 267 |
+
disparity_all_src, _ = torch.sort(disparity_all_src, dim=1, descending=True)
|
| 268 |
+
mpi_all_src_list = mpi_predictor(src_imgs, disparity_all_src) # BxS_coarsex4xHxW
|
| 269 |
+
return mpi_all_src_list, disparity_all_src
|
| 270 |
+
else:
|
| 271 |
+
mpi_coarse_src_list = mpi_predictor(src_imgs, disparity_coarse_src) # BxS_coarsex4xHxW
|
| 272 |
+
return mpi_coarse_src_list, disparity_coarse_src
|
utils/mpi/rendering_utils.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def transform_G_xyz(G, xyz, is_return_homo=False):
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
:param G: Bx4x4
|
| 8 |
+
:param xyz: Bx3xN
|
| 9 |
+
:return:
|
| 10 |
+
"""
|
| 11 |
+
assert len(G.size()) == len(xyz.size())
|
| 12 |
+
if len(G.size()) == 2:
|
| 13 |
+
G_B44 = G.unsqueeze(0)
|
| 14 |
+
xyz_B3N = xyz.unsqueeze(0)
|
| 15 |
+
else:
|
| 16 |
+
G_B44 = G
|
| 17 |
+
xyz_B3N = xyz
|
| 18 |
+
xyz_B4N = torch.cat((xyz_B3N, torch.ones_like(xyz_B3N[:, 0:1, :])), dim=1)
|
| 19 |
+
G_xyz_B4N = torch.matmul(G_B44, xyz_B4N)
|
| 20 |
+
if is_return_homo:
|
| 21 |
+
return G_xyz_B4N
|
| 22 |
+
else:
|
| 23 |
+
return G_xyz_B4N[:, 0:3, :]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def gather_pixel_by_pxpy(img, pxpy):
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
:param img: Bx3xHxW
|
| 30 |
+
:param pxpy: Bx2xN
|
| 31 |
+
:return:
|
| 32 |
+
"""
|
| 33 |
+
with torch.no_grad():
|
| 34 |
+
B, C, H, W = img.size()
|
| 35 |
+
if pxpy.dtype == torch.float32:
|
| 36 |
+
pxpy_int = torch.round(pxpy).to(torch.int64)
|
| 37 |
+
pxpy_int = pxpy_int.to(torch.int64)
|
| 38 |
+
pxpy_int[:, 0, :] = torch.clamp(pxpy_int[:, 0, :], min=0, max=W-1)
|
| 39 |
+
pxpy_int[:, 1, :] = torch.clamp(pxpy_int[:, 1, :], min=0, max=H-1)
|
| 40 |
+
pxpy_idx = pxpy_int[:, 0:1, :] + W * pxpy_int[:, 1:2, :] # Bx1xN_pt
|
| 41 |
+
rgb = torch.gather(img.view(B, C, H * W), dim=2,
|
| 42 |
+
index=pxpy_idx.repeat(1, C, 1)) # BxCxN_pt
|
| 43 |
+
return rgb
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def uniformly_sample_disparity_from_bins(batch_size, disparity_np, device):
|
| 47 |
+
"""
|
| 48 |
+
In the disparity dimension, it has to be from large to small, i.e., depth from small (near) to large (far)
|
| 49 |
+
:param start:
|
| 50 |
+
:param end:
|
| 51 |
+
:param num_bins:
|
| 52 |
+
:return:
|
| 53 |
+
"""
|
| 54 |
+
assert disparity_np[0] > disparity_np[-1]
|
| 55 |
+
S = disparity_np.shape[0] - 1
|
| 56 |
+
|
| 57 |
+
B = batch_size
|
| 58 |
+
bin_edges = torch.from_numpy(disparity_np).to(dtype=torch.float32, device=device) # S+1
|
| 59 |
+
interval = bin_edges[1:] - bin_edges[0:-1] # S
|
| 60 |
+
bin_edges_start = bin_edges[0:-1].unsqueeze(0).repeat(B, 1) # S -> BxS
|
| 61 |
+
# bin_edges_end = bin_edges[1:].unsqueeze(0).repeat(B, 1) # S -> BxS
|
| 62 |
+
interval = interval.unsqueeze(0).repeat(B, 1) # S -> BxS
|
| 63 |
+
|
| 64 |
+
random_float = torch.rand((B, S), dtype=torch.float32, device=device) # BxS
|
| 65 |
+
disparity_array = bin_edges_start + interval * random_float
|
| 66 |
+
return disparity_array # BxS
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def uniformly_sample_disparity_from_linspace_bins(batch_size, num_bins, start, end, device):
|
| 70 |
+
"""
|
| 71 |
+
In the disparity dimension, it has to be from large to small, i.e., depth from small (near) to large (far)
|
| 72 |
+
:param start:
|
| 73 |
+
:param end:
|
| 74 |
+
:param num_bins:
|
| 75 |
+
:return:
|
| 76 |
+
"""
|
| 77 |
+
assert start > end
|
| 78 |
+
|
| 79 |
+
B, S = batch_size, num_bins
|
| 80 |
+
bin_edges = torch.linspace(start, end, num_bins+1, dtype=torch.float32, device=device) # S+1
|
| 81 |
+
interval = bin_edges[1] - bin_edges[0] # scalar
|
| 82 |
+
bin_edges_start = bin_edges[0:-1].unsqueeze(0).repeat(B, 1) # S -> BxS
|
| 83 |
+
# bin_edges_end = bin_edges[1:].unsqueeze(0).repeat(B, 1) # S -> BxS
|
| 84 |
+
|
| 85 |
+
random_float = torch.rand((B, S), dtype=torch.float32, device=device) # BxS
|
| 86 |
+
disparity_array = bin_edges_start + interval * random_float
|
| 87 |
+
return disparity_array # BxS
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def sample_pdf(values, weights, N_samples):
|
| 91 |
+
"""
|
| 92 |
+
draw samples from distribution approximated by values and weights.
|
| 93 |
+
the probability distribution can be denoted as weights = p(values)
|
| 94 |
+
:param values: Bx1xNxS
|
| 95 |
+
:param weights: Bx1xNxS
|
| 96 |
+
:param N_samples: number of sample to draw
|
| 97 |
+
:return:
|
| 98 |
+
"""
|
| 99 |
+
B, N, S = weights.size(0), weights.size(2), weights.size(3)
|
| 100 |
+
assert values.size() == (B, 1, N, S)
|
| 101 |
+
|
| 102 |
+
# convert values to bin edges
|
| 103 |
+
bin_edges = (values[:, :, :, 1:] + values[:, :, :, :-1]) * 0.5 # Bx1xNxS-1
|
| 104 |
+
bin_edges = torch.cat((values[:, :, :, 0:1],
|
| 105 |
+
bin_edges,
|
| 106 |
+
values[:, :, :, -1:]), dim=3) # Bx1xNxS+1
|
| 107 |
+
|
| 108 |
+
pdf = weights / (torch.sum(weights, dim=3, keepdim=True) + 1e-5) # Bx1xNxS
|
| 109 |
+
cdf = torch.cumsum(pdf, dim=3) # Bx1xNxS
|
| 110 |
+
cdf = torch.cat((torch.zeros((B, 1, N, 1), dtype=cdf.dtype, device=cdf.device),
|
| 111 |
+
cdf), dim=3) # Bx1xNxS+1
|
| 112 |
+
|
| 113 |
+
# uniform sample over the cdf values
|
| 114 |
+
u = torch.rand((B, 1, N, N_samples), dtype=weights.dtype, device=weights.device) # Bx1xNxN_samples
|
| 115 |
+
|
| 116 |
+
# get the index on the cdf array
|
| 117 |
+
cdf_idx = torch.searchsorted(cdf, u, right=True) # Bx1xNxN_samples
|
| 118 |
+
cdf_idx_lower = torch.clamp(cdf_idx-1, min=0) # Bx1xNxN_samples
|
| 119 |
+
cdf_idx_upper = torch.clamp(cdf_idx, max=S) # Bx1xNxN_samples
|
| 120 |
+
|
| 121 |
+
# linear approximation for each bin
|
| 122 |
+
cdf_idx_lower_upper = torch.cat((cdf_idx_lower, cdf_idx_upper), dim=3) # Bx1xNx(N_samplesx2)
|
| 123 |
+
cdf_bounds_N2 = torch.gather(cdf, index=cdf_idx_lower_upper, dim=3) # Bx1xNx(N_samplesx2)
|
| 124 |
+
cdf_bounds = torch.stack((cdf_bounds_N2[..., 0:N_samples], cdf_bounds_N2[..., N_samples:]), dim=4)
|
| 125 |
+
bin_bounds_N2 = torch.gather(bin_edges, index=cdf_idx_lower_upper, dim=3) # Bx1xNx(N_samplesx2)
|
| 126 |
+
bin_bounds = torch.stack((bin_bounds_N2[..., 0:N_samples], bin_bounds_N2[..., N_samples:]), dim=4)
|
| 127 |
+
|
| 128 |
+
# avoid zero cdf_intervals
|
| 129 |
+
cdf_intervals = cdf_bounds[:, :, :, :, 1] - cdf_bounds[:, :, :, :, 0] # Bx1xNxN_samples
|
| 130 |
+
bin_intervals = bin_bounds[:, :, :, :, 1] - bin_bounds[:, :, :, :, 0] # Bx1xNxN_samples
|
| 131 |
+
u_cdf_lower = u - cdf_bounds[:, :, :, :, 0] # Bx1xNxN_samples
|
| 132 |
+
# there is the case that cdf_interval = 0, caused by the cdf_idx_lower/upper clamp above, need special handling
|
| 133 |
+
t = u_cdf_lower / torch.clamp(cdf_intervals, min=1e-5)
|
| 134 |
+
t = torch.where(cdf_intervals <= 1e-4,
|
| 135 |
+
torch.full_like(u_cdf_lower, 0.5),
|
| 136 |
+
t)
|
| 137 |
+
|
| 138 |
+
samples = bin_bounds[:, :, :, :, 0] + t*bin_intervals
|
| 139 |
+
return samples
|
utils/rendererBackbone.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# System Imports
|
| 2 |
+
import os
|
| 3 |
+
import math
|
| 4 |
+
import argparse
|
| 5 |
+
import time
|
| 6 |
+
|
| 7 |
+
# Common Libs
|
| 8 |
+
import numpy as np
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
import cv2
|
| 11 |
+
import tkinter as tk
|
| 12 |
+
import threading
|
| 13 |
+
import queue
|
| 14 |
+
|
| 15 |
+
# Torch Imports
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
from torchvision import transforms
|
| 19 |
+
from torchvision.utils import save_image
|
| 20 |
+
|
| 21 |
+
# 3rd party imports
|
| 22 |
+
from transformers import DPTForDepthEstimation, DPTImageProcessor
|
| 23 |
+
from tqdm import tqdm
|
| 24 |
+
import mediapipe as mp
|
| 25 |
+
from PIL import Image, ImageTk
|
| 26 |
+
from moviepy.editor import ImageSequenceClip
|
| 27 |
+
|
| 28 |
+
# From Codebase
|
| 29 |
+
from utils.mpi import mpi_rendering
|
| 30 |
+
from utils.mpi.homography_sampler import HomographySample
|
| 31 |
+
from utils.mpi.homography_sampler import HomographySample
|
| 32 |
+
from utils.utils import (
|
| 33 |
+
image_to_tensor,
|
| 34 |
+
disparity_to_tensor,
|
| 35 |
+
render_3dphoto,
|
| 36 |
+
render_novel_view,
|
| 37 |
+
)
|
| 38 |
+
from model.AdaMPI import MPIPredictor
|
| 39 |
+
from parameters import *
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
#=================================================
|
| 43 |
+
# Define the MPI Layers Processing Module Here
|
| 44 |
+
#=================================================
|
| 45 |
+
def processMPIs(src_imgs, mpi_all_src, disparity_all_src, k_src, k_tgt, save_path=None):
|
| 46 |
+
h, w = mpi_all_src.shape[-2:]
|
| 47 |
+
device = mpi_all_src.device
|
| 48 |
+
homography_sampler = HomographySample(h, w, device)
|
| 49 |
+
k_src_inv = torch.inverse(k_src)
|
| 50 |
+
|
| 51 |
+
# preprocess the predict MPI
|
| 52 |
+
xyz_src_BS3HW = mpi_rendering.get_src_xyz_from_plane_disparity(
|
| 53 |
+
homography_sampler.meshgrid,
|
| 54 |
+
disparity_all_src,
|
| 55 |
+
k_src_inv,
|
| 56 |
+
)
|
| 57 |
+
mpi_all_rgb_src = mpi_all_src[:, :, 0:3, :, :] # BxSx3xHxW
|
| 58 |
+
mpi_all_sigma_src = mpi_all_src[:, :, 3:, :, :] # BxSx1xHxW
|
| 59 |
+
_, _, blend_weights, _ = mpi_rendering.render(
|
| 60 |
+
mpi_all_rgb_src,
|
| 61 |
+
mpi_all_sigma_src,
|
| 62 |
+
xyz_src_BS3HW,
|
| 63 |
+
use_alpha=False,
|
| 64 |
+
is_bg_depth_inf=False,
|
| 65 |
+
)
|
| 66 |
+
mpi_all_rgb_src = blend_weights * src_imgs.unsqueeze(1) + (1 - blend_weights) * mpi_all_rgb_src
|
| 67 |
+
|
| 68 |
+
return mpi_all_rgb_src, mpi_all_sigma_src, disparity_all_src, k_src_inv,k_tgt,homography_sampler
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def cropFOV(image, original_fov, new_fov):
|
| 73 |
+
image = np.array(image)
|
| 74 |
+
if new_fov >= original_fov:
|
| 75 |
+
raise ValueError("New FoV must be smaller than the original FoV")
|
| 76 |
+
|
| 77 |
+
crop_ratio = new_fov / original_fov
|
| 78 |
+
height, width = image.shape[:2]
|
| 79 |
+
|
| 80 |
+
new_width = int(width * crop_ratio)
|
| 81 |
+
new_height = int(height * crop_ratio)
|
| 82 |
+
|
| 83 |
+
start_x = (width - new_width) // 2
|
| 84 |
+
start_y = (height - new_height) // 2
|
| 85 |
+
|
| 86 |
+
cropped_image = image[start_y:start_y + new_height, start_x:start_x + new_width]
|
| 87 |
+
cropped_image = Image.fromarray(cropped_image)
|
| 88 |
+
return cropped_image
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def renderSingleFrame(mpi_all_rgb_src, mpi_all_sigma_src, disparity_all_src, cam_ext, k_src_inv, k_tgt, homography_sampler):
|
| 93 |
+
frame = render_novel_view(
|
| 94 |
+
mpi_all_rgb_src,
|
| 95 |
+
mpi_all_sigma_src,
|
| 96 |
+
disparity_all_src,
|
| 97 |
+
cam_ext.to(device),
|
| 98 |
+
k_src_inv,
|
| 99 |
+
k_tgt,
|
| 100 |
+
homography_sampler,
|
| 101 |
+
)
|
| 102 |
+
frame_np = frame[0].permute(1, 2, 0).contiguous().cpu().numpy() # [b,h,w,3]
|
| 103 |
+
frame_np = np.clip(np.round(frame_np * 255), a_min=0, a_max=255).astype(np.uint8)
|
| 104 |
+
im = Image.fromarray(frame_np)
|
| 105 |
+
return im
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class VideoCapture:
|
| 109 |
+
def __init__(self, name):
|
| 110 |
+
self.cap = cv2.VideoCapture(name)
|
| 111 |
+
self.q = queue.Queue()
|
| 112 |
+
t = threading.Thread(target=self._reader)
|
| 113 |
+
t.daemon = True
|
| 114 |
+
t.start()
|
| 115 |
+
|
| 116 |
+
def _reader(self):
|
| 117 |
+
while True:
|
| 118 |
+
ret, frame = self.cap.read()
|
| 119 |
+
if not ret:
|
| 120 |
+
break
|
| 121 |
+
if not self.q.empty():
|
| 122 |
+
try:
|
| 123 |
+
self.q.get_nowait()
|
| 124 |
+
except queue.Empty:
|
| 125 |
+
pass
|
| 126 |
+
self.q.put(frame)
|
| 127 |
+
|
| 128 |
+
def read(self):
|
| 129 |
+
return self.q.get()
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def captureBackground(capture_device):
|
| 134 |
+
frame_background = capture_device.read()
|
| 135 |
+
img = cv2.cvtColor(frame_background, cv2.COLOR_BGR2RGB)
|
| 136 |
+
im_pil = Image.fromarray(img)
|
| 137 |
+
return im_pil
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def getImageTensor(pil_image, height, width, unsqueeze=True):
|
| 142 |
+
t = transforms.Compose([transforms.CenterCrop((height, width)),transforms.ToTensor()])
|
| 143 |
+
rgb = t(pil_image)
|
| 144 |
+
|
| 145 |
+
if unsqueeze:
|
| 146 |
+
rgb = rgb.unsqueeze(0)
|
| 147 |
+
return rgb
|
utils/utils.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import math
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import cv2
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torchvision import transforms
|
| 9 |
+
from torchvision.utils import save_image
|
| 10 |
+
import numpy as np
|
| 11 |
+
from moviepy.editor import ImageSequenceClip
|
| 12 |
+
|
| 13 |
+
from utils.mpi import mpi_rendering
|
| 14 |
+
from utils.mpi.homography_sampler import HomographySample
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def image_to_tensor(img_path, unsqueeze=True):
|
| 18 |
+
rgb = transforms.ToTensor()(Image.open(img_path))
|
| 19 |
+
if unsqueeze:
|
| 20 |
+
rgb = rgb.unsqueeze(0)
|
| 21 |
+
return rgb
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def disparity_to_tensor(disp_path, unsqueeze=True):
|
| 25 |
+
disp = cv2.imread(disp_path, -1) / (2 ** 16 - 1)
|
| 26 |
+
disp = torch.from_numpy(disp)[None, ...]
|
| 27 |
+
if unsqueeze:
|
| 28 |
+
disp = disp.unsqueeze(0)
|
| 29 |
+
return disp.float()
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def gen_swing_path(num_frames=90, r_x=0.14, r_y=0., r_z=0.10):
|
| 33 |
+
"Return a list of matrix [4, 4]"
|
| 34 |
+
t = torch.arange(num_frames) / (num_frames - 1)
|
| 35 |
+
poses = torch.eye(4).repeat(num_frames, 1, 1)
|
| 36 |
+
poses[:, 0, 3] = r_x * torch.sin(2. * math.pi * t)
|
| 37 |
+
poses[:, 1, 3] = r_y * torch.cos(2. * math.pi * t)
|
| 38 |
+
poses[:, 2, 3] = r_z * (torch.cos(2. * math.pi * t) - 1.)
|
| 39 |
+
return poses.unbind()
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def render_3dphoto(
|
| 43 |
+
src_imgs, # [b,3,h,w]
|
| 44 |
+
mpi_all_src, # [b,s,4,h,w]
|
| 45 |
+
disparity_all_src, # [b,s]
|
| 46 |
+
k_src, # [b,3,3]
|
| 47 |
+
k_tgt, # [b,3,3]
|
| 48 |
+
save_path,
|
| 49 |
+
):
|
| 50 |
+
h, w = mpi_all_src.shape[-2:]
|
| 51 |
+
device = mpi_all_src.device
|
| 52 |
+
homography_sampler = HomographySample(h, w, device)
|
| 53 |
+
k_src_inv = torch.inverse(k_src)
|
| 54 |
+
|
| 55 |
+
# preprocess the predict MPI
|
| 56 |
+
xyz_src_BS3HW = mpi_rendering.get_src_xyz_from_plane_disparity(
|
| 57 |
+
homography_sampler.meshgrid,
|
| 58 |
+
disparity_all_src,
|
| 59 |
+
k_src_inv,
|
| 60 |
+
)
|
| 61 |
+
mpi_all_rgb_src = mpi_all_src[:, :, 0:3, :, :] # BxSx3xHxW
|
| 62 |
+
mpi_all_sigma_src = mpi_all_src[:, :, 3:, :, :] # BxSx1xHxW
|
| 63 |
+
_, _, blend_weights, _ = mpi_rendering.render(
|
| 64 |
+
mpi_all_rgb_src,
|
| 65 |
+
mpi_all_sigma_src,
|
| 66 |
+
xyz_src_BS3HW,
|
| 67 |
+
use_alpha=False,
|
| 68 |
+
is_bg_depth_inf=False,
|
| 69 |
+
)
|
| 70 |
+
mpi_all_rgb_src = blend_weights * src_imgs.unsqueeze(1) + (1 - blend_weights) * mpi_all_rgb_src
|
| 71 |
+
|
| 72 |
+
# render novel views
|
| 73 |
+
swing_path_list = gen_swing_path()
|
| 74 |
+
frames = []
|
| 75 |
+
for cam_ext in tqdm(swing_path_list):
|
| 76 |
+
frame = render_novel_view(
|
| 77 |
+
mpi_all_rgb_src,
|
| 78 |
+
mpi_all_sigma_src,
|
| 79 |
+
disparity_all_src,
|
| 80 |
+
cam_ext,
|
| 81 |
+
k_src_inv,
|
| 82 |
+
k_tgt,
|
| 83 |
+
homography_sampler,
|
| 84 |
+
)
|
| 85 |
+
frame_np = frame[0].permute(1, 2, 0).contiguous().cpu().numpy() # [b,h,w,3]
|
| 86 |
+
frame_np = np.clip(np.round(frame_np * 255), a_min=0, a_max=255).astype(np.uint8)
|
| 87 |
+
frames.append(frame_np)
|
| 88 |
+
rgb_clip = ImageSequenceClip(frames, fps=30)
|
| 89 |
+
rgb_clip.write_videofile(save_path, verbose=False, codec='mpeg4', logger=None, bitrate='2000k')
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def render_novel_view(
|
| 93 |
+
mpi_all_rgb_src,
|
| 94 |
+
mpi_all_sigma_src,
|
| 95 |
+
disparity_all_src,
|
| 96 |
+
G_tgt_src,
|
| 97 |
+
K_src_inv,
|
| 98 |
+
K_tgt,
|
| 99 |
+
homography_sampler,
|
| 100 |
+
):
|
| 101 |
+
xyz_src_BS3HW = mpi_rendering.get_src_xyz_from_plane_disparity(
|
| 102 |
+
homography_sampler.meshgrid,
|
| 103 |
+
disparity_all_src,
|
| 104 |
+
K_src_inv
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
xyz_tgt_BS3HW = mpi_rendering.get_tgt_xyz_from_plane_disparity(
|
| 108 |
+
xyz_src_BS3HW,
|
| 109 |
+
G_tgt_src
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
tgt_imgs_syn, _, _ = mpi_rendering.render_tgt_rgb_depth(
|
| 113 |
+
homography_sampler,
|
| 114 |
+
mpi_all_rgb_src,
|
| 115 |
+
mpi_all_sigma_src,
|
| 116 |
+
disparity_all_src,
|
| 117 |
+
xyz_tgt_BS3HW,
|
| 118 |
+
G_tgt_src,
|
| 119 |
+
K_src_inv,
|
| 120 |
+
K_tgt,
|
| 121 |
+
use_alpha=False,
|
| 122 |
+
is_bg_depth_inf=False,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
return tgt_imgs_syn
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class AverageMeter(object):
|
| 129 |
+
"""Computes and stores the average and current value"""
|
| 130 |
+
def __init__(self, name, fmt=":f"):
|
| 131 |
+
self.name = name
|
| 132 |
+
self.fmt = fmt
|
| 133 |
+
self.reset()
|
| 134 |
+
|
| 135 |
+
def reset(self):
|
| 136 |
+
self.val = 0
|
| 137 |
+
self.avg = 0
|
| 138 |
+
self.sum = 0
|
| 139 |
+
self.count = 0
|
| 140 |
+
|
| 141 |
+
def update(self, val, n=1):
|
| 142 |
+
self.val = val
|
| 143 |
+
self.sum += val * n
|
| 144 |
+
self.count += n
|
| 145 |
+
self.avg = self.sum / self.count
|
| 146 |
+
|
| 147 |
+
def __str__(self):
|
| 148 |
+
# fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
|
| 149 |
+
# return fmtstr.format(**self.__dict__)
|
| 150 |
+
return f"{self.name:s}: {self.avg:.6f}"
|