vulus98 commited on
Commit
05d33a4
·
1 Parent(s): 4817e08

Save work before migration

Browse files
.gitattributes CHANGED
@@ -33,4 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
36
  *.jpg filter=lfs diff=lfs merge=lfs -text
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.jpg~ filter=lfs diff=lfs merge=lfs -text
37
  *.jpg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__/
2
+ .vscode/
README.md CHANGED
@@ -8,7 +8,24 @@ sdk_version: 6.0.2
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
- short_description: Panorama Geometry Estimation using onestep diffusion models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
+ python_version: 3.10
12
+ models:
13
+ - prs-eth/PaGeR-depth
14
+ - prs-eth/PaGeR-normals
15
+ tags:
16
+ - computer-vision
17
+ - image-processing
18
+ - diffusion-models
19
+ - panorama
20
+ - geometry-estimation
21
+ - depth-estimation
22
+ - normal-estimation
23
+ - single-step-diffusion
24
+ preload_from_hub:
25
+ - prs-eth/PaGeR-depth
26
+ - prs-eth/PaGeR-normals
27
+ suggested_hardware: a10g-large
28
+ short_description: Panorama Geometry Estimation
29
  ---
30
 
31
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import gc
3
+ import torch
4
+ import numpy as np
5
+ import argparse
6
+ import logging
7
+ import gradio as gr
8
+ from PIL import Image
9
+ from pathlib import Path
10
+ from omegaconf import OmegaConf
11
+ from tempfile import NamedTemporaryFile
12
+ from huggingface_hub import hf_hub_download
13
+ from matplotlib import pyplot as plt
14
+ from src.pager import Pager
15
+ from src.utils.geometry_utils import compute_edge_mask, erp_to_point_cloud_glb, erp_to_cubemap
16
+ from src.utils.utils import prepare_image_for_logging
17
+
18
+ MIN_DEPTH = np.log(1e-2)
19
+ DEPTH_RANGE = np.log(75.0)
20
+ POINTCLOUD_DOWNSAMPLE_FACTOR = 2
21
+ MAX_POINTCLOUD_POINTS = 200000
22
+ EXAMPLES_DIR = Path(__file__).parent / "examples"
23
+ EXAMPLE_IMAGES = [
24
+ str(p)
25
+ for p in sorted(EXAMPLES_DIR.glob("*"))
26
+ if p.suffix.lower() in {".jpg", ".jpeg", ".png", ".webp"}
27
+ ]
28
+
29
+ def parse_args():
30
+ parser = argparse.ArgumentParser(description="Inference script for panorama depth estimation using diffusion models.")
31
+
32
+ parser.add_argument(
33
+ "--seed",
34
+ type=int,
35
+ default=42,
36
+ help="A seed for reproducibility."
37
+ )
38
+
39
+ parser.add_argument(
40
+ "--depth_checkpoint_path",
41
+ default="prs-eth/PaGeR-depth",
42
+ type=str,
43
+ help="UNet checkpoint to load.",
44
+ )
45
+
46
+ parser.add_argument(
47
+ "--normals_checkpoint_path",
48
+ default="prs-eth/PaGeR-normals",
49
+ type=str,
50
+ help="UNet checkpoint to load.",
51
+ )
52
+
53
+ parser.add_argument(
54
+ "--enable_xformers",
55
+ action="store_true",
56
+ help="Whether or not to use xformers."
57
+ )
58
+
59
+ args = parser.parse_args()
60
+ return args
61
+
62
+ def _release_cuda_memory():
63
+ if torch.cuda.is_available():
64
+ torch.cuda.empty_cache()
65
+ torch.cuda.ipc_collect()
66
+ gc.collect()
67
+
68
+ def generate_ERP(input_rgb, modality):
69
+ batch = {}
70
+ input_rgb = torch.from_numpy(input_rgb).permute(2,0,1).to(torch.float32) / 255.0
71
+ input_rgb = input_rgb * 2.0 - 1.0
72
+ batch['rgb_cubemap'] = erp_to_cubemap(input_rgb).unsqueeze(0).to(device)
73
+ with torch.inference_mode():
74
+ torch.cuda.reset_peak_memory_stats()
75
+ torch.cuda.empty_cache()
76
+ torch.cuda.synchronize()
77
+ pred_cubemap = pager(batch, modality)
78
+ if modality == "depth":
79
+ pred, pred_image = pager.process_depth_output(pred_cubemap, orig_size=(1024, 2048),
80
+ min_depth=MIN_DEPTH,
81
+ depth_range=DEPTH_RANGE,
82
+ log_scale=pager.model_configs["depth"]["config"].log_scale)
83
+ pred, pred_image = pred[0].cpu().numpy(), pred_image.cpu().numpy()
84
+ pred_image = np.clip(pred_image, pred_image.min(), np.quantile(pred_image, 0.99))
85
+ pred_image = prepare_image_for_logging(pred_image)
86
+ pred_image = cmap(pred_image[0,...]/255.0)
87
+ pred_image = (pred_image[..., :3] * 255).astype(np.uint8)
88
+ elif modality == "normal":
89
+ pred = pager.process_normal_output(pred_cubemap, orig_size=(1024, 2048))
90
+ pred = pred.cpu().numpy()
91
+ pred_image = pred.copy()
92
+ pred_image = prepare_image_for_logging(pred_image).transpose(1,2,0)
93
+
94
+ return pred_image, pred
95
+
96
+ def process_panorama(image_path, output_type, include_pointcloud):
97
+ loaded_image = Image.open(image_path).convert("RGB").resize((2048, 1024))
98
+ input_rgb = np.array(loaded_image)
99
+
100
+ modality = "depth" if output_type.lower() == "depth" else "normal"
101
+ is_depth = modality == "depth"
102
+ main_label = "Depth Output" if is_depth else "Surface Normal Output"
103
+ pc_label = (
104
+ "RGB-colored Point Cloud" if is_depth else "Surface Normals-Colored Point Cloud"
105
+ )
106
+ output_image, raw_pred = generate_ERP(input_rgb, modality)
107
+
108
+ point_cloud = None
109
+ if include_pointcloud:
110
+ if is_depth:
111
+ depth = np.squeeze(np.array(raw_pred))
112
+ color = (input_rgb.astype(np.float32) / 127.5) - 1.0
113
+ else:
114
+ color = np.array(raw_pred)
115
+ color = np.transpose(color, (1, 2, 0))
116
+ _release_cuda_memory()
117
+ depth = np.squeeze(generate_ERP(input_rgb, "depth", )[1])
118
+
119
+ edge_filtered_mask = compute_edge_mask(
120
+ depth,
121
+ abs_thresh=0.002,
122
+ rel_thresh=0.002,
123
+ )
124
+
125
+ if POINTCLOUD_DOWNSAMPLE_FACTOR > 1:
126
+ depth = depth[::POINTCLOUD_DOWNSAMPLE_FACTOR, ::POINTCLOUD_DOWNSAMPLE_FACTOR]
127
+ color = color[::POINTCLOUD_DOWNSAMPLE_FACTOR, ::POINTCLOUD_DOWNSAMPLE_FACTOR]
128
+ edge_filtered_mask = edge_filtered_mask[::POINTCLOUD_DOWNSAMPLE_FACTOR, ::POINTCLOUD_DOWNSAMPLE_FACTOR]
129
+
130
+ tmp = NamedTemporaryFile(suffix=".glb", delete=False)
131
+ erp_to_point_cloud_glb(
132
+ color, depth, edge_filtered_mask, export_path=tmp.name)
133
+
134
+ tmp.close()
135
+ point_cloud = tmp.name
136
+
137
+ _release_cuda_memory()
138
+
139
+ return (
140
+ gr.update(value=output_image, label=main_label),
141
+ gr.update(value=point_cloud, label=pc_label),
142
+ )
143
+
144
+
145
+ def clear_pointcloud():
146
+ return gr.update(value=None)
147
+
148
+
149
+ args = parse_args()
150
+
151
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
152
+
153
+ logger = logging.getLogger("simple")
154
+ handler = logging.StreamHandler(sys.stdout)
155
+ formatter = logging.Formatter("%(message)s")
156
+ handler.setFormatter(formatter)
157
+ logger.addHandler(handler)
158
+ logger.setLevel(logging.INFO)
159
+ logger.propagate = False
160
+ cmap = plt.get_cmap("Spectral")
161
+
162
+
163
+ checkpoint_config = {}
164
+ try:
165
+ depth_checkpoint_config_path = hf_hub_download(
166
+ repo_id=args.depth_checkpoint_path,
167
+ filename="config.yaml"
168
+ )
169
+ except Exception as e:
170
+ depth_checkpoint_config_path = Path(args.depth_checkpoint_path) / "config.yaml"
171
+ depth_config = OmegaConf.load(depth_checkpoint_config_path)
172
+ checkpoint_config["depth"] = {"path": args.depth_checkpoint_path, "mode": "trained", "config": depth_config.model}
173
+
174
+ try:
175
+ normal_checkpoint_config_path = hf_hub_download(
176
+ repo_id=args.normals_checkpoint_path,
177
+ filename="config.yaml"
178
+ )
179
+ except Exception as e:
180
+ normal_checkpoint_config_path = Path(args.normals_checkpoint_path) / "config.yaml"
181
+ normal_config = OmegaConf.load(normal_checkpoint_config_path)
182
+ checkpoint_config["normal"] = {"path": args.normals_checkpoint_path, "mode": "trained", "config": normal_config.model}
183
+
184
+ pager = Pager(model_configs=checkpoint_config, pretrained_path = depth_config.model.pretrained_path, device=device)
185
+ pager.unet["depth"].to(device, dtype=pager.weight_dtype)
186
+ pager.unet["depth"].eval()
187
+ pager.unet["normal"].to(device, dtype=pager.weight_dtype)
188
+ pager.unet["normal"].eval()
189
+
190
+
191
+ with gr.Blocks() as demo:
192
+ gr.Markdown("## 📟 PaGeR: Panoramic Geometry Reconstruction")
193
+
194
+ with gr.Row():
195
+ with gr.Column(scale=1):
196
+ image_input = gr.Image(
197
+ label="RGB ERP Image",
198
+ type="filepath",
199
+ height=320,
200
+ )
201
+ output_choice = gr.Radio(
202
+ ["Depth", "Surface Normals"],
203
+ value="Depth",
204
+ label="Output Type",
205
+ )
206
+ pointcloud_checkbox = gr.Checkbox(
207
+ label="Generate Point Cloud",
208
+ value=True,
209
+ )
210
+ gr.Examples(
211
+ examples=EXAMPLE_IMAGES,
212
+ inputs=image_input,
213
+ label="Pick an example (or upload your own above)",
214
+ examples_per_page=8,
215
+ cache_examples=False,
216
+ )
217
+ run_button = gr.Button("Run Inference")
218
+
219
+ with gr.Column(scale=1):
220
+ rendered_output = gr.Image(
221
+ label="Output",
222
+ type="numpy",
223
+ height=320,
224
+ )
225
+
226
+ with gr.Row():
227
+ pointcloud_output = gr.Model3D(
228
+ label="Point Cloud",
229
+ height=360,
230
+ clear_color=[0.0, 0.0, 0.0, 0.0],
231
+ )
232
+
233
+ (
234
+ run_button.click(
235
+ fn=clear_pointcloud,
236
+ outputs=pointcloud_output,
237
+ queue=False,
238
+ )
239
+ .then(
240
+ fn=process_panorama,
241
+ inputs=[image_input, output_choice, pointcloud_checkbox],
242
+ outputs=[rendered_output, pointcloud_output],
243
+ )
244
+ )
245
+
246
+ if __name__ == "__main__":
247
+ _release_cuda_memory()
248
+ demo.launch()
examples/alice.jpg ADDED

