| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ |
| Model loading and state dict conversion utilities. |
| """ |
|
|
| from typing import Dict, Tuple |
| import torch |
|
|
| from ..utils.logger import logger |
|
|
|
|
| def convert_general_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
| """ |
| Convert general model state dict to match current model architecture. |
| |
| Args: |
| state_dict: Original state dictionary |
| |
| Returns: |
| Converted state dictionary |
| """ |
| |
| state_dict = {k.replace("module.", "model."): v for k, v in state_dict.items()} |
| state_dict = {k.replace(".net.", ".backbone."): v for k, v in state_dict.items()} |
|
|
| |
| if "model.backbone.pretrained.camera_token" in state_dict: |
| del state_dict["model.backbone.pretrained.camera_token"] |
|
|
| |
| state_dict = { |
| k.replace(".camera_token_extra", ".camera_token"): v for k, v in state_dict.items() |
| } |
|
|
| |
| state_dict = { |
| k.replace("model.all_heads.camera_cond_head", "model.cam_enc"): v |
| for k, v in state_dict.items() |
| } |
| state_dict = { |
| k.replace("model.all_heads.camera_head", "model.cam_dec"): v for k, v in state_dict.items() |
| } |
| state_dict = {k.replace(".more_mlps.", ".backbone."): v for k, v in state_dict.items()} |
| state_dict = {k.replace(".fc_rot.", ".fc_qvec."): v for k, v in state_dict.items()} |
| state_dict = { |
| k.replace("model.all_heads.head", "model.head"): v for k, v in state_dict.items() |
| } |
|
|
| |
| state_dict = { |
| k.replace("output_conv2_additional.sky_mask", "sky_output_conv2"): v |
| for k, v in state_dict.items() |
| } |
| state_dict = {k.replace("_ray.", "_aux."): v for k, v in state_dict.items()} |
|
|
| |
| state_dict = {k.replace("gaussian_param_head.", "gs_head."): v for k, v in state_dict.items()} |
|
|
| return state_dict |
|
|
|
|
| def convert_metric_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
| """ |
| Convert metric model state dict to match current model architecture. |
| |
| Args: |
| state_dict: Original metric state dictionary |
| |
| Returns: |
| Converted state dictionary |
| """ |
| |
| state_dict = {"module." + k: v for k, v in state_dict.items()} |
| return convert_general_state_dict(state_dict) |
|
|
|
|
| def load_pretrained_weights(model, model_path: str, is_metric: bool = False) -> Tuple[list, list]: |
| """ |
| Load pretrained weights for a single model. |
| |
| Args: |
| model: Model instance to load weights into |
| model_path: Path to the pretrained weights |
| is_metric: Whether this is a metric model |
| |
| Returns: |
| Tuple of (missed_keys, unexpected_keys) |
| """ |
| state_dict = torch.load(model_path, map_location="cpu") |
|
|
| if is_metric: |
| state_dict = convert_metric_state_dict(state_dict) |
| else: |
| state_dict = convert_general_state_dict(state_dict) |
|
|
| missed, unexpected = model.load_state_dict(state_dict, strict=False) |
| logger.info("Missed keys:", missed) |
| logger.info("Unexpected keys:", unexpected) |
|
|
| return missed, unexpected |
|
|
|
|
| def load_pretrained_nested_weights( |
| model, main_model_path: str, metric_model_path: str |
| ) -> Tuple[list, list]: |
| """ |
| Load pretrained weights for a nested model with both main and metric branches. |
| |
| Args: |
| model: Nested model instance |
| main_model_path: Path to main model weights |
| metric_model_path: Path to metric model weights |
| |
| Returns: |
| Tuple of (missed_keys, unexpected_keys) |
| """ |
| |
| state_dict0 = torch.load(main_model_path, map_location="cpu") |
| state_dict0 = convert_general_state_dict(state_dict0) |
| state_dict0 = {k.replace("model.", "model.da3."): v for k, v in state_dict0.items()} |
|
|
| |
| state_dict1 = torch.load(metric_model_path, map_location="cpu") |
| state_dict1 = convert_metric_state_dict(state_dict1) |
| state_dict1 = {k.replace("model.", "model.da3_metric."): v for k, v in state_dict1.items()} |
|
|
| |
| combined_state_dict = state_dict0.copy() |
| combined_state_dict.update(state_dict1) |
|
|
| missed, unexpected = model.load_state_dict(combined_state_dict, strict=False) |
|
|
| print("Missed keys:", missed) |
| print("Unexpected keys:", unexpected) |
|
|
| return missed, unexpected |
|
|