Spaces:
Running
Running
initial commit
Browse files- .gitattributes +1 -1
- .gitignore +21 -0
- .huggingface.yaml +3 -0
- README.md +97 -14
- __init__.py +0 -0
- app.py +238 -0
- depth_only_parameters.py +21 -0
- helperFunctions.py +26 -0
- helper_image_functions.py +290 -0
- models/__init__.py +0 -0
- models/depth_only_lite_model.py +234 -0
- models/depth_only_model.py +232 -0
- requirements.txt +122 -0
- rff_torch.py +53 -0
- utils.py +243 -0
.gitattributes
CHANGED
|
@@ -32,4 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 32 |
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 32 |
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# Models and Engines
|
| 7 |
+
*.onnx
|
| 8 |
+
*.onnx.data
|
| 9 |
+
*.pth
|
| 10 |
+
*.engine
|
| 11 |
+
|
| 12 |
+
# Images
|
| 13 |
+
*.png
|
| 14 |
+
*.jpeg
|
| 15 |
+
*.JPG
|
| 16 |
+
|
| 17 |
+
# Videos
|
| 18 |
+
*.mp4
|
| 19 |
+
|
| 20 |
+
# Logs
|
| 21 |
+
logs/
|
.huggingface.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
sdk: gradio
|
| 2 |
+
python_version: '3.12'
|
| 3 |
+
requirements_file: requirements.txt
|
README.md
CHANGED
|
@@ -1,14 +1,97 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div align="center">
|
| 2 |
+
<a href="#"><img src='https://img.shields.io/badge/-Paper-00629B?style=flat&logo=ieee&logoColor=white' alt='arXiv'></a>
|
| 3 |
+
<a href='https://realistic3d-miun.github.io/PVSDNet/'><img src='https://img.shields.io/badge/Project_Page-Website-green?logo=googlechrome&logoColor=white' alt='Project Page'></a>
|
| 4 |
+
<a href='https://huggingface.co/spaces/3ZadeSSG/PVSDNet'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Demo_(Coming_Soon)-blue'></a>
|
| 5 |
+
</div>
|
| 6 |
+
|
| 7 |
+
# PVSDNet: Joint Depth Prediction and View Synthesis via Shared Latent Spaces in Real-Time.
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
## Supplementary Video (Head to Project Page for more visual results)
|
| 11 |
+
[](https://youtu.be/49s2UPvRA6I)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# 1. PVSDNet - Joint Depth and View
|
| 15 |
+
**Note:** Will be added soon.
|
| 16 |
+
|
| 17 |
+
## 1.A. Normal Inference (Recommended for minimal setup)
|
| 18 |
+
**Note:** Will be added soon.
|
| 19 |
+
|
| 20 |
+
## 2.A. Faster Inference (For best possible FPS)
|
| 21 |
+
**Note:** Will be added soon.
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# 2. PVSDNet Depth-Only Model
|
| 26 |
+
This model is a variant of the original PVSDNet model, where we only predict depth and not the target views. The model core is similar except the rendering network and the positional encoding are removed.
|
| 27 |
+
|
| 28 |
+
* Download the checkpoints from following table and place them in `checkpoint_onnx` directory.
|
| 29 |
+
|
| 30 |
+
| Model | Size | Checkpoint |
|
| 31 |
+
|-----------------|--------|----------------|
|
| 32 |
+
| PVSDNet-Depth-Only | 1.11 GB| [Download](https://huggingface.co/3ZadeSSG/PVSDNet-Depth-Only/resolve/main/depth_only_model.pth) |
|
| 33 |
+
| PVSDNet-Depth-Only-Lite | 279 MB | [Download](https://huggingface.co/3ZadeSSG/PVSDNet-Depth-Only/resolve/main/depth_only_lite_model.pth) |
|
| 34 |
+
|
| 35 |
+
## 2.A. Normal Inference (Recommended for minimal setup)
|
| 36 |
+
|
| 37 |
+
## 2.B. Faster Inference (For best possible FPS)
|
| 38 |
+
You need to setup your own TRT Engine for this purpose.
|
| 39 |
+
|
| 40 |
+
* Make sure you modify the `depth_only_parameters` to set resolution you need. By default we have kept it at `384x384`.
|
| 41 |
+
|
| 42 |
+
* Run `export_onnx_depth.py` to conver the normal pytorch models located into into onnx
|
| 43 |
+
```
|
| 44 |
+
python export_onnx_depth.py
|
| 45 |
+
```
|
| 46 |
+
* Create TRT Engine directory
|
| 47 |
+
```
|
| 48 |
+
mkdir TRT_Engine
|
| 49 |
+
```
|
| 50 |
+
* Build the TRT engine based on created onnx files (which by default will be located in `checkpoint_onnx`)
|
| 51 |
+
```
|
| 52 |
+
trtexec --onnx=./checkpoint_onnx/depth_only_model.onnx --saveEngine=./TRT_Engine/depth_only_model_fp16.engine --fp16
|
| 53 |
+
```
|
| 54 |
+
```
|
| 55 |
+
trtexec --onnx=./checkpoint_onnx/depth_only_lite_model.onnx --saveEngine=./TRT_Engine/depth_only_lite_model_fp16.engine --fp16
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
## 2.C. Predicting on Depth Datasets using Multi-Resolution Fusion
|
| 60 |
+
|
| 61 |
+
We run the scripts inside the `depth_dataset_predictor` directory. There are two sample images for each dataset to test the code.
|
| 62 |
+
* First we build the TRT engine for each dataset as we use multi-resolution fusion.
|
| 63 |
+
```
|
| 64 |
+
python depth_dataset_predictor/build_trt_<dataset_name>.py
|
| 65 |
+
```
|
| 66 |
+
* Then we run the prediction script
|
| 67 |
+
```
|
| 68 |
+
python depth_dataset_predictor/predict_<dataset_name>_TensorRT.py
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
|Dataset|Setp 1|Step 2|
|
| 72 |
+
|---|---|---|
|
| 73 |
+
|ETH3D| ```python depth_dataset_predictor/build_trt_ETH3D.py``` | ```python depth_dataset_predictor/predict_ETH3D_TensorRT.py```|
|
| 74 |
+
|Sintel| ```python depth_dataset_predictor/build_trt_Sintel.py``` | ```python depth_dataset_predictor/predict_Sintel_TensorRT.py```|
|
| 75 |
+
|KITTI| ```python depth_dataset_predictor/build_trt_KITTI.py``` | ```python depth_dataset_predictor/predict_KITTI_TensorRT.py```|
|
| 76 |
+
|DIODE| ```python depth_dataset_predictor/build_trt_DIODE.py``` | ```python depth_dataset_predictor/predict_DIODE_TensorRT.py```|
|
| 77 |
+
|NYU| ```python depth_dataset_predictor/build_trt_NYU.py``` | ```python depth_dataset_predictor/predict_NYU_TensorRT.py```|
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
## 2.D. Predicting on 1080p In-The-Wild Images/Videos using Multi-Resolution Fusion
|
| 81 |
+
Similar to dataset, we can use the mutli-resolution fusion to predict on 1080p In-The-Wild Images/Videos.
|
| 82 |
+
|
| 83 |
+
* First we build the trt engine
|
| 84 |
+
```
|
| 85 |
+
python depth_in_wild_predictor/build_trt_1080p.py
|
| 86 |
+
```
|
| 87 |
+
* Then we run the prediction script for images
|
| 88 |
+
```
|
| 89 |
+
python depth_in_wild_predictor/predict_1080p_TensorRT.py
|
| 90 |
+
```
|
| 91 |
+
OR, run the prediction script for videos
|
| 92 |
+
```
|
| 93 |
+
python depth_in_wild_predictor/predict_video_1080p_TensorRT.py
|
| 94 |
+
```
|
| 95 |
+
#### Note
|
| 96 |
+
* For any other resolutions, you can modify the resolutions in these above scripts to suit your needs. We have kept the default resolution as 1080p for this example.
|
| 97 |
+
* We recommend 3-6 resolutions for best results, but you can use 1-2 smaller resolutions if working with low reoslution images/videos since receptive field of the network can handle that without any issues.
|
__init__.py
ADDED
|
File without changes
|
app.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import torchvision.transforms as transforms
|
| 8 |
+
import depth_only_parameters as params
|
| 9 |
+
|
| 10 |
+
from models.depth_only_model import PVSDNet
|
| 11 |
+
from models.depth_only_lite_model import PVSDNet_Lite
|
| 12 |
+
|
| 13 |
+
import helperFunctions as helper
|
| 14 |
+
import socket
|
| 15 |
+
from huggingface_hub import hf_hub_download
|
| 16 |
+
import joblib
|
| 17 |
+
|
| 18 |
+
REPO_ID = "3ZadeSSG/PVSDNet-Depth-Only"
|
| 19 |
+
print("Downloading/Loading checkpoints from Hugging Face Hub...")
|
| 20 |
+
params.MODEL_Small_Location = hf_hub_download(
|
| 21 |
+
repo_id=REPO_ID,
|
| 22 |
+
filename="depth_only_lite_model.pth"
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
params.MODEL_Large_Location = hf_hub_download(
|
| 26 |
+
repo_id=REPO_ID,
|
| 27 |
+
filename="depth_only_model.pth"
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
print(f"Large Model loaded at: {params.MODEL_Large_Location}")
|
| 31 |
+
print(f"Lite Model loaded at: {params.MODEL_Small_Location}")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_valid_resolutions(width, height):
|
| 35 |
+
"""Dynamically determines valid resolutions based on input size.
|
| 36 |
+
- Caps the highest resolution at 1024px to avoid unnecessary high-res computations.
|
| 37 |
+
- Uses 6 resolutions for large images to improve multi-scale fusion quality.
|
| 38 |
+
- Uses 4 resolutions for smaller images (< 512px width or height).
|
| 39 |
+
"""
|
| 40 |
+
def make_divisible(n, base=16):
|
| 41 |
+
return max(base, int(round(n / base) * base))
|
| 42 |
+
|
| 43 |
+
max_resolution = 1024
|
| 44 |
+
high_w, high_h = make_divisible(min(width, max_resolution)), make_divisible(min(height, max_resolution))
|
| 45 |
+
|
| 46 |
+
# Calculate more intermediate steps for better fusion
|
| 47 |
+
r80_w, r80_h = make_divisible(int(high_w // 1.25)), make_divisible(int(high_h // 1.25))
|
| 48 |
+
r66_w, r66_h = make_divisible(int(high_w // 1.5)), make_divisible(int(high_h // 1.5))
|
| 49 |
+
r50_w, r50_h = make_divisible(int(high_w // 2)), make_divisible(int(high_h // 2))
|
| 50 |
+
r40_w, r40_h = make_divisible(int(high_w // 2.5)), make_divisible(int(high_h // 2.5))
|
| 51 |
+
r33_w, r33_h = make_divisible(max(256, int(high_w // 3))), make_divisible(max(256, int(high_h // 3)))
|
| 52 |
+
|
| 53 |
+
if width < 512 or height < 512:
|
| 54 |
+
return [(high_w, high_h), (r80_w, r80_h), (r66_w, r66_h), (r50_w, r50_h)]
|
| 55 |
+
else:
|
| 56 |
+
return [
|
| 57 |
+
(high_w, high_h),
|
| 58 |
+
(r80_w, r80_h),
|
| 59 |
+
(r66_w, r66_h),
|
| 60 |
+
(r50_w, r50_h),
|
| 61 |
+
(r40_w, r40_h),
|
| 62 |
+
(r33_w, r33_h)
|
| 63 |
+
]
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def get_transforms(resolutions):
|
| 67 |
+
return [transforms.Compose([transforms.Resize((h, w)), transforms.ToTensor()]) for w, h in resolutions]
|
| 68 |
+
|
| 69 |
+
def get_prediction(image, transform, model):
|
| 70 |
+
img_input = image.convert('RGB')
|
| 71 |
+
img_input = transform(img_input).unsqueeze(0).to(params.DEVICE)
|
| 72 |
+
depth_out = model(img_input).detach().squeeze(0).to("cpu")
|
| 73 |
+
return depth_out
|
| 74 |
+
|
| 75 |
+
def predict_single_image(image, model_type):
|
| 76 |
+
if image is None:
|
| 77 |
+
return None, None
|
| 78 |
+
|
| 79 |
+
# Select model class and checkpoint
|
| 80 |
+
if model_type == "Lite":
|
| 81 |
+
model_class = PVSDNet_Lite
|
| 82 |
+
checkpoint = params.MODEL_Small_Location
|
| 83 |
+
else: # Default to "Large"
|
| 84 |
+
model_class = PVSDNet
|
| 85 |
+
checkpoint = params.MODEL_Large_Location
|
| 86 |
+
|
| 87 |
+
model = model_class(total_image_input=params.params_number_input)
|
| 88 |
+
model = helper.load_Checkpoint(checkpoint, model, load_cpu=True)
|
| 89 |
+
model.to(params.DEVICE)
|
| 90 |
+
model.eval()
|
| 91 |
+
|
| 92 |
+
original_width, original_height = image.size
|
| 93 |
+
|
| 94 |
+
resolutions = get_valid_resolutions(original_width, original_height)
|
| 95 |
+
print(f"Resolutions: {resolutions} for Model Type: {model_type}")
|
| 96 |
+
transforms_list = get_transforms(resolutions)
|
| 97 |
+
|
| 98 |
+
depth_maps = [get_prediction(image, t, model) for t in transforms_list]
|
| 99 |
+
|
| 100 |
+
depth_maps_resized = [
|
| 101 |
+
F.interpolate(depth[None], (original_height, original_width), mode='bilinear', align_corners=False)[0, 0]
|
| 102 |
+
for depth in depth_maps
|
| 103 |
+
]
|
| 104 |
+
|
| 105 |
+
depth_final = sum(depth_maps_resized) / len(depth_maps_resized)
|
| 106 |
+
|
| 107 |
+
depth_image = (depth_final - depth_final.min()) / (depth_final.max() - depth_final.min())
|
| 108 |
+
|
| 109 |
+
img_out = depth_image.numpy()
|
| 110 |
+
img_out_colored = plt.get_cmap('inferno')(img_out / np.max(img_out))[:, :, :3]
|
| 111 |
+
img_out_colored = (img_out_colored * 255).astype(np.uint8)
|
| 112 |
+
|
| 113 |
+
gray_scale_img_out = (depth_image.numpy() * 255).astype(np.uint8)
|
| 114 |
+
|
| 115 |
+
return Image.fromarray(img_out_colored), Image.fromarray(gray_scale_img_out)
|
| 116 |
+
|
| 117 |
+
with gr.Blocks(title="PVSDNet-Depth-Only Model", theme="default") as demo:
|
| 118 |
+
gr.Markdown(
|
| 119 |
+
"""
|
| 120 |
+
## PVSDNet-Depth-Only ZeroShot Relative Depth Estimation Model
|
| 121 |
+
* Upload an image and get its depth estimation with multi-scale fusion.
|
| 122 |
+
* Images use 2 - 6 resolutions for multi-scale fusion.
|
| 123 |
+
|
| 124 |
+
**Note:** Huggingface demo is running on CPU so inference speeds will be slow.
|
| 125 |
+
### Head to our [Project Page](https://realistic3d-miun.github.io/PVSDNet/) for more details about the models.
|
| 126 |
+
""")
|
| 127 |
+
|
| 128 |
+
with gr.Row():
|
| 129 |
+
with gr.Column():
|
| 130 |
+
img_input = gr.Image(type="pil", label="RGB Image", height=384)
|
| 131 |
+
with gr.Accordion("Advanced Settings", open=False):
|
| 132 |
+
model_type_dropdown = gr.Dropdown(["Large", "Lite"], label="Model Type", value="Large")
|
| 133 |
+
generate_btn = gr.Button("Estimate Depth", variant="primary")
|
| 134 |
+
|
| 135 |
+
with gr.Column():
|
| 136 |
+
output_color = gr.Image(type="pil", label="Depth Map (Color)", height=384)
|
| 137 |
+
output_gray = gr.Image(type="pil", label="Depth Map (Grayscale)", height=384)
|
| 138 |
+
|
| 139 |
+
generate_btn.click(
|
| 140 |
+
fn=predict_single_image,
|
| 141 |
+
inputs=[img_input, model_type_dropdown],
|
| 142 |
+
outputs=[output_color, output_gray]
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
gr.Markdown("### Example Samples")
|
| 146 |
+
with gr.Column():
|
| 147 |
+
with gr.Row():
|
| 148 |
+
with gr.Column(scale=2): gr.Markdown("**Example Image (Click to load)**")
|
| 149 |
+
with gr.Column(scale=1): gr.Markdown("**Resolution**")
|
| 150 |
+
with gr.Column(scale=2): gr.Markdown("**Fusion Resolutions**")
|
| 151 |
+
|
| 152 |
+
with gr.Row(variant="panel"):
|
| 153 |
+
with gr.Column(scale=2):
|
| 154 |
+
diode_preview = gr.Image("./samples/DIODE/00022_00195_outdoor_010_030.png", label="DIODE", height=120, interactive=False, show_label=True)
|
| 155 |
+
with gr.Column(scale=1):
|
| 156 |
+
gr.Markdown("1024 x 768")
|
| 157 |
+
with gr.Column(scale=2):
|
| 158 |
+
gr.Markdown("1024x768, 816x608, 688x512, 512x384, 416x304, 336x256")
|
| 159 |
+
|
| 160 |
+
with gr.Row(variant="panel"):
|
| 161 |
+
with gr.Column(scale=2):
|
| 162 |
+
eth3d_preview = gr.Image("./samples/ETH3D/DSC_0243.JPG", label="ETH3D", height=120, interactive=False, show_label=True)
|
| 163 |
+
with gr.Column(scale=1):
|
| 164 |
+
gr.Markdown("6048 x 4032")
|
| 165 |
+
with gr.Column(scale=2):
|
| 166 |
+
gr.Markdown("1024x1024, 816x816, 688x688, 512x512, 416x416, 336x336")
|
| 167 |
+
|
| 168 |
+
with gr.Row(variant="panel"):
|
| 169 |
+
with gr.Column(scale=2):
|
| 170 |
+
sintel_preview = gr.Image("./samples/Sintel/frame_0028_temple.png", label="Sintel", height=120, interactive=False, show_label=True)
|
| 171 |
+
with gr.Column(scale=1):
|
| 172 |
+
gr.Markdown("1024 x 436")
|
| 173 |
+
with gr.Column(scale=2):
|
| 174 |
+
gr.Markdown("1024x432, 816x352, 688x288, 512x224")
|
| 175 |
+
|
| 176 |
+
with gr.Row(variant="panel"):
|
| 177 |
+
with gr.Column(scale=2):
|
| 178 |
+
kitti_preview = gr.Image("./samples/KITTI/2011_10_03_drive_0047_sync_image_0000000383_image_02.png", label="KITTI", height=120, interactive=False, show_label=True)
|
| 179 |
+
with gr.Column(scale=1):
|
| 180 |
+
gr.Markdown("1216 x 532")
|
| 181 |
+
with gr.Column(scale=2):
|
| 182 |
+
gr.Markdown("1024x352, 816x288, 688x240, 512x176")
|
| 183 |
+
|
| 184 |
+
with gr.Row(variant="panel"):
|
| 185 |
+
with gr.Column(scale=2):
|
| 186 |
+
wild_1_preview = gr.Image("./samples/Wild/toy.jpeg", label="Wild Image 1", height=120, interactive=False, show_label=True)
|
| 187 |
+
with gr.Column(scale=1):
|
| 188 |
+
gr.Markdown("3019 x 3018")
|
| 189 |
+
with gr.Column(scale=2):
|
| 190 |
+
gr.Markdown("1024x1024, 816x816, 688x688, 512x512, 416x416, 336x336")
|
| 191 |
+
|
| 192 |
+
with gr.Row(variant="panel"):
|
| 193 |
+
with gr.Column(scale=2):
|
| 194 |
+
wild_2_preview = gr.Image("./samples/Wild/hamburg.jpeg", label="Wild Image 2", height=120, interactive=False, show_label=True)
|
| 195 |
+
with gr.Column(scale=1):
|
| 196 |
+
gr.Markdown("1536 x 1920")
|
| 197 |
+
with gr.Column(scale=2):
|
| 198 |
+
gr.Markdown("1024x1024, 816x816, 688x688, 512x512, 416x416, 336x336")
|
| 199 |
+
|
| 200 |
+
with gr.Row(variant="panel"):
|
| 201 |
+
with gr.Column(scale=2):
|
| 202 |
+
wild_3_preview = gr.Image("./samples/Wild/north_hill.jpeg", label="Wild Image 3", height=120, interactive=False, show_label=True)
|
| 203 |
+
with gr.Column(scale=1):
|
| 204 |
+
gr.Markdown("2320 x 2321")
|
| 205 |
+
with gr.Column(scale=2):
|
| 206 |
+
gr.Markdown("1024x1024, 816x816, 688x688, 512x512, 416x416, 336x336")
|
| 207 |
+
|
| 208 |
+
with gr.Row(variant="panel"):
|
| 209 |
+
with gr.Column(scale=2):
|
| 210 |
+
wild_4_preview = gr.Image("./samples/Wild/EH.jpeg", label="Wild Image 4", height=120, interactive=False, show_label=True)
|
| 211 |
+
with gr.Column(scale=1):
|
| 212 |
+
gr.Markdown("1920 x 1080")
|
| 213 |
+
with gr.Column(scale=2):
|
| 214 |
+
gr.Markdown("1024x1024, 816x816, 688x688, 512x512, 416x416, 336x336")
|
| 215 |
+
|
| 216 |
+
with gr.Row(variant="panel"):
|
| 217 |
+
with gr.Column(scale=2):
|
| 218 |
+
wild_5_preview = gr.Image("./samples/Wild/train_station.jpeg", label="Wild Image 5", height=120, interactive=False, show_label=True)
|
| 219 |
+
with gr.Column(scale=1):
|
| 220 |
+
gr.Markdown("1066 x 1060")
|
| 221 |
+
with gr.Column(scale=2):
|
| 222 |
+
gr.Markdown("1024x1024, 816x816, 688x688, 512x512, 416x416, 336x336")
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
# Define click events to load images
|
| 226 |
+
eth3d_preview.select(fn=lambda: Image.open("./samples/ETH3D/DSC_0243.JPG"), outputs=img_input)
|
| 227 |
+
sintel_preview.select(fn=lambda: Image.open("./samples/Sintel/frame_0028_temple.png"), outputs=img_input)
|
| 228 |
+
kitti_preview.select(fn=lambda: Image.open("./samples/KITTI/2011_10_03_drive_0047_sync_image_0000000383_image_02.png"), outputs=img_input)
|
| 229 |
+
diode_preview.select(fn=lambda: Image.open("./samples/DIODE/00022_00195_outdoor_010_030.png"), outputs=img_input)
|
| 230 |
+
|
| 231 |
+
wild_1_preview.select(fn=lambda: Image.open("./samples/Wild/toy.jpeg"), outputs=img_input)
|
| 232 |
+
wild_2_preview.select(fn=lambda: Image.open("./samples/Wild/hamburg.jpeg"), outputs=img_input)
|
| 233 |
+
wild_3_preview.select(fn=lambda: Image.open("./samples/Wild/north_hill.jpeg"), outputs=img_input)
|
| 234 |
+
wild_4_preview.select(fn=lambda: Image.open("./samples/Wild/EH.jpeg"), outputs=img_input)
|
| 235 |
+
wild_5_preview.select(fn=lambda: Image.open("./samples/Wild/train_station.jpeg"), outputs=img_input)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
demo.launch()
|
depth_only_parameters.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
params_height = 384
|
| 4 |
+
params_width = 384
|
| 5 |
+
|
| 6 |
+
params_number_input = 1
|
| 7 |
+
|
| 8 |
+
LOG_FILE_LOCATION = "./logs/training_log_0.txt"
|
| 9 |
+
CHECKPOINT_LOCATION = "./checkpoint/"
|
| 10 |
+
DEVICE = "cpu"
|
| 11 |
+
ONNX_PATH = "./checkpoint_onnx"
|
| 12 |
+
|
| 13 |
+
MODEL_Large_Location = "./checkpoint/depth_only_model.pth"
|
| 14 |
+
MODEL_Small_Location = "./checkpoint/depth_only_lite_model.pth"
|
| 15 |
+
|
| 16 |
+
os.makedirs(ONNX_PATH,exist_ok=True)
|
| 17 |
+
os.makedirs("./logs",exist_ok=True)
|
| 18 |
+
os.makedirs("./checkpoint",exist_ok=True)
|
| 19 |
+
os.makedirs("./output",exist_ok=True)
|
| 20 |
+
|
| 21 |
+
|
helperFunctions.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import os
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
def save_checkpoint(model, filelocation, save_parallel = True):
|
| 6 |
+
if save_parallel:
|
| 7 |
+
torch.save(model.module.state_dict(), filelocation)
|
| 8 |
+
else:
|
| 9 |
+
torch.save(model.state_dict(), filelocation)
|
| 10 |
+
|
| 11 |
+
def load_Checkpoint(fileLocation,model, load_cpu=False):
|
| 12 |
+
if load_cpu:
|
| 13 |
+
model.load_state_dict(torch.load(fileLocation,map_location=lambda storage, loc: storage))
|
| 14 |
+
else:
|
| 15 |
+
model.load_state_dict(torch.load(fileLocation))
|
| 16 |
+
return model
|
| 17 |
+
|
| 18 |
+
def writeLog(logList, filename):
|
| 19 |
+
with open(filename, 'w') as outfile:
|
| 20 |
+
outfile.write("\n".join(logList))
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def kl_loss(mu, logvar):
|
| 24 |
+
return -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).mean()
|
| 25 |
+
|
| 26 |
+
|
helper_image_functions.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
Author: Manu Gond (manu.gond@miun.se)
|
| 3 |
+
Date: Nov-15-2022
|
| 4 |
+
Objective: Accumulation of some general functions which I
|
| 5 |
+
use daily in my code realted to image relasted task.
|
| 6 |
+
The function names and parameters are self explanetory.
|
| 7 |
+
Requirements: Installed python libraries which have been imported.
|
| 8 |
+
'''
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torchvision.utils import save_image
|
| 12 |
+
from torchvision.transforms import transforms
|
| 13 |
+
import torchmetrics
|
| 14 |
+
import cv2
|
| 15 |
+
import numpy as np
|
| 16 |
+
from PIL import Image
|
| 17 |
+
import utils
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
#======================= Read and Write =====================#
|
| 21 |
+
def readImage(location):
|
| 22 |
+
image = Image.open(location).convert("RGB")
|
| 23 |
+
return image
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def writeImage(image, location):
|
| 27 |
+
image.save(location)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def writeTensorImage(image, filename):
|
| 31 |
+
save_image(image, filename)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def removeChannel(sourceLocation, targetLocation):
|
| 35 |
+
img = readImage(sourceLocation)
|
| 36 |
+
writeImage(img, targetLocation)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def getImageTransform(width, height):
|
| 40 |
+
transform = transforms.Compose([transforms.Resize((height,width)),
|
| 41 |
+
transforms.ToTensor()])
|
| 42 |
+
return transform
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def convertTensor(image):
|
| 46 |
+
transform = getImageTransform(image.size[0], image.size[1])
|
| 47 |
+
image = transform(image)
|
| 48 |
+
return image
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
#=================== 360 Images =======================#
|
| 52 |
+
|
| 53 |
+
def rotateERP180(image):
|
| 54 |
+
'''
|
| 55 |
+
:param image: PIL Image
|
| 56 |
+
:return: BxHxW Torch Tensor Image
|
| 57 |
+
'''
|
| 58 |
+
W = image.size[0]
|
| 59 |
+
H = image.size[1]
|
| 60 |
+
transform = getImageTransform(W, H)
|
| 61 |
+
image = transform(image)
|
| 62 |
+
image1 = image[:, :, 0:(W//2)]
|
| 63 |
+
image2 = image[:, :, (W//2):W]
|
| 64 |
+
image3 = torch.zeros(image.size())
|
| 65 |
+
image3[:, :, 0:(W//2)] = image2
|
| 66 |
+
image3[:, :, (W//2):W] = image1
|
| 67 |
+
return image3
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def convertERP2Cube(e_img, face_w=256, mode='bilinear', cube_format='dice'):
|
| 71 |
+
'''
|
| 72 |
+
e_img: ndarray in shape of [H, W, *]
|
| 73 |
+
face_w: int, the length of each face of the cubemap
|
| 74 |
+
'''
|
| 75 |
+
assert len(e_img.shape) == 3
|
| 76 |
+
h, w = e_img.shape[:2]
|
| 77 |
+
if mode == 'bilinear':
|
| 78 |
+
order = 1
|
| 79 |
+
elif mode == 'nearest':
|
| 80 |
+
order = 0
|
| 81 |
+
else:
|
| 82 |
+
raise NotImplementedError('unknown mode')
|
| 83 |
+
|
| 84 |
+
xyz = utils.xyzcube(face_w)
|
| 85 |
+
uv = utils.xyz2uv(xyz)
|
| 86 |
+
coor_xy = utils.uv2coor(uv, h, w)
|
| 87 |
+
|
| 88 |
+
cubemap = np.stack([
|
| 89 |
+
utils.sample_equirec(e_img[..., i], coor_xy, order=order)
|
| 90 |
+
for i in range(e_img.shape[2])
|
| 91 |
+
], axis=-1)
|
| 92 |
+
|
| 93 |
+
if cube_format == 'horizon':
|
| 94 |
+
pass
|
| 95 |
+
elif cube_format == 'list':
|
| 96 |
+
cubemap = utils.cube_h2list(cubemap)
|
| 97 |
+
elif cube_format == 'dict':
|
| 98 |
+
cubemap = utils.cube_h2dict(cubemap)
|
| 99 |
+
elif cube_format == 'dice':
|
| 100 |
+
cubemap = utils.cube_h2dice(cubemap)
|
| 101 |
+
else:
|
| 102 |
+
raise NotImplementedError()
|
| 103 |
+
return cubemap
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def convertCube2ERP(cubemap, h, w, mode='bilinear', cube_format='dice'):
|
| 107 |
+
if mode == 'bilinear':
|
| 108 |
+
order = 1
|
| 109 |
+
elif mode == 'nearest':
|
| 110 |
+
order = 0
|
| 111 |
+
else:
|
| 112 |
+
raise NotImplementedError('unknown mode')
|
| 113 |
+
|
| 114 |
+
if cube_format == 'horizon':
|
| 115 |
+
pass
|
| 116 |
+
elif cube_format == 'list':
|
| 117 |
+
cubemap = utils.cube_list2h(cubemap)
|
| 118 |
+
elif cube_format == 'dict':
|
| 119 |
+
cubemap = utils.cube_dict2h(cubemap)
|
| 120 |
+
elif cube_format == 'dice':
|
| 121 |
+
cubemap = utils.cube_dice2h(cubemap)
|
| 122 |
+
else:
|
| 123 |
+
raise NotImplementedError('unknown cube_format')
|
| 124 |
+
assert len(cubemap.shape) == 3
|
| 125 |
+
assert cubemap.shape[0] * 6 == cubemap.shape[1]
|
| 126 |
+
assert w % 8 == 0
|
| 127 |
+
face_w = cubemap.shape[0]
|
| 128 |
+
|
| 129 |
+
uv = utils.equirect_uvgrid(h, w)
|
| 130 |
+
u, v = np.split(uv, 2, axis=-1)
|
| 131 |
+
u = u[..., 0]
|
| 132 |
+
v = v[..., 0]
|
| 133 |
+
cube_faces = np.stack(np.split(cubemap, 6, 1), 0)
|
| 134 |
+
|
| 135 |
+
# Get face id to each pixel: 0F 1R 2B 3L 4U 5D
|
| 136 |
+
tp = utils.equirect_facetype(h, w)
|
| 137 |
+
coor_x = np.zeros((h, w))
|
| 138 |
+
coor_y = np.zeros((h, w))
|
| 139 |
+
|
| 140 |
+
for i in range(4):
|
| 141 |
+
mask = (tp == i)
|
| 142 |
+
coor_x[mask] = 0.5 * np.tan(u[mask] - np.pi * i / 2)
|
| 143 |
+
coor_y[mask] = -0.5 * np.tan(v[mask]) / np.cos(u[mask] - np.pi * i / 2)
|
| 144 |
+
|
| 145 |
+
mask = (tp == 4)
|
| 146 |
+
c = 0.5 * np.tan(np.pi / 2 - v[mask])
|
| 147 |
+
coor_x[mask] = c * np.sin(u[mask])
|
| 148 |
+
coor_y[mask] = c * np.cos(u[mask])
|
| 149 |
+
|
| 150 |
+
mask = (tp == 5)
|
| 151 |
+
c = 0.5 * np.tan(np.pi / 2 - np.abs(v[mask]))
|
| 152 |
+
coor_x[mask] = c * np.sin(u[mask])
|
| 153 |
+
coor_y[mask] = -c * np.cos(u[mask])
|
| 154 |
+
|
| 155 |
+
# Final renormalize
|
| 156 |
+
coor_x = (np.clip(coor_x, -0.5, 0.5) + 0.5) * face_w
|
| 157 |
+
coor_y = (np.clip(coor_y, -0.5, 0.5) + 0.5) * face_w
|
| 158 |
+
|
| 159 |
+
equirec = np.stack([
|
| 160 |
+
utils.sample_cubefaces(cube_faces[..., i], tp, coor_y, coor_x, order=order)
|
| 161 |
+
for i in range(cube_faces.shape[3])
|
| 162 |
+
], axis=-1)
|
| 163 |
+
return equirec
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def convertCube2Slices(image):
|
| 168 |
+
'''
|
| 169 |
+
:param image: Image numpy array
|
| 170 |
+
:return: List of Torch Tensors, CxHxW
|
| 171 |
+
'''
|
| 172 |
+
image = convertTensor(image)
|
| 173 |
+
C, H, W = image.size()
|
| 174 |
+
#print(C,H,W)
|
| 175 |
+
top = torch.zeros((C,W//4,W//4))
|
| 176 |
+
left = torch.zeros(top.size())
|
| 177 |
+
front = torch.zeros(top.size())
|
| 178 |
+
right = torch.zeros(top.size())
|
| 179 |
+
back = torch.zeros(top.size())
|
| 180 |
+
bottom = torch.zeros(top.size())
|
| 181 |
+
|
| 182 |
+
top = image[:, 0:H//3, (W//4):(W//4)*2]
|
| 183 |
+
left = image[:, (H//3):(H//3)*2, 0:W//4]
|
| 184 |
+
front = image[:, (H//3):(H//3)*2, (W//4):(W//4)*2]
|
| 185 |
+
right = image[:, (H//3):(H//3)*2, (W//4)*2:(W//4)*3]
|
| 186 |
+
back = image[:, (H // 3):(H // 3) * 2, (W // 4) * 3:]
|
| 187 |
+
bottom = image[:, (H//3)*2:, (W//4):(W//4)*2]
|
| 188 |
+
|
| 189 |
+
'''
|
| 190 |
+
save_image(top, 'top.png')
|
| 191 |
+
save_image(left, 'left.png')
|
| 192 |
+
save_image(front, 'front.png')
|
| 193 |
+
save_image(right, 'right.png')
|
| 194 |
+
save_image(back, 'back.png')
|
| 195 |
+
save_image(bottom, 'bottom.png')
|
| 196 |
+
'''
|
| 197 |
+
return [top, left, front, right, back, bottom]
|
| 198 |
+
|
| 199 |
+
def convertSlicesToCube(imageList):
|
| 200 |
+
'''
|
| 201 |
+
top = convertTensor(readImage(imageList[0]))
|
| 202 |
+
left = convertTensor(readImage(imageList[1]))
|
| 203 |
+
front = convertTensor(readImage(imageList[2]))
|
| 204 |
+
right = convertTensor(readImage(imageList[3]))
|
| 205 |
+
back = convertTensor(readImage(imageList[4]))
|
| 206 |
+
bottom = convertTensor(readImage(imageList[5]))
|
| 207 |
+
'''
|
| 208 |
+
top = imageList[0]
|
| 209 |
+
left = imageList[1]
|
| 210 |
+
front = imageList[2]
|
| 211 |
+
right = imageList[3]
|
| 212 |
+
back = imageList[4]
|
| 213 |
+
bottom = imageList[5]
|
| 214 |
+
|
| 215 |
+
C, H, W = 3, top.size()[1]*3, top.size()[2]*4
|
| 216 |
+
cube = torch.zeros((C, H, W))
|
| 217 |
+
|
| 218 |
+
cube[:, 0:H//3, (W//4):(W//4)*2] = top
|
| 219 |
+
cube[:, (H // 3):(H // 3) * 2, 0:W // 4] = left
|
| 220 |
+
cube[:, (H // 3):(H // 3) * 2, (W // 4):(W // 4) * 2] = front
|
| 221 |
+
cube[:, (H // 3):(H // 3) * 2, (W // 4) * 2:(W // 4) * 3] = right
|
| 222 |
+
cube[:, (H // 3):(H // 3) * 2, (W // 4) * 3:] = back
|
| 223 |
+
cube[:, (H // 3) * 2:, (W // 4):(W // 4) * 2] = bottom
|
| 224 |
+
|
| 225 |
+
return cube
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
#=================== Quality Measures =======================#
|
| 230 |
+
'''
|
| 231 |
+
Predicted Shape : BxCxHxW
|
| 232 |
+
Original Shape : BxCxHxW
|
| 233 |
+
Data Type: Torch Tensor
|
| 234 |
+
'''
|
| 235 |
+
def getSSIM(predicted, original):
|
| 236 |
+
SSIM = torchmetrics.StructuralSimilarityIndexMeasure()
|
| 237 |
+
return SSIM(predicted, original).item()
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def getPSNR(predicted, original):
|
| 241 |
+
PSNR = torchmetrics.PeakSignalNoiseRatio()
|
| 242 |
+
return PSNR(predicted, original).item()
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def getMSE(predicted, original):
|
| 246 |
+
MSE = torchmetrics.MeanSquaredError()
|
| 247 |
+
return MSE(predicted, original).item()
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def getMAE(predicted, original):
|
| 251 |
+
MAE = torchmetrics.MeanAbsoluteError()
|
| 252 |
+
return MAE(predicted, original).item()
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
if __name__ == "__main__":
|
| 257 |
+
|
| 258 |
+
'''
|
| 259 |
+
img = readImage("31_image_0_0.png")
|
| 260 |
+
img = convertERP2Cube(e_img=np.asarray(img), face_w=256)
|
| 261 |
+
img = Image.fromarray(img.astype('uint8'),'RGB')
|
| 262 |
+
convertCube2Slices(img)
|
| 263 |
+
'''
|
| 264 |
+
#image = convertSlicesToCube(["top.png", "left.png", "front.png", "right.png", "back.png", "bottom.png"])
|
| 265 |
+
#writeTensorImage(image,'this.png')
|
| 266 |
+
|
| 267 |
+
'''
|
| 268 |
+
writeImage(img, 'cube.png')
|
| 269 |
+
|
| 270 |
+
img = readImage('cube.png')
|
| 271 |
+
img = convertCube2ERP(np.asarray(img),512,1024)
|
| 272 |
+
img = Image.fromarray(img.astype('uint8'),'RGB')
|
| 273 |
+
writeImage(img, 'cubeERP.png')
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
img1 = readImage("31_image_0_0.png")
|
| 277 |
+
img2 = readImage("cubeERP.png")
|
| 278 |
+
img1 = convertTensor(img1)
|
| 279 |
+
img2 = convertTensor(img2)
|
| 280 |
+
print(getSSIM(img1.unsqueeze(0), img2.unsqueeze(0)))
|
| 281 |
+
'''
|
| 282 |
+
|
| 283 |
+
#img = rotateERP180(img)
|
| 284 |
+
#writeTensorImage(img, 'rotated_image.png')
|
| 285 |
+
#img = convertTensor(img)
|
| 286 |
+
#print(getMAE(img.unsqueeze(0),img.unsqueeze(0)))
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
|
models/__init__.py
ADDED
|
File without changes
|
models/depth_only_lite_model.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import warnings
|
| 5 |
+
warnings.filterwarnings("ignore")
|
| 6 |
+
import torchvision
|
| 7 |
+
import sys
|
| 8 |
+
import os
|
| 9 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
| 10 |
+
import depth_only_parameters as params
|
| 11 |
+
|
| 12 |
+
def getConvLayer(in_channel, out_channel, stride=1, padding=1, activation=nn.ReLU()):
|
| 13 |
+
return nn.Sequential(
|
| 14 |
+
nn.Conv2d(in_channel, out_channel,
|
| 15 |
+
kernel_size=3,
|
| 16 |
+
stride=stride,
|
| 17 |
+
padding=padding,
|
| 18 |
+
padding_mode='reflect'),
|
| 19 |
+
activation
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
def getConvTransposeLayer(in_channel, out_channel, kernel=3, stride=1, padding=1, activation=nn.ReLU()):
|
| 23 |
+
return nn.Sequential(
|
| 24 |
+
nn.ConvTranspose2d(in_channel,
|
| 25 |
+
out_channel,
|
| 26 |
+
kernel_size=kernel,
|
| 27 |
+
stride=stride,
|
| 28 |
+
padding=padding),
|
| 29 |
+
activation
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
class Flatten(nn.Module):
|
| 33 |
+
def forward(self, input):
|
| 34 |
+
return input.view(input.size(0), -1)
|
| 35 |
+
|
| 36 |
+
class UnFlatten(nn.Module):
|
| 37 |
+
def forward(self, input, size=1):
|
| 38 |
+
return input.view(input.size(0), 1, params.params_height//8, params.params_width//8)
|
| 39 |
+
|
| 40 |
+
class ResidualBlock(nn.Module):
|
| 41 |
+
def __init__(self, in_channels, out_channels, stride=1):
|
| 42 |
+
super(ResidualBlock, self).__init__()
|
| 43 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
|
| 44 |
+
stride=stride, padding=1, bias=False)
|
| 45 |
+
self.relu = nn.ReLU()
|
| 46 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
|
| 47 |
+
stride=1, padding=1, bias=False)
|
| 48 |
+
self.stride = stride
|
| 49 |
+
|
| 50 |
+
self.shortcut = nn.Sequential()
|
| 51 |
+
if stride != 1 or in_channels != out_channels:
|
| 52 |
+
self.shortcut = nn.Sequential(
|
| 53 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=1,
|
| 54 |
+
stride=stride, bias=False),
|
| 55 |
+
nn.BatchNorm2d(out_channels)
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
def forward(self, x):
|
| 59 |
+
residual = x
|
| 60 |
+
out = self.conv1(x)
|
| 61 |
+
out = self.relu(out)
|
| 62 |
+
out = self.conv2(out)
|
| 63 |
+
out = out + self.shortcut(residual)
|
| 64 |
+
out = self.relu(out)
|
| 65 |
+
return out
|
| 66 |
+
|
| 67 |
+
class UpperEncoder(nn.Module):
|
| 68 |
+
def __init__(self):
|
| 69 |
+
super().__init__()
|
| 70 |
+
model = torchvision.models.resnet152(pretrained=True)
|
| 71 |
+
layers = list(model.children())
|
| 72 |
+
self.ResNetEncoder = nn.Sequential(*layers[:5].copy())
|
| 73 |
+
del model
|
| 74 |
+
|
| 75 |
+
def forward(self, x):
|
| 76 |
+
x1 = x[:, 0:3, :, :]
|
| 77 |
+
x1 = self.ResNetEncoder(x1)
|
| 78 |
+
return x1
|
| 79 |
+
|
| 80 |
+
def apply_resnet_encoder(self, x):
|
| 81 |
+
x1 = x[:, 0:3, :, :]
|
| 82 |
+
x1 = self.ResNetEncoder(x1)
|
| 83 |
+
return x1
|
| 84 |
+
|
| 85 |
+
class LowerEncoder(nn.Module):
|
| 86 |
+
def __init__(self, total_image_input=1):
|
| 87 |
+
super().__init__()
|
| 88 |
+
# Halved channels compared to the original
|
| 89 |
+
self.encoder_pre = ResidualBlock(total_image_input*3, 10)
|
| 90 |
+
self.encoder_layer1 = ResidualBlock(10, 15)
|
| 91 |
+
self.encoder_layer2 = ResidualBlock(15, 25)
|
| 92 |
+
|
| 93 |
+
self.encoder_layer3 = nn.Sequential(
|
| 94 |
+
ResidualBlock(25, 50),
|
| 95 |
+
nn.MaxPool2d(kernel_size=2, stride=2)
|
| 96 |
+
)
|
| 97 |
+
self.encoder_layer4 = ResidualBlock(50, 100)
|
| 98 |
+
self.encoder_layer5 = nn.Sequential(
|
| 99 |
+
ResidualBlock(100, 200),
|
| 100 |
+
nn.MaxPool2d(kernel_size=2, stride=2)
|
| 101 |
+
)
|
| 102 |
+
self.encoder_layer6 = ResidualBlock(200, 300)
|
| 103 |
+
self.encoder_layer7 = nn.Sequential(
|
| 104 |
+
ResidualBlock(300, 400),
|
| 105 |
+
nn.MaxPool2d(kernel_size=2, stride=2)
|
| 106 |
+
)
|
| 107 |
+
self.encoder_layer8 = ResidualBlock(400, 500)
|
| 108 |
+
self.encoder_layer9 = nn.Sequential(
|
| 109 |
+
ResidualBlock(500, 600),
|
| 110 |
+
nn.MaxPool2d(kernel_size=2, stride=2)
|
| 111 |
+
)
|
| 112 |
+
self.encoder_layer10 = ResidualBlock(600, 700)
|
| 113 |
+
self.encoder_layer11 = ResidualBlock(700, 800)
|
| 114 |
+
|
| 115 |
+
def forward(self, x):
|
| 116 |
+
x = self.encoder_pre(x)
|
| 117 |
+
x = self.encoder_layer1(x)
|
| 118 |
+
x = self.encoder_layer2(x)
|
| 119 |
+
skip1 = self.encoder_layer3(x)
|
| 120 |
+
|
| 121 |
+
x = self.encoder_layer4(skip1)
|
| 122 |
+
skip2 = self.encoder_layer5(x)
|
| 123 |
+
|
| 124 |
+
x = self.encoder_layer6(skip2)
|
| 125 |
+
skip3 = self.encoder_layer7(x)
|
| 126 |
+
|
| 127 |
+
x = self.encoder_layer8(skip3)
|
| 128 |
+
skip4 = self.encoder_layer9(x)
|
| 129 |
+
|
| 130 |
+
x = self.encoder_layer10(skip4)
|
| 131 |
+
x = self.encoder_layer11(x)
|
| 132 |
+
return x, [skip1, skip2, skip3, skip4]
|
| 133 |
+
|
| 134 |
+
class MergeDecoder(nn.Module):
|
| 135 |
+
def __init__(self):
|
| 136 |
+
super().__init__()
|
| 137 |
+
# Halved channels for decoder blocks
|
| 138 |
+
self.decoder_layer1 = ResidualBlock(800, 700)
|
| 139 |
+
self.decoder_layer2 = ResidualBlock(700, 600)
|
| 140 |
+
self.decoder_layer3 = ResidualBlock(600, 500)
|
| 141 |
+
|
| 142 |
+
self.decoder_layer4 = nn.Sequential(
|
| 143 |
+
nn.ConvTranspose2d(500, 400, kernel_size=2, stride=2, padding=0),
|
| 144 |
+
nn.ReLU(True)
|
| 145 |
+
)
|
| 146 |
+
self.decoder_layer5 = ResidualBlock(400, 300)
|
| 147 |
+
|
| 148 |
+
self.decoder_layer6 = nn.Sequential(
|
| 149 |
+
nn.ConvTranspose2d(300, 200, kernel_size=2, stride=2, padding=0),
|
| 150 |
+
nn.ReLU(True)
|
| 151 |
+
)
|
| 152 |
+
self.decoder_layer7 = ResidualBlock(200, 100)
|
| 153 |
+
|
| 154 |
+
self.decoder_layer8 = nn.Sequential(
|
| 155 |
+
nn.ConvTranspose2d(100, 50, kernel_size=2, stride=2, padding=0),
|
| 156 |
+
nn.ReLU(True)
|
| 157 |
+
)
|
| 158 |
+
self.decoder_layer9 = ResidualBlock(50, 50)
|
| 159 |
+
|
| 160 |
+
self.decoder_layer10 = nn.Sequential(
|
| 161 |
+
nn.ConvTranspose2d(50, 50, kernel_size=2, stride=2, padding=0),
|
| 162 |
+
nn.ReLU(True)
|
| 163 |
+
)
|
| 164 |
+
self.decoder_layer11 = ResidualBlock(50, 50)
|
| 165 |
+
self.decoder_layer12 = ResidualBlock(50, 25)
|
| 166 |
+
self.decoder_layer13 = ResidualBlock(25, 20)
|
| 167 |
+
self.decoder_layer14 = ResidualBlock(20, 10)
|
| 168 |
+
self.decoder_layer15 = nn.Sequential(
|
| 169 |
+
nn.Conv2d(10, 4, kernel_size=3, stride=1, padding=1),
|
| 170 |
+
nn.ReLU(True)
|
| 171 |
+
)
|
| 172 |
+
self.decoder_layer16 = nn.Sequential(
|
| 173 |
+
nn.Conv2d(4, 1, kernel_size=3, stride=1, padding=1),
|
| 174 |
+
nn.ReLU(True)
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
def forward(self, x, lower_skip_list, upper_skip_list):
|
| 178 |
+
x = self.decoder_layer1(x)
|
| 179 |
+
x = self.decoder_layer2(x)
|
| 180 |
+
# Expecting lower_skip_list[3] and upper_skip_list[1] to have matching dimensions
|
| 181 |
+
x = x + lower_skip_list[3] + upper_skip_list[1]
|
| 182 |
+
|
| 183 |
+
x = self.decoder_layer3(x)
|
| 184 |
+
x = self.decoder_layer4(x)
|
| 185 |
+
x = x + lower_skip_list[2] + upper_skip_list[0]
|
| 186 |
+
|
| 187 |
+
x = self.decoder_layer5(x)
|
| 188 |
+
x = self.decoder_layer6(x)
|
| 189 |
+
x = x + lower_skip_list[1]
|
| 190 |
+
|
| 191 |
+
x = self.decoder_layer7(x)
|
| 192 |
+
x = self.decoder_layer8(x)
|
| 193 |
+
x = x + lower_skip_list[0]
|
| 194 |
+
|
| 195 |
+
x = self.decoder_layer9(x)
|
| 196 |
+
x = self.decoder_layer10(x)
|
| 197 |
+
x = self.decoder_layer11(x)
|
| 198 |
+
x = self.decoder_layer12(x)
|
| 199 |
+
x = self.decoder_layer13(x)
|
| 200 |
+
x = self.decoder_layer14(x)
|
| 201 |
+
x = self.decoder_layer15(x)
|
| 202 |
+
x = self.decoder_layer16(x)
|
| 203 |
+
return x
|
| 204 |
+
|
| 205 |
+
class PVSDNet_Lite(nn.Module):
|
| 206 |
+
def __init__(self, total_image_input=1):
|
| 207 |
+
super().__init__()
|
| 208 |
+
# Upper encoder remains mostly the same
|
| 209 |
+
self.upper_encoder = UpperEncoder()
|
| 210 |
+
self.lower_encoder = LowerEncoder(total_image_input)
|
| 211 |
+
self.merge_decoder = MergeDecoder()
|
| 212 |
+
# Halved extra layers for upper branch:
|
| 213 |
+
self.upper_encoder_extra_1 = nn.Sequential(
|
| 214 |
+
ResidualBlock(256, 400),
|
| 215 |
+
nn.MaxPool2d(kernel_size=2, stride=2)
|
| 216 |
+
)
|
| 217 |
+
self.upper_encoder_extra_2 = nn.Sequential(
|
| 218 |
+
ResidualBlock(400, 600),
|
| 219 |
+
nn.MaxPool2d(kernel_size=2, stride=2)
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
def forward(self, x):
|
| 223 |
+
# First Encoder Branch (Upper)
|
| 224 |
+
upper_features_1 = self.upper_encoder.apply_resnet_encoder(x)
|
| 225 |
+
upper_features_1 = self.upper_encoder_extra_1(upper_features_1)
|
| 226 |
+
upper_features_2 = self.upper_encoder_extra_2(upper_features_1)
|
| 227 |
+
|
| 228 |
+
# Second Encoder Branch (Lower)
|
| 229 |
+
lower_feature, skip_list = self.lower_encoder(x)
|
| 230 |
+
|
| 231 |
+
# Merge and decode features
|
| 232 |
+
merged_feature = self.merge_decoder(lower_feature, skip_list, [upper_features_1, upper_features_2])
|
| 233 |
+
return merged_feature
|
| 234 |
+
|
models/depth_only_model.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import warnings
|
| 5 |
+
warnings.filterwarnings("ignore")
|
| 6 |
+
import torchvision
|
| 7 |
+
import sys
|
| 8 |
+
import os
|
| 9 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
| 10 |
+
import depth_only_parameters as params
|
| 11 |
+
|
| 12 |
+
def getConvLayer(in_channel,out_channel,stride=1,padding=1,activation=nn.ReLU()):
|
| 13 |
+
return nn.Sequential(nn.Conv2d(in_channel,
|
| 14 |
+
out_channel,
|
| 15 |
+
kernel_size=3,
|
| 16 |
+
stride=stride,
|
| 17 |
+
padding=padding,
|
| 18 |
+
padding_mode='reflect'),
|
| 19 |
+
activation)
|
| 20 |
+
|
| 21 |
+
def getConvTransposeLayer(in_channel, out_channel,kernel=3,stride=1,padding=1,activation=nn.ReLU()):
|
| 22 |
+
return nn.Sequential(nn.ConvTranspose2d(in_channel,
|
| 23 |
+
out_channel,
|
| 24 |
+
kernel_size = kernel,
|
| 25 |
+
stride=stride,
|
| 26 |
+
padding=padding),
|
| 27 |
+
activation)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class Flatten(nn.Module):
|
| 32 |
+
def forward(self, input):
|
| 33 |
+
return input.view(input.size(0), -1)
|
| 34 |
+
|
| 35 |
+
class UnFlatten(nn.Module):
|
| 36 |
+
def forward(self, input, size=1):
|
| 37 |
+
return input.view(input.size(0), 1, params.params_height//8, params.params_width//8)
|
| 38 |
+
|
| 39 |
+
class ResidualBlock(nn.Module):
|
| 40 |
+
def __init__(self, in_channels, out_channels, stride=1):
|
| 41 |
+
super(ResidualBlock, self).__init__()
|
| 42 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
|
| 43 |
+
self.relu = nn.ReLU()
|
| 44 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
| 45 |
+
self.stride = stride
|
| 46 |
+
|
| 47 |
+
self.shortcut = nn.Sequential()
|
| 48 |
+
if stride != 1 or in_channels != out_channels:
|
| 49 |
+
self.shortcut = nn.Sequential(
|
| 50 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
|
| 51 |
+
nn.BatchNorm2d(out_channels)
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
def forward(self, x):
|
| 55 |
+
residual = x
|
| 56 |
+
|
| 57 |
+
out = self.conv1(x)
|
| 58 |
+
out = self.relu(out)
|
| 59 |
+
|
| 60 |
+
out = self.conv2(out)
|
| 61 |
+
|
| 62 |
+
out = out + self.shortcut(residual)
|
| 63 |
+
out = self.relu(out)
|
| 64 |
+
return out
|
| 65 |
+
|
| 66 |
+
class UpperEncoder(nn.Module):
|
| 67 |
+
def __init__(self):
|
| 68 |
+
super().__init__()
|
| 69 |
+
model = torchvision.models.resnet152(pretrained=False)
|
| 70 |
+
layers = list(model.children())
|
| 71 |
+
self.ResNetEncoder = torch.nn.Sequential(*layers[:5].copy())
|
| 72 |
+
del model
|
| 73 |
+
|
| 74 |
+
def forward(self, x):
|
| 75 |
+
x1 = x[:, 0:3, :, :]
|
| 76 |
+
x1 = self.ResNetEncoder(x1)
|
| 77 |
+
return x1
|
| 78 |
+
|
| 79 |
+
def apply_resnet_encoder(self, x):
|
| 80 |
+
x1 = x[:, 0:3, :, :]
|
| 81 |
+
x1 = self.ResNetEncoder(x1)
|
| 82 |
+
return x1
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class LowerEncoder(nn.Module):
|
| 86 |
+
def __init__(self,total_image_input=1):
|
| 87 |
+
super().__init__()
|
| 88 |
+
self.encoder_pre = ResidualBlock((total_image_input*3), 20)
|
| 89 |
+
self.encoder_layer1 = ResidualBlock(20, 30)
|
| 90 |
+
self.encoder_layer2 = ResidualBlock(30, 50)
|
| 91 |
+
|
| 92 |
+
self.encoder_layer3 = nn.Sequential(
|
| 93 |
+
ResidualBlock(50, 100),
|
| 94 |
+
nn.MaxPool2d(kernel_size=2, stride=2)
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
self.encoder_layer4 = ResidualBlock(100, 200)
|
| 98 |
+
self.encoder_layer5 = nn.Sequential(
|
| 99 |
+
ResidualBlock(200, 400),
|
| 100 |
+
nn.MaxPool2d(kernel_size=2, stride=2)
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
self.encoder_layer6 = ResidualBlock(400, 600)
|
| 104 |
+
self.encoder_layer7 = nn.Sequential(
|
| 105 |
+
ResidualBlock(600, 800),
|
| 106 |
+
nn.MaxPool2d(kernel_size=2, stride=2)
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
self.encoder_layer8 = ResidualBlock(800, 1000)
|
| 110 |
+
self.encoder_layer9 = nn.Sequential(
|
| 111 |
+
ResidualBlock(1000, 1200),
|
| 112 |
+
nn.MaxPool2d(kernel_size=2, stride=2)
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
self.encoder_layer10 = ResidualBlock(1200, 1400)
|
| 116 |
+
self.encoder_layer11 = ResidualBlock(1400, 1600)
|
| 117 |
+
|
| 118 |
+
def forward(self, x):
|
| 119 |
+
x = self.encoder_pre(x)
|
| 120 |
+
x = self.encoder_layer1(x)
|
| 121 |
+
x = self.encoder_layer2(x)
|
| 122 |
+
skip1 = self.encoder_layer3(x)
|
| 123 |
+
|
| 124 |
+
x = self.encoder_layer4(skip1)
|
| 125 |
+
skip2 = self.encoder_layer5(x)
|
| 126 |
+
|
| 127 |
+
x = self.encoder_layer6(skip2)
|
| 128 |
+
skip3 = self.encoder_layer7(x)
|
| 129 |
+
|
| 130 |
+
x = self.encoder_layer8(skip3)
|
| 131 |
+
skip4 = self.encoder_layer9(x)
|
| 132 |
+
|
| 133 |
+
x = self.encoder_layer10(skip4)
|
| 134 |
+
x = self.encoder_layer11(x)
|
| 135 |
+
|
| 136 |
+
return x, [skip1, skip2, skip3, skip4]
|
| 137 |
+
|
| 138 |
+
class MergeDecoder(nn.Module):
|
| 139 |
+
def __init__(self):
|
| 140 |
+
super().__init__()
|
| 141 |
+
|
| 142 |
+
self.decoder_layer1 = ResidualBlock(1600, 1400)
|
| 143 |
+
self.decoder_layer2 = ResidualBlock(1400, 1200)
|
| 144 |
+
self.decoder_layer3 = ResidualBlock(1200, 1000)
|
| 145 |
+
|
| 146 |
+
self.decoder_layer4 = nn.Sequential(
|
| 147 |
+
nn.ConvTranspose2d(1000, 800, 2, stride=2, padding=0),
|
| 148 |
+
nn.ReLU(True)
|
| 149 |
+
)
|
| 150 |
+
self.decoder_layer5 = ResidualBlock(800, 600)
|
| 151 |
+
|
| 152 |
+
self.decoder_layer6 = nn.Sequential(
|
| 153 |
+
nn.ConvTranspose2d(600, 400, 2, stride=2, padding=0),
|
| 154 |
+
nn.ReLU(True)
|
| 155 |
+
)
|
| 156 |
+
self.decoder_layer7 = ResidualBlock(400, 200)
|
| 157 |
+
|
| 158 |
+
self.decoder_layer8 = nn.Sequential(
|
| 159 |
+
nn.ConvTranspose2d(200, 100, 2, stride=2, padding=0),
|
| 160 |
+
nn.ReLU(True)
|
| 161 |
+
)
|
| 162 |
+
self.decoder_layer9 = ResidualBlock(100, 100)
|
| 163 |
+
|
| 164 |
+
self.decoder_layer10 = nn.Sequential(
|
| 165 |
+
nn.ConvTranspose2d(100, 100, 2, stride=2, padding=0),
|
| 166 |
+
nn.ReLU(True)
|
| 167 |
+
)
|
| 168 |
+
self.decoder_layer11 = ResidualBlock(100, 100)
|
| 169 |
+
self.decoder_layer12 = ResidualBlock(100, 50)
|
| 170 |
+
self.decoder_layer13 = ResidualBlock(50, 40)
|
| 171 |
+
self.decoder_layer14 = ResidualBlock(40, 20)
|
| 172 |
+
self.decoder_layer15 = nn.Sequential(
|
| 173 |
+
nn.Conv2d(20, 8, 3, stride=1, padding=1),
|
| 174 |
+
nn.ReLU(True)
|
| 175 |
+
)
|
| 176 |
+
self.decoder_layer16 = nn.Sequential(
|
| 177 |
+
nn.Conv2d(8, 1, 3, stride=1, padding=1),
|
| 178 |
+
nn.ReLU(True)
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
def forward(self, x, lower_skip_list, upper_skip_list):
|
| 182 |
+
x = self.decoder_layer1(x)
|
| 183 |
+
x = self.decoder_layer2(x)
|
| 184 |
+
x = x + lower_skip_list[3] + upper_skip_list[1]
|
| 185 |
+
|
| 186 |
+
x = self.decoder_layer3(x)
|
| 187 |
+
x = self.decoder_layer4(x)
|
| 188 |
+
x = x + lower_skip_list[2] + upper_skip_list[0]
|
| 189 |
+
|
| 190 |
+
x = self.decoder_layer5(x)
|
| 191 |
+
x = self.decoder_layer6(x)
|
| 192 |
+
x = x + lower_skip_list[1]
|
| 193 |
+
|
| 194 |
+
x = self.decoder_layer7(x)
|
| 195 |
+
x = self.decoder_layer8(x)
|
| 196 |
+
x = x + lower_skip_list[0]
|
| 197 |
+
|
| 198 |
+
x = self.decoder_layer9(x)
|
| 199 |
+
x = self.decoder_layer10(x)
|
| 200 |
+
x = self.decoder_layer11(x)
|
| 201 |
+
x = self.decoder_layer12(x)
|
| 202 |
+
x = self.decoder_layer13(x)
|
| 203 |
+
x = self.decoder_layer14(x)
|
| 204 |
+
x = self.decoder_layer15(x)
|
| 205 |
+
x = self.decoder_layer16(x)
|
| 206 |
+
return x
|
| 207 |
+
|
| 208 |
+
class PVSDNet(nn.Module):
|
| 209 |
+
def __init__(self,total_image_input=1):
|
| 210 |
+
super().__init__()
|
| 211 |
+
self.upper_encoder = UpperEncoder()
|
| 212 |
+
self.lower_encoder = LowerEncoder(total_image_input)
|
| 213 |
+
self.merge_decoder = MergeDecoder()
|
| 214 |
+
|
| 215 |
+
self.upper_encoder_extra_1 = nn.Sequential(
|
| 216 |
+
ResidualBlock(256, 800),
|
| 217 |
+
nn.MaxPool2d(kernel_size=2, stride=2)
|
| 218 |
+
)
|
| 219 |
+
self.upper_encoder_extra_2 = nn.Sequential(
|
| 220 |
+
ResidualBlock(800, 1200),
|
| 221 |
+
nn.MaxPool2d(kernel_size=2, stride=2)
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
def forward(self, x):
|
| 225 |
+
upper_features_1 = self.upper_encoder.apply_resnet_encoder(x)
|
| 226 |
+
upper_features_1 = self.upper_encoder_extra_1(upper_features_1)
|
| 227 |
+
upper_features_2 = self.upper_encoder_extra_2(upper_features_1)
|
| 228 |
+
|
| 229 |
+
lower_feature, skip_list = self.lower_encoder(x)
|
| 230 |
+
|
| 231 |
+
merged_feature = self.merge_decoder(lower_feature, skip_list, [upper_features_1, upper_features_2])
|
| 232 |
+
return merged_feature
|
requirements.txt
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
aiofiles==24.1.0
|
| 2 |
+
annotated-doc==0.0.4
|
| 3 |
+
annotated-types==0.7.0
|
| 4 |
+
anyio==4.12.1
|
| 5 |
+
av==16.0.1
|
| 6 |
+
blinker==1.9.0
|
| 7 |
+
brotli==1.2.0
|
| 8 |
+
certifi==2026.1.4
|
| 9 |
+
charset-normalizer==3.4.4
|
| 10 |
+
click==8.3.1
|
| 11 |
+
colorama==0.4.6
|
| 12 |
+
contourpy==1.3.3
|
| 13 |
+
cuda-toolkit==12.9.1
|
| 14 |
+
cycler==0.12.1
|
| 15 |
+
decorator==4.4.2
|
| 16 |
+
fastapi==0.128.0
|
| 17 |
+
ffmpy==1.0.0
|
| 18 |
+
filelock==3.20.0
|
| 19 |
+
Flask==3.1.2
|
| 20 |
+
fonttools==4.61.1
|
| 21 |
+
fsspec==2025.12.0
|
| 22 |
+
fvcore==0.1.5.post20221221
|
| 23 |
+
gradio==6.2.0
|
| 24 |
+
gradio_client==2.0.2
|
| 25 |
+
groovy==0.1.2
|
| 26 |
+
h11==0.16.0
|
| 27 |
+
hf-xet==1.2.0
|
| 28 |
+
httpcore==1.0.9
|
| 29 |
+
httpx==0.28.1
|
| 30 |
+
huggingface_hub==1.2.4
|
| 31 |
+
idna==3.11
|
| 32 |
+
ImageIO==2.37.2
|
| 33 |
+
imageio-ffmpeg==0.6.0
|
| 34 |
+
iopath==0.1.10
|
| 35 |
+
itsdangerous==2.2.0
|
| 36 |
+
Jinja2==3.1.6
|
| 37 |
+
joblib==1.5.3
|
| 38 |
+
kiwisolver==1.4.9
|
| 39 |
+
lazy_loader==0.4
|
| 40 |
+
Mako==1.3.10
|
| 41 |
+
markdown-it-py==4.0.0
|
| 42 |
+
MarkupSafe==2.1.5
|
| 43 |
+
matplotlib==3.10.8
|
| 44 |
+
matplotlib-inline==0.1.6
|
| 45 |
+
mdurl==0.1.2
|
| 46 |
+
ml_dtypes==0.5.4
|
| 47 |
+
moviepy==1.0.3
|
| 48 |
+
mpmath==1.3.0
|
| 49 |
+
networkx==3.6.1
|
| 50 |
+
numpy==1.26.4
|
| 51 |
+
nvidia-cuda-runtime-cu12==12.9.79
|
| 52 |
+
onnx==1.20.0
|
| 53 |
+
onnx-ir==0.1.14
|
| 54 |
+
onnxscript==0.5.7
|
| 55 |
+
opencv-python==4.6.0.66
|
| 56 |
+
orjson==3.11.5
|
| 57 |
+
packaging==25.0
|
| 58 |
+
pandas==2.3.3
|
| 59 |
+
parameterized==0.9.0
|
| 60 |
+
pillow==10.4.0
|
| 61 |
+
pillow_heif==0.15.0
|
| 62 |
+
platformdirs==4.5.1
|
| 63 |
+
portalocker==3.2.0
|
| 64 |
+
proglog==0.1.12
|
| 65 |
+
protobuf==6.33.2
|
| 66 |
+
pycuda==2025.1.2
|
| 67 |
+
pydantic==2.12.5
|
| 68 |
+
pydantic_core==2.41.5
|
| 69 |
+
pydub==0.25.1
|
| 70 |
+
Pygments==2.19.2
|
| 71 |
+
pyparsing==3.3.1
|
| 72 |
+
python-dateutil==2.9.0.post0
|
| 73 |
+
python-multipart==0.0.21
|
| 74 |
+
pytools==2025.2.5
|
| 75 |
+
pytorch-msssim==1.0.0
|
| 76 |
+
pytorchvideo==0.1.5
|
| 77 |
+
pytz==2025.2
|
| 78 |
+
pywin32==311
|
| 79 |
+
PyYAML==6.0.3
|
| 80 |
+
requests==2.32.5
|
| 81 |
+
rich==14.2.0
|
| 82 |
+
safehttpx==0.1.7
|
| 83 |
+
safetensors==0.7.0
|
| 84 |
+
scikit-image==0.26.0
|
| 85 |
+
scikit-learn==1.8.0
|
| 86 |
+
scipy==1.11.2
|
| 87 |
+
semantic-version==2.10.0
|
| 88 |
+
setuptools==80.9.0
|
| 89 |
+
shellingham==1.5.4
|
| 90 |
+
siphash24==1.8
|
| 91 |
+
six==1.17.0
|
| 92 |
+
starlette==0.50.0
|
| 93 |
+
sympy==1.14.0
|
| 94 |
+
tabulate==0.9.0
|
| 95 |
+
tensorrt_cu12==10.14.1.48.post1
|
| 96 |
+
tensorrt_cu12_bindings==10.14.1.48.post1
|
| 97 |
+
tensorrt_cu12_libs==10.14.1.48.post1
|
| 98 |
+
tensorrt_dispatch_cu12==10.14.1.48.post1
|
| 99 |
+
tensorrt_dispatch_cu12_bindings==10.14.1.48.post1
|
| 100 |
+
tensorrt_dispatch_cu12_libs==10.14.1.48.post1
|
| 101 |
+
tensorrt_lean_cu12==10.14.1.48.post1
|
| 102 |
+
tensorrt_lean_cu12_bindings==10.14.1.48.post1
|
| 103 |
+
tensorrt_lean_cu12_libs==10.14.1.48.post1
|
| 104 |
+
termcolor==3.3.0
|
| 105 |
+
threadpoolctl==3.6.0
|
| 106 |
+
tifffile==2025.12.20
|
| 107 |
+
timm==1.0.24
|
| 108 |
+
tomlkit==0.13.3
|
| 109 |
+
torch==2.9.1+cu130
|
| 110 |
+
torchvision==0.24.1+cu130
|
| 111 |
+
tqdm==4.65.0
|
| 112 |
+
traitlets==5.14.3
|
| 113 |
+
typer==0.21.1
|
| 114 |
+
typer-slim==0.21.1
|
| 115 |
+
typing-inspection==0.4.2
|
| 116 |
+
typing_extensions==4.15.0
|
| 117 |
+
tzdata==2025.3
|
| 118 |
+
urllib3==2.6.3
|
| 119 |
+
uvicorn==0.40.0
|
| 120 |
+
Werkzeug==3.1.5
|
| 121 |
+
wheel==0.45.1
|
| 122 |
+
yacs==0.1.8
|
rff_torch.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from torch import Tensor
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
@torch.jit.script
|
| 7 |
+
def positional_encoding(
|
| 8 |
+
v: Tensor,
|
| 9 |
+
sigma: float,
|
| 10 |
+
m: int) -> Tensor:
|
| 11 |
+
r"""Computes :math:`\gamma(\mathbf{v}) = (\dots, \cos{2 \pi \sigma^{(j/m)} \mathbf{v}} , \sin{2 \pi \sigma^{(j/m)} \mathbf{v}}, \dots)`
|
| 12 |
+
where :math:`j \in \{0, \dots, m-1\}`
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
v (Tensor): input tensor of shape :math:`(N, *, \text{input_size})`
|
| 16 |
+
sigma (float): constant chosen based upon the domain of :attr:`v`
|
| 17 |
+
m (int): [description]
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
Tensor: mapped tensor of shape :math:`(N, *, 2 \cdot m \cdot \text{input_size})`
|
| 21 |
+
|
| 22 |
+
See :class:`~rff.layers.PositionalEncoding` for more details.
|
| 23 |
+
"""
|
| 24 |
+
j = torch.arange(m, device=v.device)
|
| 25 |
+
coeffs = 2 * np.pi * sigma ** (j / m)
|
| 26 |
+
vp = coeffs * torch.unsqueeze(v, -1)
|
| 27 |
+
vp_cat = torch.cat((torch.cos(vp), torch.sin(vp)), dim=-1)
|
| 28 |
+
return vp_cat.flatten(-2, -1)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class PositionalEncoding(nn.Module):
|
| 32 |
+
"""Layer for mapping coordinates using the positional encoding"""
|
| 33 |
+
|
| 34 |
+
def __init__(self, sigma: float, m: int):
|
| 35 |
+
r"""
|
| 36 |
+
Args:
|
| 37 |
+
sigma (float): frequency constant
|
| 38 |
+
m (int): number of frequencies to map to
|
| 39 |
+
"""
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.sigma = sigma
|
| 42 |
+
self.m = m
|
| 43 |
+
|
| 44 |
+
def forward(self, v: Tensor) -> Tensor:
|
| 45 |
+
r"""Computes :math:`\gamma(\mathbf{v}) = (\dots, \cos{2 \pi \sigma^{(j/m)} \mathbf{v}} , \sin{2 \pi \sigma^{(j/m)} \mathbf{v}}, \dots)`
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
v (Tensor): input tensor of shape :math:`(N, *, \text{input_size})`
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
Tensor: mapped tensor of shape :math:`(N, *, 2 \cdot m \cdot \text{input_size})`
|
| 52 |
+
"""
|
| 53 |
+
return positional_encoding(v, self.sigma, self.m)
|
utils.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from scipy.ndimage import map_coordinates
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def xyzcube(face_w):
|
| 6 |
+
'''
|
| 7 |
+
Return the xyz cordinates of the unit cube in [F R B L U D] format.
|
| 8 |
+
'''
|
| 9 |
+
out = np.zeros((face_w, face_w * 6, 3), np.float32)
|
| 10 |
+
rng = np.linspace(-0.5, 0.5, num=face_w, dtype=np.float32)
|
| 11 |
+
grid = np.stack(np.meshgrid(rng, -rng), -1)
|
| 12 |
+
|
| 13 |
+
# Front face (z = 0.5)
|
| 14 |
+
out[:, 0*face_w:1*face_w, [0, 1]] = grid
|
| 15 |
+
out[:, 0*face_w:1*face_w, 2] = 0.5
|
| 16 |
+
|
| 17 |
+
# Right face (x = 0.5)
|
| 18 |
+
out[:, 1*face_w:2*face_w, [2, 1]] = grid
|
| 19 |
+
out[:, 1*face_w:2*face_w, 0] = 0.5
|
| 20 |
+
|
| 21 |
+
# Back face (z = -0.5)
|
| 22 |
+
out[:, 2*face_w:3*face_w, [0, 1]] = grid
|
| 23 |
+
out[:, 2*face_w:3*face_w, 2] = -0.5
|
| 24 |
+
|
| 25 |
+
# Left face (x = -0.5)
|
| 26 |
+
out[:, 3*face_w:4*face_w, [2, 1]] = grid
|
| 27 |
+
out[:, 3*face_w:4*face_w, 0] = -0.5
|
| 28 |
+
|
| 29 |
+
# Up face (y = 0.5)
|
| 30 |
+
out[:, 4*face_w:5*face_w, [0, 2]] = grid
|
| 31 |
+
out[:, 4*face_w:5*face_w, 1] = 0.5
|
| 32 |
+
|
| 33 |
+
# Down face (y = -0.5)
|
| 34 |
+
out[:, 5*face_w:6*face_w, [0, 2]] = grid
|
| 35 |
+
out[:, 5*face_w:6*face_w, 1] = -0.5
|
| 36 |
+
|
| 37 |
+
return out
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def equirect_uvgrid(h, w):
|
| 41 |
+
u = np.linspace(-np.pi, np.pi, num=w, dtype=np.float32)
|
| 42 |
+
v = np.linspace(np.pi, -np.pi, num=h, dtype=np.float32) / 2
|
| 43 |
+
|
| 44 |
+
return np.stack(np.meshgrid(u, v), axis=-1)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def equirect_facetype(h, w):
|
| 48 |
+
'''
|
| 49 |
+
0F 1R 2B 3L 4U 5D
|
| 50 |
+
'''
|
| 51 |
+
tp = np.roll(np.arange(4).repeat(w // 4)[None, :].repeat(h, 0), 3 * w // 8, 1)
|
| 52 |
+
|
| 53 |
+
# Prepare ceil mask
|
| 54 |
+
mask = np.zeros((h, w // 4), np.bool)
|
| 55 |
+
idx = np.linspace(-np.pi, np.pi, w // 4) / 4
|
| 56 |
+
idx = h // 2 - np.round(np.arctan(np.cos(idx)) * h / np.pi).astype(int)
|
| 57 |
+
for i, j in enumerate(idx):
|
| 58 |
+
mask[:j, i] = 1
|
| 59 |
+
mask = np.roll(np.concatenate([mask] * 4, 1), 3 * w // 8, 1)
|
| 60 |
+
|
| 61 |
+
tp[mask] = 4
|
| 62 |
+
tp[np.flip(mask, 0)] = 5
|
| 63 |
+
|
| 64 |
+
return tp.astype(np.int32)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def xyzpers(h_fov, v_fov, u, v, out_hw, in_rot):
|
| 68 |
+
out = np.ones((*out_hw, 3), np.float32)
|
| 69 |
+
|
| 70 |
+
x_max = np.tan(h_fov / 2)
|
| 71 |
+
y_max = np.tan(v_fov / 2)
|
| 72 |
+
x_rng = np.linspace(-x_max, x_max, num=out_hw[1], dtype=np.float32)
|
| 73 |
+
y_rng = np.linspace(-y_max, y_max, num=out_hw[0], dtype=np.float32)
|
| 74 |
+
out[..., :2] = np.stack(np.meshgrid(x_rng, -y_rng), -1)
|
| 75 |
+
Rx = rotation_matrix(v, [1, 0, 0])
|
| 76 |
+
Ry = rotation_matrix(u, [0, 1, 0])
|
| 77 |
+
Ri = rotation_matrix(in_rot, np.array([0, 0, 1.0]).dot(Rx).dot(Ry))
|
| 78 |
+
|
| 79 |
+
return out.dot(Rx).dot(Ry).dot(Ri)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def xyz2uv(xyz):
|
| 83 |
+
'''
|
| 84 |
+
xyz: ndarray in shape of [..., 3]
|
| 85 |
+
'''
|
| 86 |
+
x, y, z = np.split(xyz, 3, axis=-1)
|
| 87 |
+
u = np.arctan2(x, z)
|
| 88 |
+
c = np.sqrt(x**2 + z**2)
|
| 89 |
+
v = np.arctan2(y, c)
|
| 90 |
+
|
| 91 |
+
return np.concatenate([u, v], axis=-1)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def uv2unitxyz(uv):
|
| 95 |
+
u, v = np.split(uv, 2, axis=-1)
|
| 96 |
+
y = np.sin(v)
|
| 97 |
+
c = np.cos(v)
|
| 98 |
+
x = c * np.sin(u)
|
| 99 |
+
z = c * np.cos(u)
|
| 100 |
+
|
| 101 |
+
return np.concatenate([x, y, z], axis=-1)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def uv2coor(uv, h, w):
|
| 105 |
+
'''
|
| 106 |
+
uv: ndarray in shape of [..., 2]
|
| 107 |
+
h: int, height of the equirectangular image
|
| 108 |
+
w: int, width of the equirectangular image
|
| 109 |
+
'''
|
| 110 |
+
u, v = np.split(uv, 2, axis=-1)
|
| 111 |
+
coor_x = (u / (2 * np.pi) + 0.5) * w - 0.5
|
| 112 |
+
coor_y = (-v / np.pi + 0.5) * h - 0.5
|
| 113 |
+
|
| 114 |
+
return np.concatenate([coor_x, coor_y], axis=-1)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def coor2uv(coorxy, h, w):
|
| 118 |
+
coor_x, coor_y = np.split(coorxy, 2, axis=-1)
|
| 119 |
+
u = ((coor_x + 0.5) / w - 0.5) * 2 * np.pi
|
| 120 |
+
v = -((coor_y + 0.5) / h - 0.5) * np.pi
|
| 121 |
+
|
| 122 |
+
return np.concatenate([u, v], axis=-1)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def sample_equirec(e_img, coor_xy, order):
|
| 126 |
+
w = e_img.shape[1]
|
| 127 |
+
coor_x, coor_y = np.split(coor_xy, 2, axis=-1)
|
| 128 |
+
pad_u = np.roll(e_img[[0]], w // 2, 1)
|
| 129 |
+
pad_d = np.roll(e_img[[-1]], w // 2, 1)
|
| 130 |
+
e_img = np.concatenate([e_img, pad_d, pad_u], 0)
|
| 131 |
+
return map_coordinates(e_img, [coor_y, coor_x],
|
| 132 |
+
order=order, mode='wrap')[..., 0]
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def sample_cubefaces(cube_faces, tp, coor_y, coor_x, order):
|
| 136 |
+
cube_faces = cube_faces.copy()
|
| 137 |
+
cube_faces[1] = np.flip(cube_faces[1], 1)
|
| 138 |
+
cube_faces[2] = np.flip(cube_faces[2], 1)
|
| 139 |
+
cube_faces[4] = np.flip(cube_faces[4], 0)
|
| 140 |
+
|
| 141 |
+
# Pad up down
|
| 142 |
+
pad_ud = np.zeros((6, 2, cube_faces.shape[2]))
|
| 143 |
+
pad_ud[0, 0] = cube_faces[5, 0, :]
|
| 144 |
+
pad_ud[0, 1] = cube_faces[4, -1, :]
|
| 145 |
+
pad_ud[1, 0] = cube_faces[5, :, -1]
|
| 146 |
+
pad_ud[1, 1] = cube_faces[4, ::-1, -1]
|
| 147 |
+
pad_ud[2, 0] = cube_faces[5, -1, ::-1]
|
| 148 |
+
pad_ud[2, 1] = cube_faces[4, 0, ::-1]
|
| 149 |
+
pad_ud[3, 0] = cube_faces[5, ::-1, 0]
|
| 150 |
+
pad_ud[3, 1] = cube_faces[4, :, 0]
|
| 151 |
+
pad_ud[4, 0] = cube_faces[0, 0, :]
|
| 152 |
+
pad_ud[4, 1] = cube_faces[2, 0, ::-1]
|
| 153 |
+
pad_ud[5, 0] = cube_faces[2, -1, ::-1]
|
| 154 |
+
pad_ud[5, 1] = cube_faces[0, -1, :]
|
| 155 |
+
cube_faces = np.concatenate([cube_faces, pad_ud], 1)
|
| 156 |
+
|
| 157 |
+
# Pad left right
|
| 158 |
+
pad_lr = np.zeros((6, cube_faces.shape[1], 2))
|
| 159 |
+
pad_lr[0, :, 0] = cube_faces[1, :, 0]
|
| 160 |
+
pad_lr[0, :, 1] = cube_faces[3, :, -1]
|
| 161 |
+
pad_lr[1, :, 0] = cube_faces[2, :, 0]
|
| 162 |
+
pad_lr[1, :, 1] = cube_faces[0, :, -1]
|
| 163 |
+
pad_lr[2, :, 0] = cube_faces[3, :, 0]
|
| 164 |
+
pad_lr[2, :, 1] = cube_faces[1, :, -1]
|
| 165 |
+
pad_lr[3, :, 0] = cube_faces[0, :, 0]
|
| 166 |
+
pad_lr[3, :, 1] = cube_faces[2, :, -1]
|
| 167 |
+
pad_lr[4, 1:-1, 0] = cube_faces[1, 0, ::-1]
|
| 168 |
+
pad_lr[4, 1:-1, 1] = cube_faces[3, 0, :]
|
| 169 |
+
pad_lr[5, 1:-1, 0] = cube_faces[1, -2, :]
|
| 170 |
+
pad_lr[5, 1:-1, 1] = cube_faces[3, -2, ::-1]
|
| 171 |
+
cube_faces = np.concatenate([cube_faces, pad_lr], 2)
|
| 172 |
+
|
| 173 |
+
return map_coordinates(cube_faces, [tp, coor_y, coor_x], order=order, mode='wrap')
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def cube_h2list(cube_h):
|
| 177 |
+
assert cube_h.shape[0] * 6 == cube_h.shape[1]
|
| 178 |
+
return np.split(cube_h, 6, axis=1)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def cube_list2h(cube_list):
|
| 182 |
+
assert len(cube_list) == 6
|
| 183 |
+
assert sum(face.shape == cube_list[0].shape for face in cube_list) == 6
|
| 184 |
+
return np.concatenate(cube_list, axis=1)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def cube_h2dict(cube_h):
|
| 188 |
+
cube_list = cube_h2list(cube_h)
|
| 189 |
+
return dict([(k, cube_list[i])
|
| 190 |
+
for i, k in enumerate(['F', 'R', 'B', 'L', 'U', 'D'])])
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def cube_dict2h(cube_dict, face_k=['F', 'R', 'B', 'L', 'U', 'D']):
|
| 194 |
+
assert len(face_k) == 6
|
| 195 |
+
return cube_list2h([cube_dict[k] for k in face_k])
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def cube_h2dice(cube_h):
|
| 199 |
+
assert cube_h.shape[0] * 6 == cube_h.shape[1]
|
| 200 |
+
w = cube_h.shape[0]
|
| 201 |
+
cube_dice = np.zeros((w * 3, w * 4, cube_h.shape[2]), dtype=cube_h.dtype)
|
| 202 |
+
cube_list = cube_h2list(cube_h)
|
| 203 |
+
# Order: F R B L U D
|
| 204 |
+
sxy = [(1, 1), (2, 1), (3, 1), (0, 1), (1, 0), (1, 2)]
|
| 205 |
+
for i, (sx, sy) in enumerate(sxy):
|
| 206 |
+
face = cube_list[i]
|
| 207 |
+
if i in [1, 2]:
|
| 208 |
+
face = np.flip(face, axis=1)
|
| 209 |
+
if i == 4:
|
| 210 |
+
face = np.flip(face, axis=0)
|
| 211 |
+
cube_dice[sy*w:(sy+1)*w, sx*w:(sx+1)*w] = face
|
| 212 |
+
return cube_dice
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def cube_dice2h(cube_dice):
|
| 216 |
+
w = cube_dice.shape[0] // 3
|
| 217 |
+
assert cube_dice.shape[0] == w * 3 and cube_dice.shape[1] == w * 4
|
| 218 |
+
cube_h = np.zeros((w, w * 6, cube_dice.shape[2]), dtype=cube_dice.dtype)
|
| 219 |
+
# Order: F R B L U D
|
| 220 |
+
sxy = [(1, 1), (2, 1), (3, 1), (0, 1), (1, 0), (1, 2)]
|
| 221 |
+
for i, (sx, sy) in enumerate(sxy):
|
| 222 |
+
face = cube_dice[sy*w:(sy+1)*w, sx*w:(sx+1)*w]
|
| 223 |
+
if i in [1, 2]:
|
| 224 |
+
face = np.flip(face, axis=1)
|
| 225 |
+
if i == 4:
|
| 226 |
+
face = np.flip(face, axis=0)
|
| 227 |
+
cube_h[:, i*w:(i+1)*w] = face
|
| 228 |
+
return cube_h
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def rotation_matrix(rad, ax):
|
| 232 |
+
ax = np.array(ax)
|
| 233 |
+
assert len(ax.shape) == 1 and ax.shape[0] == 3
|
| 234 |
+
ax = ax / np.sqrt((ax**2).sum())
|
| 235 |
+
R = np.diag([np.cos(rad)] * 3)
|
| 236 |
+
R = R + np.outer(ax, ax) * (1.0 - np.cos(rad))
|
| 237 |
+
|
| 238 |
+
ax = ax * np.sin(rad)
|
| 239 |
+
R = R + np.array([[0, -ax[2], ax[1]],
|
| 240 |
+
[ax[2], 0, -ax[0]],
|
| 241 |
+
[-ax[1], ax[0], 0]])
|
| 242 |
+
|
| 243 |
+
return R
|