| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from einops import rearrange |
| | from .base import ( |
| | fourier_dimension_expansion, |
| | flatten, |
| | DimensionAligner, |
| | AttentionSeq, |
| | ResidualUpsampler |
| | ) |
| |
|
| |
|
| | class _ViT_w_Esphere(nn.Module): |
| | def __init__( |
| | self, |
| | hidden_dim: int, |
| | num_heads: int = 8, |
| | expansion: int = 4, |
| | num_layers_head: int | list[int] = 4, |
| | dropout: float = 0.0, |
| | kernel_size: int = 7, |
| | layer_scale: float = 1.0, |
| | out_dim: int = 1, |
| | num_prompt_blocks: int = 1, |
| | use_norm: bool = False, |
| | **kwargs, |
| | ) -> None: |
| | super().__init__() |
| | self.out_dim = out_dim |
| | self.hidden_dim = hidden_dim |
| | self.up_sampler = nn.ModuleList([]) |
| | self.pred_head = nn.ModuleList([]) |
| | self.process_features = nn.ModuleList([]) |
| | self.prompt_camera = nn.ModuleList([]) |
| | mult = 2 |
| | self.to_latents = nn.Linear(hidden_dim, hidden_dim) |
| |
|
| | for _ in range(4): |
| | self.prompt_camera.append( |
| | AttentionSeq( |
| | num_blocks=num_prompt_blocks, |
| | dim=hidden_dim, |
| | num_heads=num_heads, |
| | expansion=expansion, |
| | dropout=dropout, |
| | layer_scale=-1.0, |
| | context_dim=hidden_dim, |
| | ) |
| | ) |
| |
|
| | for i, depth in enumerate(num_layers_head): |
| | current_dim = min(hidden_dim, mult * hidden_dim // int(2**i)) |
| | next_dim = mult * hidden_dim // int(2 ** (i + 1)) |
| | output_dim = max(next_dim, out_dim) |
| | self.process_features.append( |
| | nn.ConvTranspose2d( |
| | hidden_dim, |
| | current_dim, |
| | kernel_size=max(1, 2 * i), |
| | stride=max(1, 2 * i), |
| | padding=0, |
| | ) |
| | ) |
| | self.up_sampler.append( |
| | ResidualUpsampler( |
| | current_dim, |
| | output_dim=output_dim, |
| | expansion=expansion, |
| | layer_scale=layer_scale, |
| | kernel_size=kernel_size, |
| | num_layers=depth, |
| | use_norm=use_norm, |
| | ) |
| | ) |
| | pred_head = ( |
| | nn.Sequential(nn.LayerNorm(next_dim), nn.Linear(next_dim, output_dim)) |
| | if i == len(num_layers_head) - 1 |
| | else nn.Identity() |
| | ) |
| | self.pred_head.append(pred_head) |
| |
|
| | self.to_depth_lr = nn.Conv2d( |
| | output_dim, |
| | output_dim // 2, |
| | kernel_size=3, |
| | padding=1, |
| | padding_mode='reflect', |
| | ) |
| | self.to_confidence_lr = nn.Conv2d( |
| | output_dim, |
| | output_dim // 2, |
| | kernel_size=3, |
| | padding=1, |
| | padding_mode='reflect', |
| | ) |
| | self.to_depth_hr = nn.Sequential( |
| | nn.Conv2d( |
| | output_dim // 2, 32, kernel_size=3, padding=1, padding_mode='reflect' |
| | ), |
| | nn.LeakyReLU(), |
| | nn.Conv2d(32, 1, kernel_size=1), |
| | ) |
| | self.to_confidence_hr = nn.Sequential( |
| | nn.Conv2d( |
| | output_dim // 2, 32, kernel_size=3, padding=1, padding_mode='reflect' |
| | ), |
| | nn.LeakyReLU(), |
| | nn.Conv2d(32, 1, kernel_size=1), |
| | ) |
| |
|
| | def set_original_shapes(self, shapes: tuple[int, int]): |
| | self.original_shapes = shapes |
| |
|
| | def set_shapes(self, shapes: tuple[int, int]): |
| | self.shapes = shapes |
| |
|
| | def embed_sphere_dirs(self, sphere_dirs): |
| | sphere_embedding = flatten( |
| | sphere_dirs, old=self.original_shapes, new=self.shapes |
| | ) |
| | |
| | |
| | |
| | r1, r2, r3 = sphere_embedding[..., 0], sphere_embedding[..., 1], sphere_embedding[..., 2] |
| | polar = torch.asin(r2) |
| | |
| | min_val = torch.tensor([1e-5], dtype=r3.dtype, device=r3.device) |
| | |
| | r3_abs = r3.abs() |
| | r3_clipped = torch.max(r3_abs, min_val) * (2 * (r3 >= 0).to(r3.dtype) - 1) |
| | azimuth = torch.atan2(r1, r3_clipped) |
| | |
| | sphere_embedding = torch.stack([polar, azimuth], dim=-1) |
| | |
| | sphere_embedding = fourier_dimension_expansion( |
| | sphere_embedding, |
| | dim=self.hidden_dim, |
| | max_freq=max(self.shapes) // 2, |
| | use_cos=False, |
| | ) |
| | return sphere_embedding |
| |
|
| | def condition(self, feat, sphere_embeddings): |
| | conditioned_features = [ |
| | prompter(rearrange(feature, 'b h w c -> b (h w) c'), sphere_embeddings) |
| | for prompter, feature in zip(self.prompt_camera, feat) |
| | ] |
| | return conditioned_features |
| |
|
| | def process(self, features_list, sphere_embeddings): |
| | conditioned_features = self.condition(features_list, sphere_embeddings) |
| | init_latents = self.to_latents(conditioned_features[0]) |
| | init_latents = rearrange( |
| | init_latents, 'b (h w) c -> b c h w', h=self.shapes[0], w=self.shapes[1] |
| | ).contiguous() |
| | conditioned_features = [ |
| | rearrange( |
| | x, 'b (h w) c -> b c h w', h=self.shapes[0], w=self.shapes[1] |
| | ).contiguous() |
| | for x in conditioned_features |
| | ] |
| | latents = init_latents |
| |
|
| | out_features = [] |
| | |
| | for i, up in enumerate(self.up_sampler): |
| | latents = latents + self.process_features[i](conditioned_features[i + 1]) |
| | latents = up(latents) |
| | out_features.append(latents) |
| | return out_features |
| |
|
| | def prediction_head(self, out_features): |
| | depths = [] |
| | h_out, w_out = out_features[-1].shape[-2:] |
| | for i, (layer, features) in enumerate(zip(self.pred_head, out_features)): |
| | out_depth_features = layer(features.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) |
| | if i < len(self.pred_head) - 1: |
| | continue |
| | depths.append(out_depth_features) |
| | out_depth_features = F.interpolate( |
| | out_depth_features, size=(h_out, w_out), mode='bilinear', align_corners=True |
| | ) |
| | distance = self.to_depth_lr(out_depth_features) |
| | distance = F.interpolate( |
| | distance, size=self.original_shapes, mode='bilinear', align_corners=True |
| | ) |
| | distance = self.to_depth_hr(distance) |
| | return distance |
| |
|
| | def forward( |
| | self, |
| | features: list[torch.Tensor], |
| | sphere_dirs: torch.Tensor |
| | ) -> torch.Tensor: |
| | sphere_embeddings = self.embed_sphere_dirs(sphere_dirs) |
| | features = self.process(features, sphere_embeddings) |
| | distance = self.prediction_head(features) |
| | return distance |
| |
|
| |
|
| | class ViT_w_Esphere(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | self.config = config |
| | self.dim_aligner = DimensionAligner( |
| | input_dims=config['input_dims'], |
| | hidden_dim=config['hidden_dim'], |
| | ) |
| | self._vit_w_esphere = _ViT_w_Esphere(**config) |
| |
|
| | def forward(self, images, features, sphere_dirs) -> torch.Tensor: |
| | _, _, H, W = images.shape |
| | sphere_dirs = sphere_dirs |
| | common_shape = features[0].shape[1:3] |
| | features = self.dim_aligner(features) |
| | sphere_dirs = rearrange(sphere_dirs, 'b c h w -> b (h w) c') |
| |
|
| | self._vit_w_esphere.set_shapes(common_shape) |
| | self._vit_w_esphere.set_original_shapes((H, W)) |
| | logdistance = self._vit_w_esphere( |
| | features=features, |
| | sphere_dirs=sphere_dirs, |
| | ) |
| |
|
| | min_val = torch.tensor([-8.0], dtype=logdistance.dtype, device=logdistance.device) |
| | max_val = torch.tensor([8.0], dtype=logdistance.dtype, device=logdistance.device) |
| | |
| | clamped_logdistance = torch.min(torch.max(logdistance, min_val), max_val) |
| | distance = torch.exp(clamped_logdistance + 2.0) |
| | |
| | return distance |
| |
|