Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from typing import Dict as TyDict | |
| from typing import List, Sequence | |
| import torch | |
| import torch.nn as nn | |
| from depth_anything_3.model.dpt import DPT | |
| from depth_anything_3.model.utils.head_utils import activate_head_gs, custom_interpolate | |
| class GSDPT(DPT): | |
| def __init__( | |
| self, | |
| dim_in: int, | |
| patch_size: int = 14, | |
| output_dim: int = 4, | |
| activation: str = "linear", | |
| conf_activation: str = "sigmoid", | |
| features: int = 256, | |
| out_channels: Sequence[int] = (256, 512, 1024, 1024), | |
| pos_embed: bool = True, | |
| feature_only: bool = False, | |
| down_ratio: int = 1, | |
| conf_dim: int = 1, | |
| norm_type: str = "idt", # use to match legacy GS-DPT head, "idt" / "layer" | |
| fusion_block_inplace: bool = False, | |
| ) -> None: | |
| super().__init__( | |
| dim_in=dim_in, | |
| patch_size=patch_size, | |
| output_dim=output_dim, | |
| activation=activation, | |
| conf_activation=conf_activation, | |
| features=features, | |
| out_channels=out_channels, | |
| pos_embed=pos_embed, | |
| down_ratio=down_ratio, | |
| head_name="raw_gs", | |
| use_sky_head=False, | |
| norm_type=norm_type, | |
| fusion_block_inplace=fusion_block_inplace, | |
| ) | |
| self.conf_dim = conf_dim | |
| if conf_dim and conf_dim > 1: | |
| assert ( | |
| conf_activation == "linear" | |
| ), "use linear prediction when using view-dependent opacity" | |
| merger_out_dim = features if feature_only else features // 2 | |
| self.images_merger = nn.Sequential( | |
| nn.Conv2d(3, merger_out_dim // 4, 3, 1, 1), # fewer channels first | |
| nn.GELU(), | |
| nn.Conv2d(merger_out_dim // 4, merger_out_dim // 2, 3, 1, 1), | |
| nn.GELU(), | |
| nn.Conv2d(merger_out_dim // 2, merger_out_dim, 3, 1, 1), | |
| nn.GELU(), | |
| ) | |
| # ------------------------------------------------------------------------- | |
| # Internal forward (single chunk) | |
| # ------------------------------------------------------------------------- | |
| def _forward_impl( | |
| self, | |
| feats: List[torch.Tensor], | |
| H: int, | |
| W: int, | |
| patch_start_idx: int, | |
| images: torch.Tensor, | |
| ) -> TyDict[str, torch.Tensor]: | |
| B, _, C = feats[0].shape | |
| ph, pw = H // self.patch_size, W // self.patch_size | |
| resized_feats = [] | |
| for stage_idx, take_idx in enumerate(self.intermediate_layer_idx): | |
| x = feats[take_idx][:, patch_start_idx:] # [B*S, N_patch, C] | |
| x = self.norm(x) | |
| x = x.permute(0, 2, 1).reshape(B, C, ph, pw) # [B*S, C, ph, pw] | |
| x = self.projects[stage_idx](x) | |
| if self.pos_embed: | |
| x = self._add_pos_embed(x, W, H) | |
| x = self.resize_layers[stage_idx](x) # Align scale | |
| resized_feats.append(x) | |
| # 2) Fusion pyramid (main branch only) | |
| fused = self._fuse(resized_feats) | |
| fused = self.scratch.output_conv1(fused) | |
| # 3) Upsample to target resolution, optionally add position encoding again | |
| h_out = int(ph * self.patch_size / self.down_ratio) | |
| w_out = int(pw * self.patch_size / self.down_ratio) | |
| fused = custom_interpolate(fused, (h_out, w_out), mode="bilinear", align_corners=True) | |
| # inject the image information here | |
| fused = fused + self.images_merger(images) | |
| if self.pos_embed: | |
| fused = self._add_pos_embed(fused, W, H) | |
| # 4) Shared neck1 | |
| # feat = self.scratch.output_conv1(fused) | |
| feat = fused | |
| # 5) Main head: logits -> activate_head or single channel activation | |
| main_logits = self.scratch.output_conv2(feat) | |
| outs: TyDict[str, torch.Tensor] = {} | |
| if self.has_conf: | |
| pred, conf = activate_head_gs( | |
| main_logits, | |
| activation=self.activation, | |
| conf_activation=self.conf_activation, | |
| conf_dim=self.conf_dim, | |
| ) | |
| outs[self.head_main] = pred.squeeze(1) | |
| outs[f"{self.head_main}_conf"] = conf.squeeze(1) | |
| else: | |
| outs[self.head_main] = self._apply_activation_single(main_logits).squeeze(1) | |
| return outs | |