# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Model loading and state dict conversion utilities. """ from typing import Dict, Tuple import torch from depth_anything_3.utils.logger import logger def convert_general_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Convert general model state dict to match current model architecture. Args: state_dict: Original state dictionary Returns: Converted state dictionary """ # Replace module prefixes state_dict = {k.replace("module.", "model."): v for k, v in state_dict.items()} state_dict = {k.replace(".net.", ".backbone."): v for k, v in state_dict.items()} # Remove camera token if present if "model.backbone.pretrained.camera_token" in state_dict: del state_dict["model.backbone.pretrained.camera_token"] # Replace camera token naming state_dict = { k.replace(".camera_token_extra", ".camera_token"): v for k, v in state_dict.items() } # Replace head naming state_dict = { k.replace("model.all_heads.camera_cond_head", "model.cam_enc"): v for k, v in state_dict.items() } state_dict = { k.replace("model.all_heads.camera_head", "model.cam_dec"): v for k, v in state_dict.items() } state_dict = {k.replace(".more_mlps.", ".backbone."): v for k, v in state_dict.items()} state_dict = {k.replace(".fc_rot.", ".fc_qvec."): v for k, v in state_dict.items()} state_dict = { k.replace("model.all_heads.head", "model.head"): v for k, v in state_dict.items() } # Replace output naming state_dict = { k.replace("output_conv2_additional.sky_mask", "sky_output_conv2"): v for k, v in state_dict.items() } state_dict = {k.replace("_ray.", "_aux."): v for k, v in state_dict.items()} # Update GS-DPT head naming and value state_dict = {k.replace("gaussian_param_head.", "gs_head."): v for k, v in state_dict.items()} return state_dict def convert_metric_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Convert metric model state dict to match current model architecture. Args: state_dict: Original metric state dictionary Returns: Converted state dictionary """ # Add module prefix for metric models state_dict = {"module." + k: v for k, v in state_dict.items()} return convert_general_state_dict(state_dict) def load_pretrained_weights(model, model_path: str, is_metric: bool = False) -> Tuple[list, list]: """ Load pretrained weights for a single model. Args: model: Model instance to load weights into model_path: Path to the pretrained weights is_metric: Whether this is a metric model Returns: Tuple of (missed_keys, unexpected_keys) """ state_dict = torch.load(model_path, map_location="cpu") if is_metric: state_dict = convert_metric_state_dict(state_dict) else: state_dict = convert_general_state_dict(state_dict) missed, unexpected = model.load_state_dict(state_dict, strict=False) logger.info("Missed keys:", missed) logger.info("Unexpected keys:", unexpected) return missed, unexpected def load_pretrained_nested_weights( model, main_model_path: str, metric_model_path: str ) -> Tuple[list, list]: """ Load pretrained weights for a nested model with both main and metric branches. Args: model: Nested model instance main_model_path: Path to main model weights metric_model_path: Path to metric model weights Returns: Tuple of (missed_keys, unexpected_keys) """ # Load main model weights state_dict0 = torch.load(main_model_path, map_location="cpu") state_dict0 = convert_general_state_dict(state_dict0) state_dict0 = {k.replace("model.", "model.da3."): v for k, v in state_dict0.items()} # Load metric model weights state_dict1 = torch.load(metric_model_path, map_location="cpu") state_dict1 = convert_metric_state_dict(state_dict1) state_dict1 = {k.replace("model.", "model.da3_metric."): v for k, v in state_dict1.items()} # Combine state dictionaries combined_state_dict = state_dict0.copy() combined_state_dict.update(state_dict1) missed, unexpected = model.load_state_dict(combined_state_dict, strict=False) print("Missed keys:", missed) print("Unexpected keys:", unexpected) return missed, unexpected