Save work before migration
Browse files- .gitattributes +1 -0
- .gitignore +2 -0
- README.md +18 -1
- app.py +248 -0
- examples/alice.jpg +3 -0
- examples/example_1.jpg +3 -0
- examples/example_2.jpg +3 -0
- examples/greenhouse.jpg +3 -0
- requirements.txt +18 -0
- src/__init__.py +0 -0
- src/pager.py +308 -0
- src/utils/__init__.py +0 -0
- src/utils/conv_padding.py +123 -0
- src/utils/geometry_utils.py +270 -0
- src/utils/loss.py +68 -0
- src/utils/lr_scheduler.py +41 -0
- src/utils/utils.py +214 -0
.gitattributes
CHANGED
|
@@ -33,4 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 36 |
*.jpg filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.jpg~ filter=lfs diff=lfs merge=lfs -text
|
| 37 |
*.jpg filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
.vscode/
|
README.md
CHANGED
|
@@ -8,7 +8,24 @@ sdk_version: 6.0.2
|
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: apache-2.0
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
---
|
| 13 |
|
| 14 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: apache-2.0
|
| 11 |
+
python_version: 3.10
|
| 12 |
+
models:
|
| 13 |
+
- prs-eth/PaGeR-depth
|
| 14 |
+
- prs-eth/PaGeR-normals
|
| 15 |
+
tags:
|
| 16 |
+
- computer-vision
|
| 17 |
+
- image-processing
|
| 18 |
+
- diffusion-models
|
| 19 |
+
- panorama
|
| 20 |
+
- geometry-estimation
|
| 21 |
+
- depth-estimation
|
| 22 |
+
- normal-estimation
|
| 23 |
+
- single-step-diffusion
|
| 24 |
+
preload_from_hub:
|
| 25 |
+
- prs-eth/PaGeR-depth
|
| 26 |
+
- prs-eth/PaGeR-normals
|
| 27 |
+
suggested_hardware: a10g-large
|
| 28 |
+
short_description: Panorama Geometry Estimation
|
| 29 |
---
|
| 30 |
|
| 31 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import gc
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
import argparse
|
| 6 |
+
import logging
|
| 7 |
+
import gradio as gr
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from omegaconf import OmegaConf
|
| 11 |
+
from tempfile import NamedTemporaryFile
|
| 12 |
+
from huggingface_hub import hf_hub_download
|
| 13 |
+
from matplotlib import pyplot as plt
|
| 14 |
+
from src.pager import Pager
|
| 15 |
+
from src.utils.geometry_utils import compute_edge_mask, erp_to_point_cloud_glb, erp_to_cubemap
|
| 16 |
+
from src.utils.utils import prepare_image_for_logging
|
| 17 |
+
|
| 18 |
+
MIN_DEPTH = np.log(1e-2)
|
| 19 |
+
DEPTH_RANGE = np.log(75.0)
|
| 20 |
+
POINTCLOUD_DOWNSAMPLE_FACTOR = 2
|
| 21 |
+
MAX_POINTCLOUD_POINTS = 200000
|
| 22 |
+
EXAMPLES_DIR = Path(__file__).parent / "examples"
|
| 23 |
+
EXAMPLE_IMAGES = [
|
| 24 |
+
str(p)
|
| 25 |
+
for p in sorted(EXAMPLES_DIR.glob("*"))
|
| 26 |
+
if p.suffix.lower() in {".jpg", ".jpeg", ".png", ".webp"}
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
def parse_args():
|
| 30 |
+
parser = argparse.ArgumentParser(description="Inference script for panorama depth estimation using diffusion models.")
|
| 31 |
+
|
| 32 |
+
parser.add_argument(
|
| 33 |
+
"--seed",
|
| 34 |
+
type=int,
|
| 35 |
+
default=42,
|
| 36 |
+
help="A seed for reproducibility."
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
parser.add_argument(
|
| 40 |
+
"--depth_checkpoint_path",
|
| 41 |
+
default="prs-eth/PaGeR-depth",
|
| 42 |
+
type=str,
|
| 43 |
+
help="UNet checkpoint to load.",
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
"--normals_checkpoint_path",
|
| 48 |
+
default="prs-eth/PaGeR-normals",
|
| 49 |
+
type=str,
|
| 50 |
+
help="UNet checkpoint to load.",
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
parser.add_argument(
|
| 54 |
+
"--enable_xformers",
|
| 55 |
+
action="store_true",
|
| 56 |
+
help="Whether or not to use xformers."
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
args = parser.parse_args()
|
| 60 |
+
return args
|
| 61 |
+
|
| 62 |
+
def _release_cuda_memory():
|
| 63 |
+
if torch.cuda.is_available():
|
| 64 |
+
torch.cuda.empty_cache()
|
| 65 |
+
torch.cuda.ipc_collect()
|
| 66 |
+
gc.collect()
|
| 67 |
+
|
| 68 |
+
def generate_ERP(input_rgb, modality):
|
| 69 |
+
batch = {}
|
| 70 |
+
input_rgb = torch.from_numpy(input_rgb).permute(2,0,1).to(torch.float32) / 255.0
|
| 71 |
+
input_rgb = input_rgb * 2.0 - 1.0
|
| 72 |
+
batch['rgb_cubemap'] = erp_to_cubemap(input_rgb).unsqueeze(0).to(device)
|
| 73 |
+
with torch.inference_mode():
|
| 74 |
+
torch.cuda.reset_peak_memory_stats()
|
| 75 |
+
torch.cuda.empty_cache()
|
| 76 |
+
torch.cuda.synchronize()
|
| 77 |
+
pred_cubemap = pager(batch, modality)
|
| 78 |
+
if modality == "depth":
|
| 79 |
+
pred, pred_image = pager.process_depth_output(pred_cubemap, orig_size=(1024, 2048),
|
| 80 |
+
min_depth=MIN_DEPTH,
|
| 81 |
+
depth_range=DEPTH_RANGE,
|
| 82 |
+
log_scale=pager.model_configs["depth"]["config"].log_scale)
|
| 83 |
+
pred, pred_image = pred[0].cpu().numpy(), pred_image.cpu().numpy()
|
| 84 |
+
pred_image = np.clip(pred_image, pred_image.min(), np.quantile(pred_image, 0.99))
|
| 85 |
+
pred_image = prepare_image_for_logging(pred_image)
|
| 86 |
+
pred_image = cmap(pred_image[0,...]/255.0)
|
| 87 |
+
pred_image = (pred_image[..., :3] * 255).astype(np.uint8)
|
| 88 |
+
elif modality == "normal":
|
| 89 |
+
pred = pager.process_normal_output(pred_cubemap, orig_size=(1024, 2048))
|
| 90 |
+
pred = pred.cpu().numpy()
|
| 91 |
+
pred_image = pred.copy()
|
| 92 |
+
pred_image = prepare_image_for_logging(pred_image).transpose(1,2,0)
|
| 93 |
+
|
| 94 |
+
return pred_image, pred
|
| 95 |
+
|
| 96 |
+
def process_panorama(image_path, output_type, include_pointcloud):
|
| 97 |
+
loaded_image = Image.open(image_path).convert("RGB").resize((2048, 1024))
|
| 98 |
+
input_rgb = np.array(loaded_image)
|
| 99 |
+
|
| 100 |
+
modality = "depth" if output_type.lower() == "depth" else "normal"
|
| 101 |
+
is_depth = modality == "depth"
|
| 102 |
+
main_label = "Depth Output" if is_depth else "Surface Normal Output"
|
| 103 |
+
pc_label = (
|
| 104 |
+
"RGB-colored Point Cloud" if is_depth else "Surface Normals-Colored Point Cloud"
|
| 105 |
+
)
|
| 106 |
+
output_image, raw_pred = generate_ERP(input_rgb, modality)
|
| 107 |
+
|
| 108 |
+
point_cloud = None
|
| 109 |
+
if include_pointcloud:
|
| 110 |
+
if is_depth:
|
| 111 |
+
depth = np.squeeze(np.array(raw_pred))
|
| 112 |
+
color = (input_rgb.astype(np.float32) / 127.5) - 1.0
|
| 113 |
+
else:
|
| 114 |
+
color = np.array(raw_pred)
|
| 115 |
+
color = np.transpose(color, (1, 2, 0))
|
| 116 |
+
_release_cuda_memory()
|
| 117 |
+
depth = np.squeeze(generate_ERP(input_rgb, "depth", )[1])
|
| 118 |
+
|
| 119 |
+
edge_filtered_mask = compute_edge_mask(
|
| 120 |
+
depth,
|
| 121 |
+
abs_thresh=0.002,
|
| 122 |
+
rel_thresh=0.002,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
if POINTCLOUD_DOWNSAMPLE_FACTOR > 1:
|
| 126 |
+
depth = depth[::POINTCLOUD_DOWNSAMPLE_FACTOR, ::POINTCLOUD_DOWNSAMPLE_FACTOR]
|
| 127 |
+
color = color[::POINTCLOUD_DOWNSAMPLE_FACTOR, ::POINTCLOUD_DOWNSAMPLE_FACTOR]
|
| 128 |
+
edge_filtered_mask = edge_filtered_mask[::POINTCLOUD_DOWNSAMPLE_FACTOR, ::POINTCLOUD_DOWNSAMPLE_FACTOR]
|
| 129 |
+
|
| 130 |
+
tmp = NamedTemporaryFile(suffix=".glb", delete=False)
|
| 131 |
+
erp_to_point_cloud_glb(
|
| 132 |
+
color, depth, edge_filtered_mask, export_path=tmp.name)
|
| 133 |
+
|
| 134 |
+
tmp.close()
|
| 135 |
+
point_cloud = tmp.name
|
| 136 |
+
|
| 137 |
+
_release_cuda_memory()
|
| 138 |
+
|
| 139 |
+
return (
|
| 140 |
+
gr.update(value=output_image, label=main_label),
|
| 141 |
+
gr.update(value=point_cloud, label=pc_label),
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def clear_pointcloud():
|
| 146 |
+
return gr.update(value=None)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
args = parse_args()
|
| 150 |
+
|
| 151 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 152 |
+
|
| 153 |
+
logger = logging.getLogger("simple")
|
| 154 |
+
handler = logging.StreamHandler(sys.stdout)
|
| 155 |
+
formatter = logging.Formatter("%(message)s")
|
| 156 |
+
handler.setFormatter(formatter)
|
| 157 |
+
logger.addHandler(handler)
|
| 158 |
+
logger.setLevel(logging.INFO)
|
| 159 |
+
logger.propagate = False
|
| 160 |
+
cmap = plt.get_cmap("Spectral")
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
checkpoint_config = {}
|
| 164 |
+
try:
|
| 165 |
+
depth_checkpoint_config_path = hf_hub_download(
|
| 166 |
+
repo_id=args.depth_checkpoint_path,
|
| 167 |
+
filename="config.yaml"
|
| 168 |
+
)
|
| 169 |
+
except Exception as e:
|
| 170 |
+
depth_checkpoint_config_path = Path(args.depth_checkpoint_path) / "config.yaml"
|
| 171 |
+
depth_config = OmegaConf.load(depth_checkpoint_config_path)
|
| 172 |
+
checkpoint_config["depth"] = {"path": args.depth_checkpoint_path, "mode": "trained", "config": depth_config.model}
|
| 173 |
+
|
| 174 |
+
try:
|
| 175 |
+
normal_checkpoint_config_path = hf_hub_download(
|
| 176 |
+
repo_id=args.normals_checkpoint_path,
|
| 177 |
+
filename="config.yaml"
|
| 178 |
+
)
|
| 179 |
+
except Exception as e:
|
| 180 |
+
normal_checkpoint_config_path = Path(args.normals_checkpoint_path) / "config.yaml"
|
| 181 |
+
normal_config = OmegaConf.load(normal_checkpoint_config_path)
|
| 182 |
+
checkpoint_config["normal"] = {"path": args.normals_checkpoint_path, "mode": "trained", "config": normal_config.model}
|
| 183 |
+
|
| 184 |
+
pager = Pager(model_configs=checkpoint_config, pretrained_path = depth_config.model.pretrained_path, device=device)
|
| 185 |
+
pager.unet["depth"].to(device, dtype=pager.weight_dtype)
|
| 186 |
+
pager.unet["depth"].eval()
|
| 187 |
+
pager.unet["normal"].to(device, dtype=pager.weight_dtype)
|
| 188 |
+
pager.unet["normal"].eval()
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
with gr.Blocks() as demo:
|
| 192 |
+
gr.Markdown("## 📟 PaGeR: Panoramic Geometry Reconstruction")
|
| 193 |
+
|
| 194 |
+
with gr.Row():
|
| 195 |
+
with gr.Column(scale=1):
|
| 196 |
+
image_input = gr.Image(
|
| 197 |
+
label="RGB ERP Image",
|
| 198 |
+
type="filepath",
|
| 199 |
+
height=320,
|
| 200 |
+
)
|
| 201 |
+
output_choice = gr.Radio(
|
| 202 |
+
["Depth", "Surface Normals"],
|
| 203 |
+
value="Depth",
|
| 204 |
+
label="Output Type",
|
| 205 |
+
)
|
| 206 |
+
pointcloud_checkbox = gr.Checkbox(
|
| 207 |
+
label="Generate Point Cloud",
|
| 208 |
+
value=True,
|
| 209 |
+
)
|
| 210 |
+
gr.Examples(
|
| 211 |
+
examples=EXAMPLE_IMAGES,
|
| 212 |
+
inputs=image_input,
|
| 213 |
+
label="Pick an example (or upload your own above)",
|
| 214 |
+
examples_per_page=8,
|
| 215 |
+
cache_examples=False,
|
| 216 |
+
)
|
| 217 |
+
run_button = gr.Button("Run Inference")
|
| 218 |
+
|
| 219 |
+
with gr.Column(scale=1):
|
| 220 |
+
rendered_output = gr.Image(
|
| 221 |
+
label="Output",
|
| 222 |
+
type="numpy",
|
| 223 |
+
height=320,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
with gr.Row():
|
| 227 |
+
pointcloud_output = gr.Model3D(
|
| 228 |
+
label="Point Cloud",
|
| 229 |
+
height=360,
|
| 230 |
+
clear_color=[0.0, 0.0, 0.0, 0.0],
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
(
|
| 234 |
+
run_button.click(
|
| 235 |
+
fn=clear_pointcloud,
|
| 236 |
+
outputs=pointcloud_output,
|
| 237 |
+
queue=False,
|
| 238 |
+
)
|
| 239 |
+
.then(
|
| 240 |
+
fn=process_panorama,
|
| 241 |
+
inputs=[image_input, output_choice, pointcloud_checkbox],
|
| 242 |
+
outputs=[rendered_output, pointcloud_output],
|
| 243 |
+
)
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
if __name__ == "__main__":
|
| 247 |
+
_release_cuda_memory()
|
| 248 |
+
demo.launch()
|
examples/alice.jpg
ADDED
|
Git LFS Details
|
examples/example_1.jpg
ADDED
|
Git LFS Details
|
examples/example_2.jpg
ADDED
|
Git LFS Details
|
examples/greenhouse.jpg
ADDED
|
Git LFS Details
|
requirements.txt
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.2.0
|
| 2 |
+
xformers==0.0.24
|
| 3 |
+
accelerate==0.27.2
|
| 4 |
+
gradio==6.0.2
|
| 5 |
+
huggingface-hub==0.36.0
|
| 6 |
+
transformers
|
| 7 |
+
diffusers==0.30.2
|
| 8 |
+
numpy==1.26.4
|
| 9 |
+
scipy==1.15.1
|
| 10 |
+
matplotlib==3.10.0
|
| 11 |
+
tqdm==4.67.1
|
| 12 |
+
einops==0.8.1
|
| 13 |
+
datasets==3.3.0
|
| 14 |
+
python-dotenv==1.1.1
|
| 15 |
+
wandb==0.19.6
|
| 16 |
+
opencv-python==4.11.0.86
|
| 17 |
+
pytorch360convert==0.2.3
|
| 18 |
+
trimesh==4.9.0
|
src/__init__.py
ADDED
|
File without changes
|
src/pager.py
ADDED
|
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torch.nn import Conv2d
|
| 4 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
| 5 |
+
from diffusers import DDPMScheduler
|
| 6 |
+
from diffusers.utils.import_utils import is_xformers_available
|
| 7 |
+
from Marigold.unet.unet_2d_condition import UNet2DConditionModel
|
| 8 |
+
from Marigold.vae.autoencoder_kl import AutoencoderKL
|
| 9 |
+
from src.utils.conv_padding import PaddedConv2d, valid_pad_conv_fn
|
| 10 |
+
from src.utils.loss import L1Loss, GradL1Loss, CosineNormalLoss
|
| 11 |
+
from src.utils.geometry_utils import (
|
| 12 |
+
get_positional_encoding,
|
| 13 |
+
compute_scale_and_shift,
|
| 14 |
+
compute_shift,
|
| 15 |
+
depth_to_normals_erp,
|
| 16 |
+
cubemap_to_erp
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class Pager(nn.Module):
|
| 21 |
+
def __init__(self,
|
| 22 |
+
model_configs,
|
| 23 |
+
pretrained_path,
|
| 24 |
+
train_modality=None,
|
| 25 |
+
device=torch.device("cpu"),
|
| 26 |
+
weight_dtype=torch.float32):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.model_configs = model_configs
|
| 29 |
+
self.weight_dtype = weight_dtype
|
| 30 |
+
self.rgb_latent_scale_factor = 0.18215
|
| 31 |
+
self.depth_latent_scale_factor = 0.18215
|
| 32 |
+
self.train_modality = train_modality
|
| 33 |
+
self.device = device
|
| 34 |
+
self.prepare_model_components(pretrained_path, model_configs)
|
| 35 |
+
self.prepare_empty_encoding()
|
| 36 |
+
|
| 37 |
+
self.alpha_prod = self.noise_scheduler.alphas_cumprod.to(device, dtype=weight_dtype)
|
| 38 |
+
self.beta_prod = 1 - self.alpha_prod
|
| 39 |
+
self.num_timesteps = self.noise_scheduler.config.num_train_timesteps - 1
|
| 40 |
+
del self.noise_scheduler
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def prepare_model_components(self, pretrained_path, model_configs):
|
| 44 |
+
vae_use_RoPE = None
|
| 45 |
+
for checkpoint_cfg in model_configs.values():
|
| 46 |
+
if vae_use_RoPE is None:
|
| 47 |
+
vae_use_RoPE = checkpoint_cfg['config'].vae_use_RoPE == "RoPE"
|
| 48 |
+
elif vae_use_RoPE != (checkpoint_cfg['config'].vae_use_RoPE == "RoPE"):
|
| 49 |
+
raise ValueError("All UNet checkpoints must use the same VAE positional encoding configuration.")
|
| 50 |
+
|
| 51 |
+
self.noise_scheduler = DDPMScheduler.from_pretrained(pretrained_path, subfolder="scheduler", rescale_betas_zero_snr=True)
|
| 52 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(pretrained_path, subfolder="tokenizer", revision=None)
|
| 53 |
+
self.text_encoder = CLIPTextModel.from_pretrained(pretrained_path, subfolder="text_encoder", revision=None, variant=None)
|
| 54 |
+
self.vae = AutoencoderKL.from_pretrained(pretrained_path, subfolder="vae", revision=None, variant=None,
|
| 55 |
+
use_RoPE = vae_use_RoPE)
|
| 56 |
+
self.set_valid_pad_conv(self.vae)
|
| 57 |
+
|
| 58 |
+
self.vae.requires_grad_(False)
|
| 59 |
+
self.vae.to(self.device, dtype=self.weight_dtype)
|
| 60 |
+
self.vae.eval()
|
| 61 |
+
|
| 62 |
+
self.text_encoder.requires_grad_(False)
|
| 63 |
+
self.text_encoder.to(self.device, dtype=self.weight_dtype)
|
| 64 |
+
self.text_encoder.eval()
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
base_in_channels = 8
|
| 68 |
+
pe_channels_size = 0
|
| 69 |
+
|
| 70 |
+
self.unet = {}
|
| 71 |
+
for modality, checkpoint_cfg in model_configs.items():
|
| 72 |
+
if checkpoint_cfg['config'].unet_positional_encoding == "uv":
|
| 73 |
+
pe_channels_size = 2
|
| 74 |
+
target_in_channels = base_in_channels + pe_channels_size
|
| 75 |
+
|
| 76 |
+
self.unet[modality] = UNet2DConditionModel.from_pretrained(
|
| 77 |
+
checkpoint_cfg["path"],
|
| 78 |
+
subfolder="unet",
|
| 79 |
+
revision=None,
|
| 80 |
+
in_channels=target_in_channels if checkpoint_cfg["mode"] == "trained" else base_in_channels,
|
| 81 |
+
use_RoPE=checkpoint_cfg['config'].unet_positional_encoding == "RoPE"
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
if target_in_channels > base_in_channels and checkpoint_cfg["mode"] != "trained":
|
| 85 |
+
self.extend_unet_conv_in(self.unet[modality], new_in_channels=target_in_channels)
|
| 86 |
+
self.set_valid_pad_conv(self.unet[modality])
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
if checkpoint_cfg['config'].enable_xformers:
|
| 90 |
+
if is_xformers_available():
|
| 91 |
+
import xformers
|
| 92 |
+
if self.unet.get("depth"):
|
| 93 |
+
self.unet["depth"].enable_xformers_memory_efficient_attention()
|
| 94 |
+
if self.unet.get("normal"):
|
| 95 |
+
self.unet["normal"].enable_xformers_memory_efficient_attention()
|
| 96 |
+
self.vae.enable_xformers_memory_efficient_attention()
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def prepare_training(self, accelerator, gradient_checkpointing):
|
| 100 |
+
self.unwrapped_unet = self.unet[self.train_modality]
|
| 101 |
+
self.unet[self.train_modality] = accelerator.prepare(self.unet[self.train_modality])
|
| 102 |
+
self.trained_unet = self.unet[self.train_modality]
|
| 103 |
+
|
| 104 |
+
if gradient_checkpointing:
|
| 105 |
+
self.trained_unet._set_gradient_checkpointing()
|
| 106 |
+
self.vae._set_gradient_checkpointing()
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def prepare_cubemap_PE(self, image_height, image_width):
|
| 110 |
+
use_uv_PE = False
|
| 111 |
+
for checkpoint_cfg in self.model_configs.values():
|
| 112 |
+
if checkpoint_cfg['config'].unet_positional_encoding == "uv":
|
| 113 |
+
use_uv_PE = True
|
| 114 |
+
if use_uv_PE:
|
| 115 |
+
PE_cubemap = get_positional_encoding(image_height, image_width)
|
| 116 |
+
self.PE_cubemap = PE_cubemap.to(device=self.device, dtype=self.weight_dtype)
|
| 117 |
+
|
| 118 |
+
def prepare_empty_encoding(self):
|
| 119 |
+
with torch.inference_mode():
|
| 120 |
+
empty_token = self.tokenizer([""], padding="max_length", truncation=True, return_tensors="pt").input_ids
|
| 121 |
+
empty_token = empty_token.to(self.device)
|
| 122 |
+
empty_encoding = self.text_encoder(empty_token, return_dict=False)[0]
|
| 123 |
+
self.empty_encoding = empty_encoding.to(self.device, dtype=self.weight_dtype)
|
| 124 |
+
|
| 125 |
+
del empty_token
|
| 126 |
+
del self.text_encoder
|
| 127 |
+
del self.tokenizer
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def forward(self, batch, modality):
|
| 131 |
+
with torch.inference_mode():
|
| 132 |
+
c, h, w = batch["rgb_cubemap"].shape[2:]
|
| 133 |
+
rgb_vae_input = batch["rgb_cubemap"].reshape(-1, c, h, w).to(dtype=self.weight_dtype)
|
| 134 |
+
rgb_latents = self.vae.encode(rgb_vae_input, deterministic=True)
|
| 135 |
+
rgb_latents = rgb_latents * self.rgb_latent_scale_factor
|
| 136 |
+
del rgb_vae_input
|
| 137 |
+
|
| 138 |
+
timesteps = torch.ones((rgb_latents.shape[0],), device=self.device) * self.num_timesteps
|
| 139 |
+
timesteps = timesteps.long()
|
| 140 |
+
alpha_prod_t = self.alpha_prod[timesteps].view(-1, 1, 1, 1)
|
| 141 |
+
beta_prod_t = self.beta_prod[timesteps].view(-1, 1, 1, 1)
|
| 142 |
+
|
| 143 |
+
noisy_latents = torch.zeros_like(rgb_latents).to(self.device)
|
| 144 |
+
encoder_hidden_states = self.empty_encoding.repeat(batch["rgb_cubemap"].shape[0] * 6, 1, 1)
|
| 145 |
+
if self.model_configs[modality]['config'].unet_positional_encoding == "uv":
|
| 146 |
+
batch_PE_cubemap = self.PE_cubemap.repeat(batch["rgb_cubemap"].shape[0], 1, 1, 1)
|
| 147 |
+
unet_input = torch.cat((rgb_latents, noisy_latents, batch_PE_cubemap), dim=1).to(
|
| 148 |
+
self.device
|
| 149 |
+
)
|
| 150 |
+
else:
|
| 151 |
+
unet_input = torch.cat((rgb_latents, noisy_latents), dim=1).to(self.device)
|
| 152 |
+
|
| 153 |
+
del rgb_latents
|
| 154 |
+
model_pred = self.unet[modality](
|
| 155 |
+
unet_input,
|
| 156 |
+
timesteps,
|
| 157 |
+
encoder_hidden_states,
|
| 158 |
+
return_dict=False,
|
| 159 |
+
)[0]
|
| 160 |
+
|
| 161 |
+
current_latent_estimate = (alpha_prod_t**0.5) * noisy_latents - (beta_prod_t**0.5) * model_pred
|
| 162 |
+
current_scaled_latent_estimate = current_latent_estimate / self.depth_latent_scale_factor
|
| 163 |
+
pred_cubemap = self.vae.decode(current_scaled_latent_estimate, deterministic=True)
|
| 164 |
+
|
| 165 |
+
if modality == "depth":
|
| 166 |
+
pred_cubemap = pred_cubemap.mean(dim=1, keepdim=True)
|
| 167 |
+
return pred_cubemap
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def prepare_losses_dict(self, loss_cfg):
|
| 171 |
+
self.losses_dict = {}
|
| 172 |
+
if self.train_modality == "depth":
|
| 173 |
+
self.losses_dict["l1_loss"] = {"loss_fn": L1Loss(invalid_mask_weight=loss_cfg.invalid_mask_weight),
|
| 174 |
+
"weight": loss_cfg.l1_loss_weight}
|
| 175 |
+
if loss_cfg.grad_loss_weight > 0.0:
|
| 176 |
+
self.losses_dict["grad_loss"] = {"loss_fn": GradL1Loss(), "weight": loss_cfg.grad_loss_weight}
|
| 177 |
+
if loss_cfg.normals_consistency_loss_weight > 0.0:
|
| 178 |
+
self.losses_dict["normals_consistency_loss"] = {"loss_fn": CosineNormalLoss(),
|
| 179 |
+
"weight": loss_cfg.normals_consistency_loss_weight}
|
| 180 |
+
else:
|
| 181 |
+
self.losses_dict["cosine_normal_loss"] = {"loss_fn": CosineNormalLoss(), "weight": 1.0}
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def calculate_depth_loss(self, batch, pred_cubemap, min_depth, depth_range, log_scale, metric_depth):
|
| 185 |
+
loss = {"total_loss": 0.0}
|
| 186 |
+
|
| 187 |
+
gt_depth_cubemap = batch['depth_cubemap'].squeeze(0).mean(dim=1, keepdim=True)
|
| 188 |
+
mask_cubemap = batch["mask_cubemap"].squeeze(0)
|
| 189 |
+
|
| 190 |
+
if not metric_depth:
|
| 191 |
+
if log_scale:
|
| 192 |
+
scale = compute_shift(pred_cubemap, gt_depth_cubemap, mask_cubemap)
|
| 193 |
+
else:
|
| 194 |
+
scale, shift = compute_scale_and_shift(pred_cubemap, gt_depth_cubemap, mask_cubemap)
|
| 195 |
+
|
| 196 |
+
if log_scale:
|
| 197 |
+
pred_cubemap += scale
|
| 198 |
+
else:
|
| 199 |
+
pred_cubemap = pred_cubemap * scale + shift
|
| 200 |
+
|
| 201 |
+
for loss_name, loss_params in self.losses_dict.items():
|
| 202 |
+
if loss_name == "normals_consistency_loss":
|
| 203 |
+
gt = batch['normal']
|
| 204 |
+
pred_depth = pred_cubemap
|
| 205 |
+
mask = batch["mask"]
|
| 206 |
+
pred_depth = self.process_depth_output(pred_depth, orig_size=gt.shape[2:], min_depth=min_depth,
|
| 207 |
+
depth_range=depth_range, log_scale=log_scale)[0]
|
| 208 |
+
pred = depth_to_normals_erp(pred_depth).unsqueeze(0)
|
| 209 |
+
else:
|
| 210 |
+
pred = pred_cubemap
|
| 211 |
+
gt = gt_depth_cubemap
|
| 212 |
+
mask = mask_cubemap
|
| 213 |
+
loss[loss_name] = loss_params["loss_fn"](pred, gt, mask)
|
| 214 |
+
loss["total_loss"] += loss[loss_name] * loss_params["weight"]
|
| 215 |
+
|
| 216 |
+
return loss
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def calculate_normal_loss(self, batch, pred_cubemap):
|
| 220 |
+
loss = {"total_loss": 0.0}
|
| 221 |
+
|
| 222 |
+
gt_normal_cubemap = batch['normal_cubemap'].squeeze(0)
|
| 223 |
+
mask_cubemap = batch["mask_cubemap"].squeeze(0)
|
| 224 |
+
|
| 225 |
+
for loss_name, loss_params in self.losses_dict.items():
|
| 226 |
+
pred = pred_cubemap
|
| 227 |
+
gt = gt_normal_cubemap
|
| 228 |
+
loss[loss_name] = loss_params["loss_fn"](pred, gt, mask_cubemap)
|
| 229 |
+
loss["total_loss"] += loss[loss_name] * loss_params["weight"]
|
| 230 |
+
|
| 231 |
+
return loss
|
| 232 |
+
|
| 233 |
+
def process_depth_output(self,pred_cubemap, orig_size, min_depth, depth_range, log_scale, mask=None):
|
| 234 |
+
pred_panorama = cubemap_to_erp(pred_cubemap, *orig_size)
|
| 235 |
+
pred_panorama = torch.clamp(pred_panorama, -1, 1)
|
| 236 |
+
pred_panorama = (pred_panorama + 1) / 2
|
| 237 |
+
if mask is not None:
|
| 238 |
+
pred_panorama *= mask
|
| 239 |
+
pred_panorama = pred_panorama * depth_range + min_depth
|
| 240 |
+
if log_scale:
|
| 241 |
+
pred_panorama_viz = pred_panorama.clone()
|
| 242 |
+
pred_panorama = torch.exp(pred_panorama)
|
| 243 |
+
else:
|
| 244 |
+
pred_panorama_viz = torch.log(pred_panorama)
|
| 245 |
+
|
| 246 |
+
return pred_panorama, pred_panorama_viz
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def process_normal_output(self,pred_cubemap, orig_size):
|
| 250 |
+
pred_panorama = cubemap_to_erp(pred_cubemap, *orig_size)
|
| 251 |
+
pred_panorama = torch.clamp(pred_panorama, -1, 1)
|
| 252 |
+
return pred_panorama
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def extend_unet_conv_in(self, unet, new_in_channels: int):
|
| 256 |
+
if new_in_channels < unet.conv_in.in_channels:
|
| 257 |
+
raise ValueError(
|
| 258 |
+
f"new_in_channels ({new_in_channels}) must be >= current "
|
| 259 |
+
f"{unet.conv_in.in_channels}"
|
| 260 |
+
)
|
| 261 |
+
if new_in_channels == unet.conv_in.in_channels:
|
| 262 |
+
return
|
| 263 |
+
|
| 264 |
+
old_conv = unet.conv_in
|
| 265 |
+
old_in = old_conv.in_channels
|
| 266 |
+
device, dtype = old_conv.weight.device, old_conv.weight.dtype
|
| 267 |
+
bias_flag = old_conv.bias is not None
|
| 268 |
+
|
| 269 |
+
new_conv = Conv2d(
|
| 270 |
+
new_in_channels,
|
| 271 |
+
old_conv.out_channels,
|
| 272 |
+
kernel_size=old_conv.kernel_size,
|
| 273 |
+
stride=old_conv.stride,
|
| 274 |
+
padding=old_conv.padding,
|
| 275 |
+
bias=bias_flag,
|
| 276 |
+
padding_mode=old_conv.padding_mode,
|
| 277 |
+
).to(device=device, dtype=dtype)
|
| 278 |
+
|
| 279 |
+
new_conv.weight.zero_()
|
| 280 |
+
new_conv.weight[:, :old_in].copy_(old_conv.weight)
|
| 281 |
+
if bias_flag:
|
| 282 |
+
new_conv.bias.copy_(old_conv.bias)
|
| 283 |
+
|
| 284 |
+
unet.conv_in = new_conv
|
| 285 |
+
unet.config["in_channels"] = new_in_channels
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def set_valid_pad_conv(self, module: nn.Module):
|
| 289 |
+
for name, child in list(module.named_children()):
|
| 290 |
+
if isinstance(child, nn.Conv2d):
|
| 291 |
+
if child.padding != (0, 0):
|
| 292 |
+
setattr(module, name, PaddedConv2d.from_existing(child, valid_pad_conv_fn))
|
| 293 |
+
elif module.__class__.__name__ == "Downsample2D" and module.use_conv:
|
| 294 |
+
setattr(module, name, PaddedConv2d.from_existing(child, valid_pad_conv_fn, one_side_pad=True))
|
| 295 |
+
else:
|
| 296 |
+
self.set_valid_pad_conv(child)
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def save_model(self, ema_unet, model_save_dir):
|
| 300 |
+
self.unwrapped_unet.save_pretrained(model_save_dir / "original")
|
| 301 |
+
if ema_unet is not None:
|
| 302 |
+
ema_unet.store(self.unwrapped_unet.parameters())
|
| 303 |
+
ema_unet.copy_to(self.unwrapped_unet.parameters())
|
| 304 |
+
self.unwrapped_unet.save_pretrained(model_save_dir / f"EMA")
|
| 305 |
+
ema_unet.restore(self.unwrapped_unet.parameters())
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
|
src/utils/__init__.py
ADDED
|
File without changes
|
src/utils/conv_padding.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
orderings = [
|
| 5 |
+
[0, 1, 3, 4, 5],
|
| 6 |
+
[1, 2, 0, 4, 5],
|
| 7 |
+
[2, 3, 1, 4, 5],
|
| 8 |
+
[3, 0, 2, 4, 5],
|
| 9 |
+
[4, 1, 3, 2, 0],
|
| 10 |
+
[5, 1, 3, 0, 2],
|
| 11 |
+
]
|
| 12 |
+
rotations = [
|
| 13 |
+
[0, 0, 0, 0, 0],
|
| 14 |
+
[0, 0, 0,-1, 1],
|
| 15 |
+
[0, 0, 0, 2, 2],
|
| 16 |
+
[0, 0, 0, 1,-1],
|
| 17 |
+
[0, 1,-1, 2, 0],
|
| 18 |
+
[0,-1, 1, 0, 2]
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
def _take_right(face, rot):
|
| 22 |
+
if rot == 0:
|
| 23 |
+
return face[:, :, 0]
|
| 24 |
+
elif rot == 1:
|
| 25 |
+
return face[:, 0, :].flip(1)
|
| 26 |
+
elif rot == 2:
|
| 27 |
+
return face[:, :, -1].flip(1)
|
| 28 |
+
elif rot == -1:
|
| 29 |
+
return face[:, -1, :]
|
| 30 |
+
|
| 31 |
+
def _take_left(face, rot):
|
| 32 |
+
if rot == 0:
|
| 33 |
+
return face[:, :, -1]
|
| 34 |
+
elif rot == 1:
|
| 35 |
+
return face[:, -1, :].flip(1)
|
| 36 |
+
elif rot == 2:
|
| 37 |
+
return face[:, :, 0].flip(1)
|
| 38 |
+
elif rot == -1:
|
| 39 |
+
return face[:, 0, :]
|
| 40 |
+
|
| 41 |
+
def _take_top(face, rot):
|
| 42 |
+
if rot == 0:
|
| 43 |
+
return face[:, -1, :]
|
| 44 |
+
elif rot == 1:
|
| 45 |
+
return face[:, :, 0]
|
| 46 |
+
elif rot == 2:
|
| 47 |
+
return face[:, 0, :].flip(1)
|
| 48 |
+
elif rot == -1:
|
| 49 |
+
return face[:, :, -1].flip(1)
|
| 50 |
+
|
| 51 |
+
def _take_bottom(face, rot):
|
| 52 |
+
if rot == 0:
|
| 53 |
+
return face[:, 0, :]
|
| 54 |
+
elif rot == 1:
|
| 55 |
+
return face[:, :, -1]
|
| 56 |
+
elif rot == 2:
|
| 57 |
+
return face[:, -1, :].flip(1)
|
| 58 |
+
elif rot == -1:
|
| 59 |
+
return face[:, :, 0].flip(1)
|
| 60 |
+
|
| 61 |
+
def valid_pad_conv_fn(x, one_side_pad=False):
|
| 62 |
+
if one_side_pad:
|
| 63 |
+
x = x[:, :, :-1, :-1]
|
| 64 |
+
assert x.ndim == 4 and x.shape[0] == 6
|
| 65 |
+
_, C, H, W = x.shape
|
| 66 |
+
y = x.new_empty(6, C, H+2, W+2)
|
| 67 |
+
y[..., 1:-1, 1:-1] = x
|
| 68 |
+
|
| 69 |
+
for i in range(6):
|
| 70 |
+
r_idx, l_idx, t_idx, b_idx = orderings[i][1:5]
|
| 71 |
+
r_rot, l_rot, t_rot, b_rot = rotations[i][1:5]
|
| 72 |
+
|
| 73 |
+
r_edge = _take_right (x[r_idx], r_rot)
|
| 74 |
+
l_edge = _take_left (x[l_idx], l_rot)
|
| 75 |
+
t_edge = _take_top (x[t_idx], t_rot)
|
| 76 |
+
b_edge = _take_bottom(x[b_idx], b_rot)
|
| 77 |
+
|
| 78 |
+
y[i, :, 1:-1, 0 ] = l_edge
|
| 79 |
+
y[i, :, 1:-1, -1 ] = r_edge
|
| 80 |
+
y[i, :, 0, 1:-1] = t_edge
|
| 81 |
+
y[i, :, -1, 1:-1] = b_edge
|
| 82 |
+
|
| 83 |
+
y[i, :, 0, 0 ] = 0.5*(y[i, :, 0, 1] + y[i, :, 1, 0])
|
| 84 |
+
y[i, :, 0, -1 ] = 0.5*(y[i, :, 0, -2] + y[i, :, 1, -1])
|
| 85 |
+
y[i, :, -1, 0 ] = 0.5*(y[i, :, -2, 0] + y[i, :, -1, 1])
|
| 86 |
+
y[i, :, -1,-1 ] = 0.5*(y[i, :, -2, -1] + y[i, :, -1, -2])
|
| 87 |
+
|
| 88 |
+
if one_side_pad:
|
| 89 |
+
return y[:, :, 1:, 1:]
|
| 90 |
+
|
| 91 |
+
return y
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class PaddedConv2d(nn.Conv2d):
|
| 95 |
+
def __init__(self, *args, pad_fn=None, one_side_pad=False, **kwargs):
|
| 96 |
+
kwargs = dict(kwargs)
|
| 97 |
+
kwargs["padding"] = 0
|
| 98 |
+
super().__init__(*args, **kwargs)
|
| 99 |
+
self.pad_fn = pad_fn
|
| 100 |
+
self.one_side_pad = one_side_pad
|
| 101 |
+
|
| 102 |
+
def forward(self, x):
|
| 103 |
+
x = self.pad_fn(x, one_side_pad=self.one_side_pad)
|
| 104 |
+
return F.conv2d(
|
| 105 |
+
x, self.weight, self.bias,
|
| 106 |
+
stride=self.stride, padding=0,
|
| 107 |
+
dilation=self.dilation, groups=self.groups
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
@classmethod
|
| 111 |
+
def from_existing(cls, conv: nn.Conv2d, pad_fn, one_side_pad=False):
|
| 112 |
+
new = cls(
|
| 113 |
+
conv.in_channels, conv.out_channels, conv.kernel_size,
|
| 114 |
+
stride=conv.stride, padding=0, dilation=conv.dilation,
|
| 115 |
+
groups=conv.groups, bias=(conv.bias is not None),
|
| 116 |
+
padding_mode="zeros", pad_fn=pad_fn, one_side_pad=one_side_pad
|
| 117 |
+
)
|
| 118 |
+
new.weight = conv.weight
|
| 119 |
+
if conv.bias is not None:
|
| 120 |
+
new.bias = conv.bias
|
| 121 |
+
return new
|
| 122 |
+
|
| 123 |
+
|
src/utils/geometry_utils.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import trimesh
|
| 4 |
+
from pytorch360convert import e2c, c2e
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def erp_to_cubemap(erp_tensor, face_w = 768, cube_format = "stack", mode = "bilinear", **kwargs):
|
| 8 |
+
return e2c(erp_tensor, face_w=face_w, cube_format=cube_format, mode=mode, **kwargs)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def cubemap_to_erp(cube_tensor, erp_h = 1024, erp_w = 2048, cube_format = "stack", mode = "bilinear", **kwargs):
|
| 12 |
+
return c2e(cube_tensor, h=erp_h, w=erp_w, cube_format=cube_format, mode=mode, **kwargs)
|
| 13 |
+
|
| 14 |
+
def roll_augment(data, shift_x):
|
| 15 |
+
if data.ndim == 2:
|
| 16 |
+
data = data[:, :, np.newaxis]
|
| 17 |
+
originally_2d = True
|
| 18 |
+
else:
|
| 19 |
+
originally_2d = False
|
| 20 |
+
if data.ndim == 3 and data.shape[0] != 3:
|
| 21 |
+
data = np.moveaxis(data, -1, 0)
|
| 22 |
+
moved_axis = True
|
| 23 |
+
else:
|
| 24 |
+
moved_axis = False
|
| 25 |
+
|
| 26 |
+
data_rolled = np.roll(data, int(shift_x), axis=2)
|
| 27 |
+
|
| 28 |
+
if moved_axis:
|
| 29 |
+
data_rolled = np.moveaxis(data_rolled, 0, -1)
|
| 30 |
+
if originally_2d:
|
| 31 |
+
data_rolled = data_rolled[:, :, 0]
|
| 32 |
+
return data_rolled
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def roll_normal(normal, shift_x):
|
| 36 |
+
if normal.ndim == 2:
|
| 37 |
+
normal = normal[:, :, np.newaxis]
|
| 38 |
+
originally_2d = True
|
| 39 |
+
else:
|
| 40 |
+
originally_2d = False
|
| 41 |
+
if normal.ndim == 3 and normal.shape[0] != 3:
|
| 42 |
+
normal = np.moveaxis(normal, -1, 0)
|
| 43 |
+
moved_axis = True
|
| 44 |
+
else:
|
| 45 |
+
moved_axis = False
|
| 46 |
+
|
| 47 |
+
_, H, W = normal.shape
|
| 48 |
+
|
| 49 |
+
angle = - 2.0 * np.pi * (shift_x / float(W))
|
| 50 |
+
cos_a, sin_a = np.cos(angle), np.sin(angle)
|
| 51 |
+
R = np.array([
|
| 52 |
+
[ cos_a, 0.0, -sin_a],
|
| 53 |
+
[ 0.0, 1.0, 0.0 ],
|
| 54 |
+
[ sin_a, 0.0, cos_a]
|
| 55 |
+
], dtype=normal.dtype)
|
| 56 |
+
|
| 57 |
+
n_flat = normal.reshape(3, -1)
|
| 58 |
+
normal = (R @ n_flat).reshape(3, H, W)
|
| 59 |
+
|
| 60 |
+
if moved_axis:
|
| 61 |
+
normal = np.moveaxis(normal, 0, -1)
|
| 62 |
+
|
| 63 |
+
if originally_2d:
|
| 64 |
+
normal = normal[:, :, 0]
|
| 65 |
+
return normal
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def compute_scale_and_shift(pred_g, targ_g, mask_g = None, eps = 0.0, fit_shift = True):
|
| 69 |
+
if mask_g is None:
|
| 70 |
+
mask_g = torch.ones_like(pred_g, dtype=torch.bool)
|
| 71 |
+
if pred_g.shape[0] == 6:
|
| 72 |
+
pred_g = pred_g.view(1, 6, pred_g.shape[2], pred_g.shape[3])
|
| 73 |
+
targ_g = targ_g.view(1, 6, targ_g.shape[2], targ_g.shape[3])
|
| 74 |
+
mask_g = mask_g.view(1, 6, mask_g.shape[2], mask_g.shape[3])
|
| 75 |
+
elif pred_g.shape[0] == 1 and pred_g.dim() == 3:
|
| 76 |
+
pred_g = pred_g.unsqueeze(0)
|
| 77 |
+
targ_g = targ_g.unsqueeze(0)
|
| 78 |
+
mask_g = mask_g.unsqueeze(0)
|
| 79 |
+
|
| 80 |
+
mask_g = mask_g.to(dtype=pred_g.dtype)
|
| 81 |
+
|
| 82 |
+
a_00 = torch.sum(mask_g * pred_g * pred_g, dim=(1, 2, 3))
|
| 83 |
+
a_01 = torch.sum(mask_g * pred_g, dim=(1, 2, 3))
|
| 84 |
+
a_11 = torch.sum(mask_g, dim=(1, 2, 3))
|
| 85 |
+
b_0 = torch.sum(mask_g * pred_g * targ_g, dim=(1, 2, 3))
|
| 86 |
+
b_1 = torch.sum(mask_g * targ_g, dim=(1, 2, 3))
|
| 87 |
+
|
| 88 |
+
if fit_shift:
|
| 89 |
+
det = a_00 * a_11 - a_01 * a_01
|
| 90 |
+
det = det + eps
|
| 91 |
+
scale = torch.zeros_like(b_0)
|
| 92 |
+
shift = torch.zeros_like(b_1)
|
| 93 |
+
valid = det > 0
|
| 94 |
+
scale[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / det[valid]
|
| 95 |
+
shift[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid]
|
| 96 |
+
return scale, shift
|
| 97 |
+
else:
|
| 98 |
+
denom = a_00 + eps
|
| 99 |
+
scale = b_0 / denom
|
| 100 |
+
shift = torch.zeros_like(scale)
|
| 101 |
+
return scale, shift
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def compute_shift(pred, targ, mask, eps = 1e-6):
|
| 105 |
+
if pred.shape[0] == 6:
|
| 106 |
+
pred = pred.view(1, 6, *pred.shape[2:])
|
| 107 |
+
targ = targ.view(1, 6, *targ.shape[2:])
|
| 108 |
+
mask = mask.view(1, 6, *mask.shape[2:])
|
| 109 |
+
|
| 110 |
+
w = mask.float()
|
| 111 |
+
num = torch.sum(w * (targ - pred), dim=(1,2,3))
|
| 112 |
+
den = torch.sum(w, dim=(1,2,3)).clamp_min(eps)
|
| 113 |
+
beta = num / den
|
| 114 |
+
return beta
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def get_positional_encoding(H, W, pixel_center = True, hw = 96):
|
| 118 |
+
jj = np.arange(W, dtype=np.float64)
|
| 119 |
+
ii = np.arange(H, dtype=np.float64)
|
| 120 |
+
if pixel_center:
|
| 121 |
+
jj = jj + 0.5
|
| 122 |
+
ii = ii + 0.5
|
| 123 |
+
|
| 124 |
+
U = (jj / W) * 2.0 - 1.0
|
| 125 |
+
V = (ii / H) * 2.0 - 1.0
|
| 126 |
+
U, V = np.meshgrid(U, V, indexing='xy')
|
| 127 |
+
|
| 128 |
+
erp = np.stack([U, V], axis=-1)
|
| 129 |
+
|
| 130 |
+
erp_tensor = torch.from_numpy(erp).permute(2, 0, 1).float()
|
| 131 |
+
faces = erp_to_cubemap(erp_tensor, face_w=hw)
|
| 132 |
+
return faces
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def unit_normals(n, eps = 1e-6):
|
| 136 |
+
assert n.dim() >= 3 and n.size(-3) == 3, "normals must have channel=3 at dim -3"
|
| 137 |
+
denom = torch.clamp(torch.linalg.norm(n, dim=-3, keepdim=True), min=eps)
|
| 138 |
+
return n / denom
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def _erp_dirs(H, W, device=None, dtype=None):
|
| 142 |
+
u = (torch.arange(W, device=device, dtype=dtype) + 0.5) / W
|
| 143 |
+
v = (torch.arange(H, device=device, dtype=dtype) + 0.5) / H
|
| 144 |
+
theta = u * (2.0 * torch.pi) - torch.pi
|
| 145 |
+
phi = (0.5 - v) * torch.pi
|
| 146 |
+
|
| 147 |
+
theta = theta.view(1, W).expand(H, W)
|
| 148 |
+
phi = phi.view(H, 1).expand(H, W)
|
| 149 |
+
|
| 150 |
+
cosphi = torch.cos(phi)
|
| 151 |
+
sinphi = torch.sin(phi)
|
| 152 |
+
costhe = torch.cos(theta)
|
| 153 |
+
sinthe = torch.sin(theta)
|
| 154 |
+
|
| 155 |
+
x = cosphi * costhe
|
| 156 |
+
y = sinphi
|
| 157 |
+
z = -cosphi * sinthe
|
| 158 |
+
|
| 159 |
+
dirs = torch.stack([x, y, z], dim=0)
|
| 160 |
+
return dirs
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def depth_to_normals_erp(depth, eps = 1e-6):
|
| 164 |
+
|
| 165 |
+
assert depth.dim() == 3 and depth.size(0) == 1, "depth must be (B,1,H,W)"
|
| 166 |
+
_, H, W = depth.shape
|
| 167 |
+
device, dtype = depth.device, depth.dtype
|
| 168 |
+
|
| 169 |
+
dirs = _erp_dirs(H, W, device=device, dtype=dtype)
|
| 170 |
+
P = depth * dirs
|
| 171 |
+
|
| 172 |
+
dtheta = 2.0 * torch.pi / W
|
| 173 |
+
dphi = torch.pi / H
|
| 174 |
+
|
| 175 |
+
P_l = torch.roll(P, shifts=+1, dims=-1)
|
| 176 |
+
P_r = torch.roll(P, shifts=-1, dims=-1)
|
| 177 |
+
dP_dtheta = (P_r - P_l) / (2.0 * dtheta)
|
| 178 |
+
|
| 179 |
+
P_u = torch.cat([P[:, :1, :], P[:, :-1, :]], dim=-2)
|
| 180 |
+
P_d = torch.cat([P[:, 1:, :], P[:, -1:, :]], dim=-2)
|
| 181 |
+
dP_dphi = (P_d - P_u) / (2.0 * dphi)
|
| 182 |
+
|
| 183 |
+
n = torch.cross(dP_dtheta, dP_dphi, dim=0)
|
| 184 |
+
n = unit_normals(n, eps=eps)
|
| 185 |
+
|
| 186 |
+
return n
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def compute_edge_mask(depth, abs_thresh = 0.1, rel_thresh = 0.1):
|
| 190 |
+
assert depth.ndim == 2
|
| 191 |
+
depth = depth.astype(np.float32, copy=False)
|
| 192 |
+
|
| 193 |
+
valid = depth > 0
|
| 194 |
+
eps = 1e-6
|
| 195 |
+
|
| 196 |
+
edge = np.zeros_like(valid, dtype=bool)
|
| 197 |
+
|
| 198 |
+
d1 = depth[:, :-1]
|
| 199 |
+
d2 = depth[:, 1:]
|
| 200 |
+
v_pair = valid[:, :-1] & valid[:, 1:]
|
| 201 |
+
|
| 202 |
+
diff = np.abs(d1 - d2)
|
| 203 |
+
rel = diff / (np.minimum(d1, d2) + eps)
|
| 204 |
+
|
| 205 |
+
edge_pair = v_pair & (diff > abs_thresh) & (rel > rel_thresh)
|
| 206 |
+
|
| 207 |
+
edge[:, :-1] |= edge_pair
|
| 208 |
+
edge[:, 1:] |= edge_pair
|
| 209 |
+
|
| 210 |
+
d1 = depth[:-1, :]
|
| 211 |
+
d2 = depth[1:, :]
|
| 212 |
+
v_pair = valid[:-1, :] & valid[1:, :]
|
| 213 |
+
|
| 214 |
+
diff = np.abs(d1 - d2)
|
| 215 |
+
rel = diff / (np.minimum(d1, d2) + eps)
|
| 216 |
+
|
| 217 |
+
edge_pair = v_pair & (diff > abs_thresh) & (rel > rel_thresh)
|
| 218 |
+
|
| 219 |
+
edge[:-1, :] |= edge_pair
|
| 220 |
+
edge[1:, :] |= edge_pair
|
| 221 |
+
|
| 222 |
+
keep = valid & (~edge)
|
| 223 |
+
return keep
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def erp_to_pointcloud(rgb, depth, mask = None):
|
| 227 |
+
assert rgb.ndim == 3 and rgb.shape[-1] == 3, "rgb must be (H, W, 3)"
|
| 228 |
+
assert depth.ndim == 2 and depth.shape[:2] == rgb.shape[:2], "depth must be (H, W) and match rgb H,W"
|
| 229 |
+
|
| 230 |
+
H, W, _ = rgb.shape
|
| 231 |
+
|
| 232 |
+
depth = depth.astype(np.float32, copy=False)
|
| 233 |
+
|
| 234 |
+
u = (np.arange(W, dtype=np.float32) + 0.5) / W
|
| 235 |
+
v = (np.arange(H, dtype=np.float32) + 0.5) / H
|
| 236 |
+
|
| 237 |
+
theta = u * (2.0 * np.pi) - np.pi
|
| 238 |
+
phi = (1 - v) * np.pi - (np.pi / 2.0)
|
| 239 |
+
|
| 240 |
+
theta, phi = np.meshgrid(theta, phi, indexing="xy")
|
| 241 |
+
|
| 242 |
+
cos_phi = np.cos(phi)
|
| 243 |
+
dir_x = cos_phi * np.cos(theta)
|
| 244 |
+
dir_y = np.sin(phi)
|
| 245 |
+
dir_z = cos_phi * np.sin(theta)
|
| 246 |
+
|
| 247 |
+
X = depth * dir_x
|
| 248 |
+
Y = depth * dir_y
|
| 249 |
+
Z = depth * dir_z
|
| 250 |
+
|
| 251 |
+
if mask is None:
|
| 252 |
+
keep = depth > 0
|
| 253 |
+
else:
|
| 254 |
+
keep = (mask.astype(bool)) & (depth > 0)
|
| 255 |
+
|
| 256 |
+
points = np.stack([X, Y, Z], axis=-1)[keep]
|
| 257 |
+
|
| 258 |
+
rgb_clamped = np.clip(rgb, -1.0, 1.0)
|
| 259 |
+
colors = ((rgb_clamped * 0.5 + 0.5) * 255.0).astype(np.uint8)
|
| 260 |
+
colors = colors.reshape(H, W, 3)[keep]
|
| 261 |
+
|
| 262 |
+
return points.astype(np.float32, copy=False), colors
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def erp_to_point_cloud_glb(rgb, depth, mask=None, export_path=None):
|
| 266 |
+
points, colors = erp_to_pointcloud(rgb, depth, mask)
|
| 267 |
+
scene = trimesh.Scene()
|
| 268 |
+
scene.add_geometry(trimesh.PointCloud(vertices=points, colors=colors))
|
| 269 |
+
scene.export(export_path)
|
| 270 |
+
return scene
|
src/utils/loss.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from src.utils.geometry_utils import unit_normals
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class L1Loss(nn.Module):
|
| 7 |
+
def __init__(self, invalid_mask_weight=0.0):
|
| 8 |
+
super(L1Loss, self).__init__()
|
| 9 |
+
self.name = 'L1'
|
| 10 |
+
self.invalid_mask_weight = invalid_mask_weight
|
| 11 |
+
|
| 12 |
+
def forward(self, pred, target, mask):
|
| 13 |
+
loss = nn.functional.l1_loss(pred[mask], target[mask])
|
| 14 |
+
if self.invalid_mask_weight > 0.0:
|
| 15 |
+
invalid_mask = ~mask
|
| 16 |
+
if invalid_mask.sum() > 0:
|
| 17 |
+
invalid_loss = nn.functional.l1_loss(pred[invalid_mask], target[invalid_mask])
|
| 18 |
+
loss = loss + self.invalid_mask_weight * invalid_loss
|
| 19 |
+
return loss
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class GradL1Loss(nn.Module):
|
| 24 |
+
def __init__(self):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.name = 'GradL1'
|
| 27 |
+
|
| 28 |
+
def grad(self, x):
|
| 29 |
+
dx = x[..., :-1, 1:] - x[..., :-1, :-1]
|
| 30 |
+
dy = x[..., 1:, :-1] - x[..., :-1, :-1]
|
| 31 |
+
return dx, dy
|
| 32 |
+
|
| 33 |
+
def grad_mask(self, mask):
|
| 34 |
+
return (mask[..., :-1, :-1] & mask[..., :-1, 1:] &
|
| 35 |
+
mask[..., 1:, :-1] & mask[..., 1:, 1:])
|
| 36 |
+
|
| 37 |
+
def forward(self, pred, target, mask):
|
| 38 |
+
dx_p, dy_p = self.grad(pred)
|
| 39 |
+
dx_t, dy_t = self.grad(target)
|
| 40 |
+
mask_g = self.grad_mask(mask)
|
| 41 |
+
|
| 42 |
+
loss_x = nn.functional.l1_loss(dx_p[mask_g], dx_t[mask_g], reduction='mean')
|
| 43 |
+
loss_y = nn.functional.l1_loss(dy_p[mask_g], dy_t[mask_g], reduction='mean')
|
| 44 |
+
|
| 45 |
+
return 0.5 * (loss_x + loss_y)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class CosineNormalLoss(nn.Module):
|
| 49 |
+
def __init__(self):
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.name = "CosineNormalLoss"
|
| 52 |
+
|
| 53 |
+
def forward(self, pred: torch.Tensor,
|
| 54 |
+
target: torch.Tensor,
|
| 55 |
+
mask: torch.Tensor) -> torch.Tensor:
|
| 56 |
+
assert pred.shape == target.shape, "pred and target must have same shape"
|
| 57 |
+
|
| 58 |
+
pred = unit_normals(pred)
|
| 59 |
+
target = unit_normals(target)
|
| 60 |
+
|
| 61 |
+
dot = (pred * target).sum(dim=1, keepdim=True).clamp(-1.0, 1.0)
|
| 62 |
+
cos_term = 1.0 - dot
|
| 63 |
+
|
| 64 |
+
if mask is not None:
|
| 65 |
+
loss = cos_term[mask].mean()
|
| 66 |
+
else:
|
| 67 |
+
loss = cos_term.mean()
|
| 68 |
+
return loss
|
src/utils/lr_scheduler.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @Vukasin Bozic 2026
|
| 2 |
+
# This file contains the modified version of Marigold's exponential LR scheduler.
|
| 3 |
+
# https://github.com/prs-eth/Marigold/blob/main/src/util/lr_scheduler.py
|
| 4 |
+
|
| 5 |
+
# Author: Bingxin Ke
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
class IterExponential:
|
| 10 |
+
|
| 11 |
+
def __init__(self, total_iter_length, final_ratio, warmup_steps=0) -> None:
|
| 12 |
+
self.total_length = total_iter_length
|
| 13 |
+
self.effective_length = int(total_iter_length * (1 - warmup_steps))
|
| 14 |
+
self.final_ratio = final_ratio
|
| 15 |
+
self.warmup_steps = int(total_iter_length * warmup_steps)
|
| 16 |
+
|
| 17 |
+
def __call__(self, n_iter) -> float:
|
| 18 |
+
if n_iter < self.warmup_steps:
|
| 19 |
+
alpha = 1.0 * n_iter / self.warmup_steps
|
| 20 |
+
elif n_iter >= self.total_length:
|
| 21 |
+
alpha = self.final_ratio
|
| 22 |
+
else:
|
| 23 |
+
actual_iter = n_iter - self.warmup_steps
|
| 24 |
+
alpha = np.exp(
|
| 25 |
+
actual_iter / self.effective_length * np.log(self.final_ratio)
|
| 26 |
+
)
|
| 27 |
+
return alpha
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class IterConstant:
|
| 31 |
+
|
| 32 |
+
def __init__(self, total_iter_length: int, warmup_steps: float = 0.0) -> None:
|
| 33 |
+
self.total_length = int(total_iter_length)
|
| 34 |
+
self.warmup_steps = int(total_iter_length * warmup_steps)
|
| 35 |
+
|
| 36 |
+
def __call__(self, n_iter: int) -> float:
|
| 37 |
+
if self.warmup_steps <= 0:
|
| 38 |
+
return 1.0
|
| 39 |
+
if n_iter < self.warmup_steps:
|
| 40 |
+
return float(n_iter + 1) / float(self.warmup_steps)
|
| 41 |
+
return 1.0
|
src/utils/utils.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cv2
|
| 4 |
+
import random
|
| 5 |
+
import wandb
|
| 6 |
+
from tqdm.auto import tqdm
|
| 7 |
+
from omegaconf import OmegaConf, DictConfig
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
def args_to_omegaconf(args, base_cfg=None):
|
| 11 |
+
cfg = OmegaConf.create(base_cfg)
|
| 12 |
+
|
| 13 |
+
def _override_if_provided(container, key):
|
| 14 |
+
if hasattr(args, key):
|
| 15 |
+
value = getattr(args, key)
|
| 16 |
+
if value is not None:
|
| 17 |
+
container[key] = value
|
| 18 |
+
|
| 19 |
+
for key in cfg.keys():
|
| 20 |
+
node = cfg[key]
|
| 21 |
+
if isinstance(node, DictConfig):
|
| 22 |
+
for subkey in node.keys():
|
| 23 |
+
_override_if_provided(node, subkey)
|
| 24 |
+
else:
|
| 25 |
+
_override_if_provided(cfg, key)
|
| 26 |
+
|
| 27 |
+
return cfg
|
| 28 |
+
|
| 29 |
+
def _tb_sanitize(v):
|
| 30 |
+
if v is None:
|
| 31 |
+
return "null"
|
| 32 |
+
if isinstance(v, (bool, int, float, str, torch.Tensor)):
|
| 33 |
+
return v
|
| 34 |
+
if isinstance(v, Path):
|
| 35 |
+
return str(v)
|
| 36 |
+
return str(v)
|
| 37 |
+
|
| 38 |
+
def _flatten_dict(d, prefix=""):
|
| 39 |
+
out = {}
|
| 40 |
+
if isinstance(d, dict):
|
| 41 |
+
for k, v in d.items():
|
| 42 |
+
key = f"{prefix}.{k}" if prefix else str(k)
|
| 43 |
+
if isinstance(v, dict):
|
| 44 |
+
out.update(_flatten_dict(v, key))
|
| 45 |
+
else:
|
| 46 |
+
out[key] = _tb_sanitize(v)
|
| 47 |
+
else:
|
| 48 |
+
out[prefix or "cfg"] = _tb_sanitize(d)
|
| 49 |
+
return out
|
| 50 |
+
|
| 51 |
+
def convert_paths_to_pathlib(cfg):
|
| 52 |
+
for key, value in cfg.items():
|
| 53 |
+
if isinstance(value, DictConfig):
|
| 54 |
+
cfg[key] = convert_paths_to_pathlib(value)
|
| 55 |
+
elif 'path' in key.lower():
|
| 56 |
+
cfg[key] = Path(value) if value is not None else None
|
| 57 |
+
return cfg
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def convert_pathlib_to_strings(cfg):
|
| 61 |
+
for key, value in cfg.items():
|
| 62 |
+
if isinstance(value, DictConfig):
|
| 63 |
+
cfg[key] = convert_pathlib_to_strings(value)
|
| 64 |
+
elif isinstance(value, Path):
|
| 65 |
+
cfg[key] = str(value)
|
| 66 |
+
return cfg
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def prepare_trained_parameters(unet, cfg):
|
| 70 |
+
unet_parameters = []
|
| 71 |
+
|
| 72 |
+
if cfg.training.only_train_attention_layers:
|
| 73 |
+
for name, param in unet.named_parameters():
|
| 74 |
+
if (cfg.model.unet_positional_encoding == "uv" and "conv_in" in name) or \
|
| 75 |
+
"transformer_blocks" in name:
|
| 76 |
+
unet_parameters.append(param)
|
| 77 |
+
param.requires_grad_(True)
|
| 78 |
+
else:
|
| 79 |
+
param.requires_grad_(False)
|
| 80 |
+
else:
|
| 81 |
+
for param in unet.parameters():
|
| 82 |
+
unet_parameters.append(param)
|
| 83 |
+
param.requires_grad_(True)
|
| 84 |
+
|
| 85 |
+
return unet_parameters
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@torch.no_grad()
|
| 89 |
+
def validation_loop(accelerator, dataloader, pager, ema_unet, cfg, epoch, global_step, val_type="val"):
|
| 90 |
+
if val_type == "val":
|
| 91 |
+
desc = "Validation"
|
| 92 |
+
x_axis_name = "epoch"
|
| 93 |
+
x_axis = epoch
|
| 94 |
+
elif val_type == "tiny_val":
|
| 95 |
+
desc = "Tiny Validation"
|
| 96 |
+
x_axis_name = "global_step"
|
| 97 |
+
x_axis = global_step
|
| 98 |
+
else:
|
| 99 |
+
raise ValueError(f"Unknown val type {val_type}")
|
| 100 |
+
if cfg.training.use_EMA:
|
| 101 |
+
ema_unet.store(pager.unwrapped_unet.parameters())
|
| 102 |
+
ema_unet.copy_to(pager.unwrapped_unet.parameters())
|
| 103 |
+
val_epoch_loss = 0.0
|
| 104 |
+
log_val_images = {"rgb": [], cfg.model.modality: []}
|
| 105 |
+
log_img_ids = random.sample(range(len(dataloader)), 4)
|
| 106 |
+
progress_bar = tqdm(dataloader, desc=desc, total=len(dataloader), disable=not accelerator.is_main_process)
|
| 107 |
+
for i, batch in enumerate(progress_bar):
|
| 108 |
+
pred_cubemap = pager(batch, cfg.model.modality)
|
| 109 |
+
if cfg.model.modality == "depth":
|
| 110 |
+
min_depth = dataloader.dataset.LOG_MIN_DEPTH if cfg.model.log_scale else dataloader.dataset.MIN_DEPTH
|
| 111 |
+
depth_range = dataloader.dataset.LOG_DEPTH_RANGE if cfg.model.log_scale else dataloader.dataset.DEPTH_RANGE
|
| 112 |
+
loss = pager.calculate_depth_loss(batch, pred_cubemap, min_depth, depth_range, cfg.model.log_scale, cfg.model.metric_depth)
|
| 113 |
+
elif cfg.model.modality == "normal":
|
| 114 |
+
loss = pager.calculate_normal_loss(batch, pred_cubemap)
|
| 115 |
+
|
| 116 |
+
avg_loss = accelerator.reduce(loss["total_loss"].detach(), reduction="mean")
|
| 117 |
+
if accelerator.is_main_process:
|
| 118 |
+
progress_bar.set_postfix({"loss": avg_loss.item()})
|
| 119 |
+
val_epoch_loss += avg_loss
|
| 120 |
+
if i in log_img_ids:
|
| 121 |
+
log_val_images["rgb"].append(prepare_image_for_logging(batch["rgb"][0].cpu().numpy()))
|
| 122 |
+
if cfg.model.modality == "depth":
|
| 123 |
+
result_image = pager.process_depth_output(pred_cubemap, orig_size=batch['depth'].shape[2:4], min_depth=min_depth,
|
| 124 |
+
depth_range=depth_range, log_scale=cfg.model.log_scale)[1].cpu().numpy()
|
| 125 |
+
elif cfg.model.modality == "normal":
|
| 126 |
+
result_image = pager.process_normal_output(pred_cubemap, orig_size=batch['normal'].shape[2:4]).cpu().numpy()
|
| 127 |
+
log_val_images[cfg.model.modality].append(prepare_image_for_logging(result_image))
|
| 128 |
+
|
| 129 |
+
val_epoch_loss = val_epoch_loss / len(dataloader)
|
| 130 |
+
|
| 131 |
+
if accelerator.is_main_process:
|
| 132 |
+
accelerator.log({x_axis_name: x_axis, f"{val_type}/loss": float(val_epoch_loss)}, step=global_step)
|
| 133 |
+
|
| 134 |
+
img_mix_rgb = log_images_mosaic(log_val_images["rgb"])
|
| 135 |
+
img_mix_depth = log_images_mosaic(log_val_images[cfg.model.modality])
|
| 136 |
+
|
| 137 |
+
if cfg.logging.report_to == "wandb":
|
| 138 |
+
accelerator.log(
|
| 139 |
+
{x_axis_name: x_axis, f"{val_type}/pred_panorama_rgb": wandb.Image(img_mix_rgb)},
|
| 140 |
+
step=global_step,
|
| 141 |
+
)
|
| 142 |
+
accelerator.log(
|
| 143 |
+
{x_axis_name: x_axis, f"{val_type}/pred_panorama_{cfg.model.modality}": wandb.Image(img_mix_depth)},
|
| 144 |
+
step=global_step,
|
| 145 |
+
)
|
| 146 |
+
elif cfg.logging.report_to == "tensorboard":
|
| 147 |
+
tb_writer = accelerator.get_tracker("tensorboard").writer
|
| 148 |
+
tb_writer.add_image(
|
| 149 |
+
f"{val_type}/pred_panorama_rgb",
|
| 150 |
+
img_mix_rgb,
|
| 151 |
+
global_step,
|
| 152 |
+
dataformats="HWC",
|
| 153 |
+
)
|
| 154 |
+
tb_writer.add_image(
|
| 155 |
+
f"{val_type}/pred_panorama_{cfg.model.modality}",
|
| 156 |
+
img_mix_depth,
|
| 157 |
+
global_step,
|
| 158 |
+
dataformats="HWC",
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
if cfg.training.use_EMA:
|
| 162 |
+
ema_unet.restore(pager.unwrapped_unet.parameters())
|
| 163 |
+
return val_epoch_loss
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def prepare_image_for_logging(image):
|
| 167 |
+
image = (image - image.min()) / (image.max() - image.min() + 1e-8)
|
| 168 |
+
image = (image * 255).astype("uint8")
|
| 169 |
+
return image
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def log_images_mosaic(images):
|
| 173 |
+
n = len(images)
|
| 174 |
+
assert 1 <= n <= 4, "Provide between 1 and 4 images (CHW uint8)."
|
| 175 |
+
|
| 176 |
+
fullhd_imgs = []
|
| 177 |
+
for img in images:
|
| 178 |
+
assert img.dtype == np.uint8 and img.ndim == 3 and img.shape[0] in (1, 3), \
|
| 179 |
+
"Each image must be uint8 with shape (C,H,W), C in {1,3}."
|
| 180 |
+
|
| 181 |
+
if img.shape[0] == 1:
|
| 182 |
+
img = np.repeat(img, 3, axis=0)
|
| 183 |
+
img_hwc = np.transpose(img, (1, 2, 0))
|
| 184 |
+
|
| 185 |
+
img_fullhd = cv2.resize(img_hwc, (1920, 1080), interpolation=cv2.INTER_LINEAR)
|
| 186 |
+
fullhd_imgs.append(img_fullhd)
|
| 187 |
+
|
| 188 |
+
H, W, C = 1080, 1920, 3
|
| 189 |
+
|
| 190 |
+
if n == 1:
|
| 191 |
+
return fullhd_imgs[0]
|
| 192 |
+
|
| 193 |
+
if n == 2:
|
| 194 |
+
canvas = np.zeros((H, 2*W, C), dtype=np.uint8)
|
| 195 |
+
canvas[:, 0:W, :] = fullhd_imgs[0]
|
| 196 |
+
canvas[:, W:2*W, :] = fullhd_imgs[1]
|
| 197 |
+
return canvas
|
| 198 |
+
|
| 199 |
+
if n == 3:
|
| 200 |
+
canvas = np.zeros((2*H, 2*W, C), dtype=np.uint8)
|
| 201 |
+
x_off = W // 2
|
| 202 |
+
canvas[0:H, x_off:x_off+W, :] = fullhd_imgs[0]
|
| 203 |
+
canvas[H:2*H, 0:W, :] = fullhd_imgs[1]
|
| 204 |
+
canvas[H:2*H, W:2*W, :] = fullhd_imgs[2]
|
| 205 |
+
return canvas
|
| 206 |
+
|
| 207 |
+
canvas = np.zeros((2*H, 2*W, C), dtype=np.uint8)
|
| 208 |
+
canvas[0:H, 0:W, :] = fullhd_imgs[0]
|
| 209 |
+
canvas[0:H, W:2*W, :] = fullhd_imgs[1]
|
| 210 |
+
canvas[H:2*H, 0:W, :] = fullhd_imgs[2]
|
| 211 |
+
canvas[H:2*H, W:2*W, :] = fullhd_imgs[3]
|
| 212 |
+
return canvas
|
| 213 |
+
|
| 214 |
+
|