Git LFS Details

  • SHA256: 08bc06d4f11394aba2ed22a211186cda79979449fde0a6216549e547c430c3e5
  • Pointer size: 131 Bytes
  • Size of remote file: 287 kB
examples/example_1.jpg ADDED

Git LFS Details

  • SHA256: 875fc6107ae7b2d666767e2847e5652efab9d069805487c33c26fdc3897f6210
  • Pointer size: 131 Bytes
  • Size of remote file: 246 kB
examples/example_2.jpg ADDED

Git LFS Details

  • SHA256: 540fbc57b080fada086be2e3bed0bc5d7b58f9eb55d4637853f6772c355a42b6
  • Pointer size: 131 Bytes
  • Size of remote file: 144 kB
examples/greenhouse.jpg ADDED

Git LFS Details

  • SHA256: 2f17cebe56b3df5c02e30ff1ae2ad1771bfd10441ed6842543dae1e2f09a540a
  • Pointer size: 131 Bytes
  • Size of remote file: 517 kB
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.2.0
2
+ xformers==0.0.24
3
+ accelerate==0.27.2
4
+ gradio==6.0.2
5
+ huggingface-hub==0.36.0
6
+ transformers
7
+ diffusers==0.30.2
8
+ numpy==1.26.4
9
+ scipy==1.15.1
10
+ matplotlib==3.10.0
11
+ tqdm==4.67.1
12
+ einops==0.8.1
13
+ datasets==3.3.0
14
+ python-dotenv==1.1.1
15
+ wandb==0.19.6
16
+ opencv-python==4.11.0.86
17
+ pytorch360convert==0.2.3
18
+ trimesh==4.9.0
src/__init__.py ADDED
File without changes
src/pager.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import Conv2d
4
+ from transformers import CLIPTextModel, CLIPTokenizer
5
+ from diffusers import DDPMScheduler
6
+ from diffusers.utils.import_utils import is_xformers_available
7
+ from Marigold.unet.unet_2d_condition import UNet2DConditionModel
8
+ from Marigold.vae.autoencoder_kl import AutoencoderKL
9
+ from src.utils.conv_padding import PaddedConv2d, valid_pad_conv_fn
10
+ from src.utils.loss import L1Loss, GradL1Loss, CosineNormalLoss
11
+ from src.utils.geometry_utils import (
12
+ get_positional_encoding,
13
+ compute_scale_and_shift,
14
+ compute_shift,
15
+ depth_to_normals_erp,
16
+ cubemap_to_erp
17
+ )
18
+
19
+
20
+ class Pager(nn.Module):
21
+ def __init__(self,
22
+ model_configs,
23
+ pretrained_path,
24
+ train_modality=None,
25
+ device=torch.device("cpu"),
26
+ weight_dtype=torch.float32):
27
+ super().__init__()
28
+ self.model_configs = model_configs
29
+ self.weight_dtype = weight_dtype
30
+ self.rgb_latent_scale_factor = 0.18215
31
+ self.depth_latent_scale_factor = 0.18215
32
+ self.train_modality = train_modality
33
+ self.device = device
34
+ self.prepare_model_components(pretrained_path, model_configs)
35
+ self.prepare_empty_encoding()
36
+
37
+ self.alpha_prod = self.noise_scheduler.alphas_cumprod.to(device, dtype=weight_dtype)
38
+ self.beta_prod = 1 - self.alpha_prod
39
+ self.num_timesteps = self.noise_scheduler.config.num_train_timesteps - 1
40
+ del self.noise_scheduler
41
+
42
+
43
+ def prepare_model_components(self, pretrained_path, model_configs):
44
+ vae_use_RoPE = None
45
+ for checkpoint_cfg in model_configs.values():
46
+ if vae_use_RoPE is None:
47
+ vae_use_RoPE = checkpoint_cfg['config'].vae_use_RoPE == "RoPE"
48
+ elif vae_use_RoPE != (checkpoint_cfg['config'].vae_use_RoPE == "RoPE"):
49
+ raise ValueError("All UNet checkpoints must use the same VAE positional encoding configuration.")
50
+
51
+ self.noise_scheduler = DDPMScheduler.from_pretrained(pretrained_path, subfolder="scheduler", rescale_betas_zero_snr=True)
52
+ self.tokenizer = CLIPTokenizer.from_pretrained(pretrained_path, subfolder="tokenizer", revision=None)
53
+ self.text_encoder = CLIPTextModel.from_pretrained(pretrained_path, subfolder="text_encoder", revision=None, variant=None)
54
+ self.vae = AutoencoderKL.from_pretrained(pretrained_path, subfolder="vae", revision=None, variant=None,
55
+ use_RoPE = vae_use_RoPE)
56
+ self.set_valid_pad_conv(self.vae)
57
+
58
+ self.vae.requires_grad_(False)
59
+ self.vae.to(self.device, dtype=self.weight_dtype)
60
+ self.vae.eval()
61
+
62
+ self.text_encoder.requires_grad_(False)
63
+ self.text_encoder.to(self.device, dtype=self.weight_dtype)
64
+ self.text_encoder.eval()
65
+
66
+
67
+ base_in_channels = 8
68
+ pe_channels_size = 0
69
+
70
+ self.unet = {}
71
+ for modality, checkpoint_cfg in model_configs.items():
72
+ if checkpoint_cfg['config'].unet_positional_encoding == "uv":
73
+ pe_channels_size = 2
74
+ target_in_channels = base_in_channels + pe_channels_size
75
+
76
+ self.unet[modality] = UNet2DConditionModel.from_pretrained(
77
+ checkpoint_cfg["path"],
78
+ subfolder="unet",
79
+ revision=None,
80
+ in_channels=target_in_channels if checkpoint_cfg["mode"] == "trained" else base_in_channels,
81
+ use_RoPE=checkpoint_cfg['config'].unet_positional_encoding == "RoPE"
82
+ )
83
+
84
+ if target_in_channels > base_in_channels and checkpoint_cfg["mode"] != "trained":
85
+ self.extend_unet_conv_in(self.unet[modality], new_in_channels=target_in_channels)
86
+ self.set_valid_pad_conv(self.unet[modality])
87
+
88
+
89
+ if checkpoint_cfg['config'].enable_xformers:
90
+ if is_xformers_available():
91
+ import xformers
92
+ if self.unet.get("depth"):
93
+ self.unet["depth"].enable_xformers_memory_efficient_attention()
94
+ if self.unet.get("normal"):
95
+ self.unet["normal"].enable_xformers_memory_efficient_attention()
96
+ self.vae.enable_xformers_memory_efficient_attention()
97
+
98
+
99
+ def prepare_training(self, accelerator, gradient_checkpointing):
100
+ self.unwrapped_unet = self.unet[self.train_modality]
101
+ self.unet[self.train_modality] = accelerator.prepare(self.unet[self.train_modality])
102
+ self.trained_unet = self.unet[self.train_modality]
103
+
104
+ if gradient_checkpointing:
105
+ self.trained_unet._set_gradient_checkpointing()
106
+ self.vae._set_gradient_checkpointing()
107
+
108
+
109
+ def prepare_cubemap_PE(self, image_height, image_width):
110
+ use_uv_PE = False
111
+ for checkpoint_cfg in self.model_configs.values():
112
+ if checkpoint_cfg['config'].unet_positional_encoding == "uv":
113
+ use_uv_PE = True
114
+ if use_uv_PE:
115
+ PE_cubemap = get_positional_encoding(image_height, image_width)
116
+ self.PE_cubemap = PE_cubemap.to(device=self.device, dtype=self.weight_dtype)
117
+
118
+ def prepare_empty_encoding(self):
119
+ with torch.inference_mode():
120
+ empty_token = self.tokenizer([""], padding="max_length", truncation=True, return_tensors="pt").input_ids
121
+ empty_token = empty_token.to(self.device)
122
+ empty_encoding = self.text_encoder(empty_token, return_dict=False)[0]
123
+ self.empty_encoding = empty_encoding.to(self.device, dtype=self.weight_dtype)
124
+
125
+ del empty_token
126
+ del self.text_encoder
127
+ del self.tokenizer
128
+
129
+
130
+ def forward(self, batch, modality):
131
+ with torch.inference_mode():
132
+ c, h, w = batch["rgb_cubemap"].shape[2:]
133
+ rgb_vae_input = batch["rgb_cubemap"].reshape(-1, c, h, w).to(dtype=self.weight_dtype)
134
+ rgb_latents = self.vae.encode(rgb_vae_input, deterministic=True)
135
+ rgb_latents = rgb_latents * self.rgb_latent_scale_factor
136
+ del rgb_vae_input
137
+
138
+ timesteps = torch.ones((rgb_latents.shape[0],), device=self.device) * self.num_timesteps
139
+ timesteps = timesteps.long()
140
+ alpha_prod_t = self.alpha_prod[timesteps].view(-1, 1, 1, 1)
141
+ beta_prod_t = self.beta_prod[timesteps].view(-1, 1, 1, 1)
142
+
143
+ noisy_latents = torch.zeros_like(rgb_latents).to(self.device)
144
+ encoder_hidden_states = self.empty_encoding.repeat(batch["rgb_cubemap"].shape[0] * 6, 1, 1)
145
+ if self.model_configs[modality]['config'].unet_positional_encoding == "uv":
146
+ batch_PE_cubemap = self.PE_cubemap.repeat(batch["rgb_cubemap"].shape[0], 1, 1, 1)
147
+ unet_input = torch.cat((rgb_latents, noisy_latents, batch_PE_cubemap), dim=1).to(
148
+ self.device
149
+ )
150
+ else:
151
+ unet_input = torch.cat((rgb_latents, noisy_latents), dim=1).to(self.device)
152
+
153
+ del rgb_latents
154
+ model_pred = self.unet[modality](
155
+ unet_input,
156
+ timesteps,
157
+ encoder_hidden_states,
158
+ return_dict=False,
159
+ )[0]
160
+
161
+ current_latent_estimate = (alpha_prod_t**0.5) * noisy_latents - (beta_prod_t**0.5) * model_pred
162
+ current_scaled_latent_estimate = current_latent_estimate / self.depth_latent_scale_factor
163
+ pred_cubemap = self.vae.decode(current_scaled_latent_estimate, deterministic=True)
164
+
165
+ if modality == "depth":
166
+ pred_cubemap = pred_cubemap.mean(dim=1, keepdim=True)
167
+ return pred_cubemap
168
+
169
+
170
+ def prepare_losses_dict(self, loss_cfg):
171
+ self.losses_dict = {}
172
+ if self.train_modality == "depth":
173
+ self.losses_dict["l1_loss"] = {"loss_fn": L1Loss(invalid_mask_weight=loss_cfg.invalid_mask_weight),
174
+ "weight": loss_cfg.l1_loss_weight}
175
+ if loss_cfg.grad_loss_weight > 0.0:
176
+ self.losses_dict["grad_loss"] = {"loss_fn": GradL1Loss(), "weight": loss_cfg.grad_loss_weight}
177
+ if loss_cfg.normals_consistency_loss_weight > 0.0:
178
+ self.losses_dict["normals_consistency_loss"] = {"loss_fn": CosineNormalLoss(),
179
+ "weight": loss_cfg.normals_consistency_loss_weight}
180
+ else:
181
+ self.losses_dict["cosine_normal_loss"] = {"loss_fn": CosineNormalLoss(), "weight": 1.0}
182
+
183
+
184
+ def calculate_depth_loss(self, batch, pred_cubemap, min_depth, depth_range, log_scale, metric_depth):
185
+ loss = {"total_loss": 0.0}
186
+
187
+ gt_depth_cubemap = batch['depth_cubemap'].squeeze(0).mean(dim=1, keepdim=True)
188
+ mask_cubemap = batch["mask_cubemap"].squeeze(0)
189
+
190
+ if not metric_depth:
191
+ if log_scale:
192
+ scale = compute_shift(pred_cubemap, gt_depth_cubemap, mask_cubemap)
193
+ else:
194
+ scale, shift = compute_scale_and_shift(pred_cubemap, gt_depth_cubemap, mask_cubemap)
195
+
196
+ if log_scale:
197
+ pred_cubemap += scale
198
+ else:
199
+ pred_cubemap = pred_cubemap * scale + shift
200
+
201
+ for loss_name, loss_params in self.losses_dict.items():
202
+ if loss_name == "normals_consistency_loss":
203
+ gt = batch['normal']
204
+ pred_depth = pred_cubemap
205
+ mask = batch["mask"]
206
+ pred_depth = self.process_depth_output(pred_depth, orig_size=gt.shape[2:], min_depth=min_depth,
207
+ depth_range=depth_range, log_scale=log_scale)[0]
208
+ pred = depth_to_normals_erp(pred_depth).unsqueeze(0)
209
+ else:
210
+ pred = pred_cubemap
211
+ gt = gt_depth_cubemap
212
+ mask = mask_cubemap
213
+ loss[loss_name] = loss_params["loss_fn"](pred, gt, mask)
214
+ loss["total_loss"] += loss[loss_name] * loss_params["weight"]
215
+
216
+ return loss
217
+
218
+
219
+ def calculate_normal_loss(self, batch, pred_cubemap):
220
+ loss = {"total_loss": 0.0}
221
+
222
+ gt_normal_cubemap = batch['normal_cubemap'].squeeze(0)
223
+ mask_cubemap = batch["mask_cubemap"].squeeze(0)
224
+
225
+ for loss_name, loss_params in self.losses_dict.items():
226
+ pred = pred_cubemap
227
+ gt = gt_normal_cubemap
228
+ loss[loss_name] = loss_params["loss_fn"](pred, gt, mask_cubemap)
229
+ loss["total_loss"] += loss[loss_name] * loss_params["weight"]
230
+
231
+ return loss
232
+
233
+ def process_depth_output(self,pred_cubemap, orig_size, min_depth, depth_range, log_scale, mask=None):
234
+ pred_panorama = cubemap_to_erp(pred_cubemap, *orig_size)
235
+ pred_panorama = torch.clamp(pred_panorama, -1, 1)
236
+ pred_panorama = (pred_panorama + 1) / 2
237
+ if mask is not None:
238
+ pred_panorama *= mask
239
+ pred_panorama = pred_panorama * depth_range + min_depth
240
+ if log_scale:
241
+ pred_panorama_viz = pred_panorama.clone()
242
+ pred_panorama = torch.exp(pred_panorama)
243
+ else:
244
+ pred_panorama_viz = torch.log(pred_panorama)
245
+
246
+ return pred_panorama, pred_panorama_viz
247
+
248
+
249
+ def process_normal_output(self,pred_cubemap, orig_size):
250
+ pred_panorama = cubemap_to_erp(pred_cubemap, *orig_size)
251
+ pred_panorama = torch.clamp(pred_panorama, -1, 1)
252
+ return pred_panorama
253
+
254
+
255
+ def extend_unet_conv_in(self, unet, new_in_channels: int):
256
+ if new_in_channels < unet.conv_in.in_channels:
257
+ raise ValueError(
258
+ f"new_in_channels ({new_in_channels}) must be >= current "
259
+ f"{unet.conv_in.in_channels}"
260
+ )
261
+ if new_in_channels == unet.conv_in.in_channels:
262
+ return
263
+
264
+ old_conv = unet.conv_in
265
+ old_in = old_conv.in_channels
266
+ device, dtype = old_conv.weight.device, old_conv.weight.dtype
267
+ bias_flag = old_conv.bias is not None
268
+
269
+ new_conv = Conv2d(
270
+ new_in_channels,
271
+ old_conv.out_channels,
272
+ kernel_size=old_conv.kernel_size,
273
+ stride=old_conv.stride,
274
+ padding=old_conv.padding,
275
+ bias=bias_flag,
276
+ padding_mode=old_conv.padding_mode,
277
+ ).to(device=device, dtype=dtype)
278
+
279
+ new_conv.weight.zero_()
280
+ new_conv.weight[:, :old_in].copy_(old_conv.weight)
281
+ if bias_flag:
282
+ new_conv.bias.copy_(old_conv.bias)
283
+
284
+ unet.conv_in = new_conv
285
+ unet.config["in_channels"] = new_in_channels
286
+
287
+
288
+ def set_valid_pad_conv(self, module: nn.Module):
289
+ for name, child in list(module.named_children()):
290
+ if isinstance(child, nn.Conv2d):
291
+ if child.padding != (0, 0):
292
+ setattr(module, name, PaddedConv2d.from_existing(child, valid_pad_conv_fn))
293
+ elif module.__class__.__name__ == "Downsample2D" and module.use_conv:
294
+ setattr(module, name, PaddedConv2d.from_existing(child, valid_pad_conv_fn, one_side_pad=True))
295
+ else:
296
+ self.set_valid_pad_conv(child)
297
+
298
+
299
+ def save_model(self, ema_unet, model_save_dir):
300
+ self.unwrapped_unet.save_pretrained(model_save_dir / "original")
301
+ if ema_unet is not None:
302
+ ema_unet.store(self.unwrapped_unet.parameters())
303
+ ema_unet.copy_to(self.unwrapped_unet.parameters())
304
+ self.unwrapped_unet.save_pretrained(model_save_dir / f"EMA")
305
+ ema_unet.restore(self.unwrapped_unet.parameters())
306
+
307
+
308
+
src/utils/__init__.py ADDED
File without changes
src/utils/conv_padding.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+
4
+ orderings = [
5
+ [0, 1, 3, 4, 5],
6
+ [1, 2, 0, 4, 5],
7
+ [2, 3, 1, 4, 5],
8
+ [3, 0, 2, 4, 5],
9
+ [4, 1, 3, 2, 0],
10
+ [5, 1, 3, 0, 2],
11
+ ]
12
+ rotations = [
13
+ [0, 0, 0, 0, 0],
14
+ [0, 0, 0,-1, 1],
15
+ [0, 0, 0, 2, 2],
16
+ [0, 0, 0, 1,-1],
17
+ [0, 1,-1, 2, 0],
18
+ [0,-1, 1, 0, 2]
19
+ ]
20
+
21
+ def _take_right(face, rot):
22
+ if rot == 0:
23
+ return face[:, :, 0]
24
+ elif rot == 1:
25
+ return face[:, 0, :].flip(1)
26
+ elif rot == 2:
27
+ return face[:, :, -1].flip(1)
28
+ elif rot == -1:
29
+ return face[:, -1, :]
30
+
31
+ def _take_left(face, rot):
32
+ if rot == 0:
33
+ return face[:, :, -1]
34
+ elif rot == 1:
35
+ return face[:, -1, :].flip(1)
36
+ elif rot == 2:
37
+ return face[:, :, 0].flip(1)
38
+ elif rot == -1:
39
+ return face[:, 0, :]
40
+
41
+ def _take_top(face, rot):
42
+ if rot == 0:
43
+ return face[:, -1, :]
44
+ elif rot == 1:
45
+ return face[:, :, 0]
46
+ elif rot == 2:
47
+ return face[:, 0, :].flip(1)
48
+ elif rot == -1:
49
+ return face[:, :, -1].flip(1)
50
+
51
+ def _take_bottom(face, rot):
52
+ if rot == 0:
53
+ return face[:, 0, :]
54
+ elif rot == 1:
55
+ return face[:, :, -1]
56
+ elif rot == 2:
57
+ return face[:, -1, :].flip(1)
58
+ elif rot == -1:
59
+ return face[:, :, 0].flip(1)
60
+
61
+ def valid_pad_conv_fn(x, one_side_pad=False):
62
+ if one_side_pad:
63
+ x = x[:, :, :-1, :-1]
64
+ assert x.ndim == 4 and x.shape[0] == 6
65
+ _, C, H, W = x.shape
66
+ y = x.new_empty(6, C, H+2, W+2)
67
+ y[..., 1:-1, 1:-1] = x
68
+
69
+ for i in range(6):
70
+ r_idx, l_idx, t_idx, b_idx = orderings[i][1:5]
71
+ r_rot, l_rot, t_rot, b_rot = rotations[i][1:5]
72
+
73
+ r_edge = _take_right (x[r_idx], r_rot)
74
+ l_edge = _take_left (x[l_idx], l_rot)
75
+ t_edge = _take_top (x[t_idx], t_rot)
76
+ b_edge = _take_bottom(x[b_idx], b_rot)
77
+
78
+ y[i, :, 1:-1, 0 ] = l_edge
79
+ y[i, :, 1:-1, -1 ] = r_edge
80
+ y[i, :, 0, 1:-1] = t_edge
81
+ y[i, :, -1, 1:-1] = b_edge
82
+
83
+ y[i, :, 0, 0 ] = 0.5*(y[i, :, 0, 1] + y[i, :, 1, 0])
84
+ y[i, :, 0, -1 ] = 0.5*(y[i, :, 0, -2] + y[i, :, 1, -1])
85
+ y[i, :, -1, 0 ] = 0.5*(y[i, :, -2, 0] + y[i, :, -1, 1])
86
+ y[i, :, -1,-1 ] = 0.5*(y[i, :, -2, -1] + y[i, :, -1, -2])
87
+
88
+ if one_side_pad:
89
+ return y[:, :, 1:, 1:]
90
+
91
+ return y
92
+
93
+
94
+ class PaddedConv2d(nn.Conv2d):
95
+ def __init__(self, *args, pad_fn=None, one_side_pad=False, **kwargs):
96
+ kwargs = dict(kwargs)
97
+ kwargs["padding"] = 0
98
+ super().__init__(*args, **kwargs)
99
+ self.pad_fn = pad_fn
100
+ self.one_side_pad = one_side_pad
101
+
102
+ def forward(self, x):
103
+ x = self.pad_fn(x, one_side_pad=self.one_side_pad)
104
+ return F.conv2d(
105
+ x, self.weight, self.bias,
106
+ stride=self.stride, padding=0,
107
+ dilation=self.dilation, groups=self.groups
108
+ )
109
+
110
+ @classmethod
111
+ def from_existing(cls, conv: nn.Conv2d, pad_fn, one_side_pad=False):
112
+ new = cls(
113
+ conv.in_channels, conv.out_channels, conv.kernel_size,
114
+ stride=conv.stride, padding=0, dilation=conv.dilation,
115
+ groups=conv.groups, bias=(conv.bias is not None),
116
+ padding_mode="zeros", pad_fn=pad_fn, one_side_pad=one_side_pad
117
+ )
118
+ new.weight = conv.weight
119
+ if conv.bias is not None:
120
+ new.bias = conv.bias
121
+ return new
122
+
123
+
src/utils/geometry_utils.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import trimesh
4
+ from pytorch360convert import e2c, c2e
5
+
6
+
7
+ def erp_to_cubemap(erp_tensor, face_w = 768, cube_format = "stack", mode = "bilinear", **kwargs):
8
+ return e2c(erp_tensor, face_w=face_w, cube_format=cube_format, mode=mode, **kwargs)
9
+
10
+
11
+ def cubemap_to_erp(cube_tensor, erp_h = 1024, erp_w = 2048, cube_format = "stack", mode = "bilinear", **kwargs):
12
+ return c2e(cube_tensor, h=erp_h, w=erp_w, cube_format=cube_format, mode=mode, **kwargs)
13
+
14
+ def roll_augment(data, shift_x):
15
+ if data.ndim == 2:
16
+ data = data[:, :, np.newaxis]
17
+ originally_2d = True
18
+ else:
19
+ originally_2d = False
20
+ if data.ndim == 3 and data.shape[0] != 3:
21
+ data = np.moveaxis(data, -1, 0)
22
+ moved_axis = True
23
+ else:
24
+ moved_axis = False
25
+
26
+ data_rolled = np.roll(data, int(shift_x), axis=2)
27
+
28
+ if moved_axis:
29
+ data_rolled = np.moveaxis(data_rolled, 0, -1)
30
+ if originally_2d:
31
+ data_rolled = data_rolled[:, :, 0]
32
+ return data_rolled
33
+
34
+
35
+ def roll_normal(normal, shift_x):
36
+ if normal.ndim == 2:
37
+ normal = normal[:, :, np.newaxis]
38
+ originally_2d = True
39
+ else:
40
+ originally_2d = False
41
+ if normal.ndim == 3 and normal.shape[0] != 3:
42
+ normal = np.moveaxis(normal, -1, 0)
43
+ moved_axis = True
44
+ else:
45
+ moved_axis = False
46
+
47
+ _, H, W = normal.shape
48
+
49
+ angle = - 2.0 * np.pi * (shift_x / float(W))
50
+ cos_a, sin_a = np.cos(angle), np.sin(angle)
51
+ R = np.array([
52
+ [ cos_a, 0.0, -sin_a],
53
+ [ 0.0, 1.0, 0.0 ],
54
+ [ sin_a, 0.0, cos_a]
55
+ ], dtype=normal.dtype)
56
+
57
+ n_flat = normal.reshape(3, -1)
58
+ normal = (R @ n_flat).reshape(3, H, W)
59
+
60
+ if moved_axis:
61
+ normal = np.moveaxis(normal, 0, -1)
62
+
63
+ if originally_2d:
64
+ normal = normal[:, :, 0]
65
+ return normal
66
+
67
+
68
+ def compute_scale_and_shift(pred_g, targ_g, mask_g = None, eps = 0.0, fit_shift = True):
69
+ if mask_g is None:
70
+ mask_g = torch.ones_like(pred_g, dtype=torch.bool)
71
+ if pred_g.shape[0] == 6:
72
+ pred_g = pred_g.view(1, 6, pred_g.shape[2], pred_g.shape[3])
73
+ targ_g = targ_g.view(1, 6, targ_g.shape[2], targ_g.shape[3])
74
+ mask_g = mask_g.view(1, 6, mask_g.shape[2], mask_g.shape[3])
75
+ elif pred_g.shape[0] == 1 and pred_g.dim() == 3:
76
+ pred_g = pred_g.unsqueeze(0)
77
+ targ_g = targ_g.unsqueeze(0)
78
+ mask_g = mask_g.unsqueeze(0)
79
+
80
+ mask_g = mask_g.to(dtype=pred_g.dtype)
81
+
82
+ a_00 = torch.sum(mask_g * pred_g * pred_g, dim=(1, 2, 3))
83
+ a_01 = torch.sum(mask_g * pred_g, dim=(1, 2, 3))
84
+ a_11 = torch.sum(mask_g, dim=(1, 2, 3))
85
+ b_0 = torch.sum(mask_g * pred_g * targ_g, dim=(1, 2, 3))
86
+ b_1 = torch.sum(mask_g * targ_g, dim=(1, 2, 3))
87
+
88
+ if fit_shift:
89
+ det = a_00 * a_11 - a_01 * a_01
90
+ det = det + eps
91
+ scale = torch.zeros_like(b_0)
92
+ shift = torch.zeros_like(b_1)
93
+ valid = det > 0
94
+ scale[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / det[valid]
95
+ shift[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid]
96
+ return scale, shift
97
+ else:
98
+ denom = a_00 + eps
99
+ scale = b_0 / denom
100
+ shift = torch.zeros_like(scale)
101
+ return scale, shift
102
+
103
+
104
+ def compute_shift(pred, targ, mask, eps = 1e-6):
105
+ if pred.shape[0] == 6:
106
+ pred = pred.view(1, 6, *pred.shape[2:])
107
+ targ = targ.view(1, 6, *targ.shape[2:])
108
+ mask = mask.view(1, 6, *mask.shape[2:])
109
+
110
+ w = mask.float()
111
+ num = torch.sum(w * (targ - pred), dim=(1,2,3))
112
+ den = torch.sum(w, dim=(1,2,3)).clamp_min(eps)
113
+ beta = num / den
114
+ return beta
115
+
116
+
117
+ def get_positional_encoding(H, W, pixel_center = True, hw = 96):
118
+ jj = np.arange(W, dtype=np.float64)
119
+ ii = np.arange(H, dtype=np.float64)
120
+ if pixel_center:
121
+ jj = jj + 0.5
122
+ ii = ii + 0.5
123
+
124
+ U = (jj / W) * 2.0 - 1.0
125
+ V = (ii / H) * 2.0 - 1.0
126
+ U, V = np.meshgrid(U, V, indexing='xy')
127
+
128
+ erp = np.stack([U, V], axis=-1)
129
+
130
+ erp_tensor = torch.from_numpy(erp).permute(2, 0, 1).float()
131
+ faces = erp_to_cubemap(erp_tensor, face_w=hw)
132
+ return faces
133
+
134
+
135
+ def unit_normals(n, eps = 1e-6):
136
+ assert n.dim() >= 3 and n.size(-3) == 3, "normals must have channel=3 at dim -3"
137
+ denom = torch.clamp(torch.linalg.norm(n, dim=-3, keepdim=True), min=eps)
138
+ return n / denom
139
+
140
+
141
+ def _erp_dirs(H, W, device=None, dtype=None):
142
+ u = (torch.arange(W, device=device, dtype=dtype) + 0.5) / W
143
+ v = (torch.arange(H, device=device, dtype=dtype) + 0.5) / H
144
+ theta = u * (2.0 * torch.pi) - torch.pi
145
+ phi = (0.5 - v) * torch.pi
146
+
147
+ theta = theta.view(1, W).expand(H, W)
148
+ phi = phi.view(H, 1).expand(H, W)
149
+
150
+ cosphi = torch.cos(phi)
151
+ sinphi = torch.sin(phi)
152
+ costhe = torch.cos(theta)
153
+ sinthe = torch.sin(theta)
154
+
155
+ x = cosphi * costhe
156
+ y = sinphi
157
+ z = -cosphi * sinthe
158
+
159
+ dirs = torch.stack([x, y, z], dim=0)
160
+ return dirs
161
+
162
+
163
+ def depth_to_normals_erp(depth, eps = 1e-6):
164
+
165
+ assert depth.dim() == 3 and depth.size(0) == 1, "depth must be (B,1,H,W)"
166
+ _, H, W = depth.shape
167
+ device, dtype = depth.device, depth.dtype
168
+
169
+ dirs = _erp_dirs(H, W, device=device, dtype=dtype)
170
+ P = depth * dirs
171
+
172
+ dtheta = 2.0 * torch.pi / W
173
+ dphi = torch.pi / H
174
+
175
+ P_l = torch.roll(P, shifts=+1, dims=-1)
176
+ P_r = torch.roll(P, shifts=-1, dims=-1)
177
+ dP_dtheta = (P_r - P_l) / (2.0 * dtheta)
178
+
179
+ P_u = torch.cat([P[:, :1, :], P[:, :-1, :]], dim=-2)
180
+ P_d = torch.cat([P[:, 1:, :], P[:, -1:, :]], dim=-2)
181
+ dP_dphi = (P_d - P_u) / (2.0 * dphi)
182
+
183
+ n = torch.cross(dP_dtheta, dP_dphi, dim=0)
184
+ n = unit_normals(n, eps=eps)
185
+
186
+ return n
187
+
188
+
189
+ def compute_edge_mask(depth, abs_thresh = 0.1, rel_thresh = 0.1):
190
+ assert depth.ndim == 2
191
+ depth = depth.astype(np.float32, copy=False)
192
+
193
+ valid = depth > 0
194
+ eps = 1e-6
195
+
196
+ edge = np.zeros_like(valid, dtype=bool)
197
+
198
+ d1 = depth[:, :-1]
199
+ d2 = depth[:, 1:]
200
+ v_pair = valid[:, :-1] & valid[:, 1:]
201
+
202
+ diff = np.abs(d1 - d2)
203
+ rel = diff / (np.minimum(d1, d2) + eps)
204
+
205
+ edge_pair = v_pair & (diff > abs_thresh) & (rel > rel_thresh)
206
+
207
+ edge[:, :-1] |= edge_pair
208
+ edge[:, 1:] |= edge_pair
209
+
210
+ d1 = depth[:-1, :]
211
+ d2 = depth[1:, :]
212
+ v_pair = valid[:-1, :] & valid[1:, :]
213
+
214
+ diff = np.abs(d1 - d2)
215
+ rel = diff / (np.minimum(d1, d2) + eps)
216
+
217
+ edge_pair = v_pair & (diff > abs_thresh) & (rel > rel_thresh)
218
+
219
+ edge[:-1, :] |= edge_pair
220
+ edge[1:, :] |= edge_pair
221
+
222
+ keep = valid & (~edge)
223
+ return keep
224
+
225
+
226
+ def erp_to_pointcloud(rgb, depth, mask = None):
227
+ assert rgb.ndim == 3 and rgb.shape[-1] == 3, "rgb must be (H, W, 3)"
228
+ assert depth.ndim == 2 and depth.shape[:2] == rgb.shape[:2], "depth must be (H, W) and match rgb H,W"
229
+
230
+ H, W, _ = rgb.shape
231
+
232
+ depth = depth.astype(np.float32, copy=False)
233
+
234
+ u = (np.arange(W, dtype=np.float32) + 0.5) / W
235
+ v = (np.arange(H, dtype=np.float32) + 0.5) / H
236
+
237
+ theta = u * (2.0 * np.pi) - np.pi
238
+ phi = (1 - v) * np.pi - (np.pi / 2.0)
239
+
240
+ theta, phi = np.meshgrid(theta, phi, indexing="xy")
241
+
242
+ cos_phi = np.cos(phi)
243
+ dir_x = cos_phi * np.cos(theta)
244
+ dir_y = np.sin(phi)
245
+ dir_z = cos_phi * np.sin(theta)
246
+
247
+ X = depth * dir_x
248
+ Y = depth * dir_y
249
+ Z = depth * dir_z
250
+
251
+ if mask is None:
252
+ keep = depth > 0
253
+ else:
254
+ keep = (mask.astype(bool)) & (depth > 0)
255
+
256
+ points = np.stack([X, Y, Z], axis=-1)[keep]
257
+
258
+ rgb_clamped = np.clip(rgb, -1.0, 1.0)
259
+ colors = ((rgb_clamped * 0.5 + 0.5) * 255.0).astype(np.uint8)
260
+ colors = colors.reshape(H, W, 3)[keep]
261
+
262
+ return points.astype(np.float32, copy=False), colors
263
+
264
+
265
+ def erp_to_point_cloud_glb(rgb, depth, mask=None, export_path=None):
266
+ points, colors = erp_to_pointcloud(rgb, depth, mask)
267
+ scene = trimesh.Scene()
268
+ scene.add_geometry(trimesh.PointCloud(vertices=points, colors=colors))
269
+ scene.export(export_path)
270
+ return scene
src/utils/loss.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from src.utils.geometry_utils import unit_normals
4
+
5
+
6
+ class L1Loss(nn.Module):
7
+ def __init__(self, invalid_mask_weight=0.0):
8
+ super(L1Loss, self).__init__()
9
+ self.name = 'L1'
10
+ self.invalid_mask_weight = invalid_mask_weight
11
+
12
+ def forward(self, pred, target, mask):
13
+ loss = nn.functional.l1_loss(pred[mask], target[mask])
14
+ if self.invalid_mask_weight > 0.0:
15
+ invalid_mask = ~mask
16
+ if invalid_mask.sum() > 0:
17
+ invalid_loss = nn.functional.l1_loss(pred[invalid_mask], target[invalid_mask])
18
+ loss = loss + self.invalid_mask_weight * invalid_loss
19
+ return loss
20
+
21
+
22
+
23
+ class GradL1Loss(nn.Module):
24
+ def __init__(self):
25
+ super().__init__()
26
+ self.name = 'GradL1'
27
+
28
+ def grad(self, x):
29
+ dx = x[..., :-1, 1:] - x[..., :-1, :-1]
30
+ dy = x[..., 1:, :-1] - x[..., :-1, :-1]
31
+ return dx, dy
32
+
33
+ def grad_mask(self, mask):
34
+ return (mask[..., :-1, :-1] & mask[..., :-1, 1:] &
35
+ mask[..., 1:, :-1] & mask[..., 1:, 1:])
36
+
37
+ def forward(self, pred, target, mask):
38
+ dx_p, dy_p = self.grad(pred)
39
+ dx_t, dy_t = self.grad(target)
40
+ mask_g = self.grad_mask(mask)
41
+
42
+ loss_x = nn.functional.l1_loss(dx_p[mask_g], dx_t[mask_g], reduction='mean')
43
+ loss_y = nn.functional.l1_loss(dy_p[mask_g], dy_t[mask_g], reduction='mean')
44
+
45
+ return 0.5 * (loss_x + loss_y)
46
+
47
+
48
+ class CosineNormalLoss(nn.Module):
49
+ def __init__(self):
50
+ super().__init__()
51
+ self.name = "CosineNormalLoss"
52
+
53
+ def forward(self, pred: torch.Tensor,
54
+ target: torch.Tensor,
55
+ mask: torch.Tensor) -> torch.Tensor:
56
+ assert pred.shape == target.shape, "pred and target must have same shape"
57
+
58
+ pred = unit_normals(pred)
59
+ target = unit_normals(target)
60
+
61
+ dot = (pred * target).sum(dim=1, keepdim=True).clamp(-1.0, 1.0)
62
+ cos_term = 1.0 - dot
63
+
64
+ if mask is not None:
65
+ loss = cos_term[mask].mean()
66
+ else:
67
+ loss = cos_term.mean()
68
+ return loss
src/utils/lr_scheduler.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @Vukasin Bozic 2026
2
+ # This file contains the modified version of Marigold's exponential LR scheduler.
3
+ # https://github.com/prs-eth/Marigold/blob/main/src/util/lr_scheduler.py
4
+
5
+ # Author: Bingxin Ke
6
+
7
+ import numpy as np
8
+
9
+ class IterExponential:
10
+
11
+ def __init__(self, total_iter_length, final_ratio, warmup_steps=0) -> None:
12
+ self.total_length = total_iter_length
13
+ self.effective_length = int(total_iter_length * (1 - warmup_steps))
14
+ self.final_ratio = final_ratio
15
+ self.warmup_steps = int(total_iter_length * warmup_steps)
16
+
17
+ def __call__(self, n_iter) -> float:
18
+ if n_iter < self.warmup_steps:
19
+ alpha = 1.0 * n_iter / self.warmup_steps
20
+ elif n_iter >= self.total_length:
21
+ alpha = self.final_ratio
22
+ else:
23
+ actual_iter = n_iter - self.warmup_steps
24
+ alpha = np.exp(
25
+ actual_iter / self.effective_length * np.log(self.final_ratio)
26
+ )
27
+ return alpha
28
+
29
+
30
+ class IterConstant:
31
+
32
+ def __init__(self, total_iter_length: int, warmup_steps: float = 0.0) -> None:
33
+ self.total_length = int(total_iter_length)
34
+ self.warmup_steps = int(total_iter_length * warmup_steps)
35
+
36
+ def __call__(self, n_iter: int) -> float:
37
+ if self.warmup_steps <= 0:
38
+ return 1.0
39
+ if n_iter < self.warmup_steps:
40
+ return float(n_iter + 1) / float(self.warmup_steps)
41
+ return 1.0
src/utils/utils.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import cv2
4
+ import random
5
+ import wandb
6
+ from tqdm.auto import tqdm
7
+ from omegaconf import OmegaConf, DictConfig
8
+ from pathlib import Path
9
+
10
+ def args_to_omegaconf(args, base_cfg=None):
11
+ cfg = OmegaConf.create(base_cfg)
12
+
13
+ def _override_if_provided(container, key):
14
+ if hasattr(args, key):
15
+ value = getattr(args, key)
16
+ if value is not None:
17
+ container[key] = value
18
+
19
+ for key in cfg.keys():
20
+ node = cfg[key]
21
+ if isinstance(node, DictConfig):
22
+ for subkey in node.keys():
23
+ _override_if_provided(node, subkey)
24
+ else:
25
+ _override_if_provided(cfg, key)
26
+
27
+ return cfg
28
+
29
+ def _tb_sanitize(v):
30
+ if v is None:
31
+ return "null"
32
+ if isinstance(v, (bool, int, float, str, torch.Tensor)):
33
+ return v
34
+ if isinstance(v, Path):
35
+ return str(v)
36
+ return str(v)
37
+
38
+ def _flatten_dict(d, prefix=""):
39
+ out = {}
40
+ if isinstance(d, dict):
41
+ for k, v in d.items():
42
+ key = f"{prefix}.{k}" if prefix else str(k)
43
+ if isinstance(v, dict):
44
+ out.update(_flatten_dict(v, key))
45
+ else:
46
+ out[key] = _tb_sanitize(v)
47
+ else:
48
+ out[prefix or "cfg"] = _tb_sanitize(d)
49
+ return out
50
+
51
+ def convert_paths_to_pathlib(cfg):
52
+ for key, value in cfg.items():
53
+ if isinstance(value, DictConfig):
54
+ cfg[key] = convert_paths_to_pathlib(value)
55
+ elif 'path' in key.lower():
56
+ cfg[key] = Path(value) if value is not None else None
57
+ return cfg
58
+
59
+
60
+ def convert_pathlib_to_strings(cfg):
61
+ for key, value in cfg.items():
62
+ if isinstance(value, DictConfig):
63
+ cfg[key] = convert_pathlib_to_strings(value)
64
+ elif isinstance(value, Path):
65
+ cfg[key] = str(value)
66
+ return cfg
67
+
68
+
69
+ def prepare_trained_parameters(unet, cfg):
70
+ unet_parameters = []
71
+
72
+ if cfg.training.only_train_attention_layers:
73
+ for name, param in unet.named_parameters():
74
+ if (cfg.model.unet_positional_encoding == "uv" and "conv_in" in name) or \
75
+ "transformer_blocks" in name:
76
+ unet_parameters.append(param)
77
+ param.requires_grad_(True)
78
+ else:
79
+ param.requires_grad_(False)
80
+ else:
81
+ for param in unet.parameters():
82
+ unet_parameters.append(param)
83
+ param.requires_grad_(True)
84
+
85
+ return unet_parameters
86
+
87
+
88
+ @torch.no_grad()
89
+ def validation_loop(accelerator, dataloader, pager, ema_unet, cfg, epoch, global_step, val_type="val"):
90
+ if val_type == "val":
91
+ desc = "Validation"
92
+ x_axis_name = "epoch"
93
+ x_axis = epoch
94
+ elif val_type == "tiny_val":
95
+ desc = "Tiny Validation"
96
+ x_axis_name = "global_step"
97
+ x_axis = global_step
98
+ else:
99
+ raise ValueError(f"Unknown val type {val_type}")
100
+ if cfg.training.use_EMA:
101
+ ema_unet.store(pager.unwrapped_unet.parameters())
102
+ ema_unet.copy_to(pager.unwrapped_unet.parameters())
103
+ val_epoch_loss = 0.0
104
+ log_val_images = {"rgb": [], cfg.model.modality: []}
105
+ log_img_ids = random.sample(range(len(dataloader)), 4)
106
+ progress_bar = tqdm(dataloader, desc=desc, total=len(dataloader), disable=not accelerator.is_main_process)
107
+ for i, batch in enumerate(progress_bar):
108
+ pred_cubemap = pager(batch, cfg.model.modality)
109
+ if cfg.model.modality == "depth":
110
+ min_depth = dataloader.dataset.LOG_MIN_DEPTH if cfg.model.log_scale else dataloader.dataset.MIN_DEPTH
111
+ depth_range = dataloader.dataset.LOG_DEPTH_RANGE if cfg.model.log_scale else dataloader.dataset.DEPTH_RANGE
112
+ loss = pager.calculate_depth_loss(batch, pred_cubemap, min_depth, depth_range, cfg.model.log_scale, cfg.model.metric_depth)
113
+ elif cfg.model.modality == "normal":
114
+ loss = pager.calculate_normal_loss(batch, pred_cubemap)
115
+
116
+ avg_loss = accelerator.reduce(loss["total_loss"].detach(), reduction="mean")
117
+ if accelerator.is_main_process:
118
+ progress_bar.set_postfix({"loss": avg_loss.item()})
119
+ val_epoch_loss += avg_loss
120
+ if i in log_img_ids:
121
+ log_val_images["rgb"].append(prepare_image_for_logging(batch["rgb"][0].cpu().numpy()))
122
+ if cfg.model.modality == "depth":
123
+ result_image = pager.process_depth_output(pred_cubemap, orig_size=batch['depth'].shape[2:4], min_depth=min_depth,
124
+ depth_range=depth_range, log_scale=cfg.model.log_scale)[1].cpu().numpy()
125
+ elif cfg.model.modality == "normal":
126
+ result_image = pager.process_normal_output(pred_cubemap, orig_size=batch['normal'].shape[2:4]).cpu().numpy()
127
+ log_val_images[cfg.model.modality].append(prepare_image_for_logging(result_image))
128
+
129
+ val_epoch_loss = val_epoch_loss / len(dataloader)
130
+
131
+ if accelerator.is_main_process:
132
+ accelerator.log({x_axis_name: x_axis, f"{val_type}/loss": float(val_epoch_loss)}, step=global_step)
133
+
134
+ img_mix_rgb = log_images_mosaic(log_val_images["rgb"])
135
+ img_mix_depth = log_images_mosaic(log_val_images[cfg.model.modality])
136
+
137
+ if cfg.logging.report_to == "wandb":
138
+ accelerator.log(
139
+ {x_axis_name: x_axis, f"{val_type}/pred_panorama_rgb": wandb.Image(img_mix_rgb)},
140
+ step=global_step,
141
+ )
142
+ accelerator.log(
143
+ {x_axis_name: x_axis, f"{val_type}/pred_panorama_{cfg.model.modality}": wandb.Image(img_mix_depth)},
144
+ step=global_step,
145
+ )
146
+ elif cfg.logging.report_to == "tensorboard":
147
+ tb_writer = accelerator.get_tracker("tensorboard").writer
148
+ tb_writer.add_image(
149
+ f"{val_type}/pred_panorama_rgb",
150
+ img_mix_rgb,
151
+ global_step,
152
+ dataformats="HWC",
153
+ )
154
+ tb_writer.add_image(
155
+ f"{val_type}/pred_panorama_{cfg.model.modality}",
156
+ img_mix_depth,
157
+ global_step,
158
+ dataformats="HWC",
159
+ )
160
+
161
+ if cfg.training.use_EMA:
162
+ ema_unet.restore(pager.unwrapped_unet.parameters())
163
+ return val_epoch_loss
164
+
165
+
166
+ def prepare_image_for_logging(image):
167
+ image = (image - image.min()) / (image.max() - image.min() + 1e-8)
168
+ image = (image * 255).astype("uint8")
169
+ return image
170
+
171
+
172
+ def log_images_mosaic(images):
173
+ n = len(images)
174
+ assert 1 <= n <= 4, "Provide between 1 and 4 images (CHW uint8)."
175
+
176
+ fullhd_imgs = []
177
+ for img in images:
178
+ assert img.dtype == np.uint8 and img.ndim == 3 and img.shape[0] in (1, 3), \
179
+ "Each image must be uint8 with shape (C,H,W), C in {1,3}."
180
+
181
+ if img.shape[0] == 1:
182
+ img = np.repeat(img, 3, axis=0)
183
+ img_hwc = np.transpose(img, (1, 2, 0))
184
+
185
+ img_fullhd = cv2.resize(img_hwc, (1920, 1080), interpolation=cv2.INTER_LINEAR)
186
+ fullhd_imgs.append(img_fullhd)
187
+
188
+ H, W, C = 1080, 1920, 3
189
+
190
+ if n == 1:
191
+ return fullhd_imgs[0]
192
+
193
+ if n == 2:
194
+ canvas = np.zeros((H, 2*W, C), dtype=np.uint8)
195
+ canvas[:, 0:W, :] = fullhd_imgs[0]
196
+ canvas[:, W:2*W, :] = fullhd_imgs[1]
197
+ return canvas
198
+
199
+ if n == 3:
200
+ canvas = np.zeros((2*H, 2*W, C), dtype=np.uint8)
201
+ x_off = W // 2
202
+ canvas[0:H, x_off:x_off+W, :] = fullhd_imgs[0]
203
+ canvas[H:2*H, 0:W, :] = fullhd_imgs[1]
204
+ canvas[H:2*H, W:2*W, :] = fullhd_imgs[2]
205
+ return canvas
206
+
207
+ canvas = np.zeros((2*H, 2*W, C), dtype=np.uint8)
208
+ canvas[0:H, 0:W, :] = fullhd_imgs[0]
209
+ canvas[0:H, W:2*W, :] = fullhd_imgs[1]
210
+ canvas[H:2*H, 0:W, :] = fullhd_imgs[2]
211
+ canvas[H:2*H, W:2*W, :] = fullhd_imgs[3]
212
+ return canvas
213
+
214
+