Spaces:
Sleeping
Sleeping
| from omegaconf import OmegaConf | |
| CURRENT_CFG_VERSION = 2 | |
| def migrate(cfg_dict): | |
| was_omega = not isinstance(cfg_dict, dict) | |
| version = cfg_dict.get("version", 0) | |
| # null means a fresh run from main.yaml — treat as current version. | |
| if version is None: | |
| version = CURRENT_CFG_VERSION | |
| if version == 0: | |
| # Heuristic: configs that were partially migrated may have version=0 but a | |
| # non-depthsplat optimizer name (already renamed during v0→v1), so skip v0→v1. | |
| so = cfg_dict.get("scene_trainer", {}).get("scene_optimizer", {}) | |
| if so.get("name", "") not in ["depthsplat"]: | |
| version = 1 | |
| else: | |
| print("Migrating config from version 0 (cvpr submission) to version 1 (cvpr rebuttal)...") | |
| cfg_dict = migrate_v0_to_v1(cfg_dict) | |
| version = 1 | |
| if version == 1: | |
| print("Migrating config from version 1 to version 2 (train/test moved under meta_trainer)...") | |
| cfg_dict = migrate_v1_to_v2(cfg_dict) | |
| version = 2 | |
| if version != CURRENT_CFG_VERSION: | |
| raise ValueError(f"Unsupported config version: {version}") | |
| # Apply code-level renames and strip stale fields. | |
| # Work on a plain dict so mutations propagate; convert back to OmegaConf if needed. | |
| cfg_container = OmegaConf.to_container(cfg_dict, resolve=False) if not isinstance(cfg_dict, dict) else cfg_dict | |
| # Handle code-level renames that don't require a version bump (e.g. resplat → resplat_v1). | |
| so = cfg_container.get("scene_trainer", {}).get("scene_optimizer", {}) | |
| si = cfg_container.get("scene_trainer", {}).get("scene_initializer", {}) | |
| if so.get("name") == "resplat": | |
| so["name"] = "resplat_v1" | |
| if si.get("name") == "resplat": | |
| si["name"] = "resplat_v1" | |
| # Strip stale postprocessing fields from old checkpoint configs | |
| pp = cfg_container.get("meta_trainer", {}).get("test", {}).get("postprocessing", None) | |
| if isinstance(pp, dict): | |
| pp.pop("__target__", None) | |
| pp.pop("enabled", None) | |
| pp.pop("lr", None) | |
| # Strip stale foundationstereo fields (encoder removed) | |
| si.pop("foundationstereo", None) | |
| si.pop("fstereo_num_refine", None) | |
| if was_omega: | |
| return OmegaConf.create(cfg_container) | |
| return cfg_container | |
| def migrate_v1_to_v2(cfg_dict): | |
| """ | |
| Migration from v1 to v2: move top-level 'train' and 'test' under 'meta_trainer'. | |
| """ | |
| cfg = OmegaConf.to_container(cfg_dict, resolve=False) if not isinstance(cfg_dict, dict) else dict(cfg_dict) | |
| meta_trainer = cfg.setdefault("meta_trainer", {}) | |
| for key in ("train", "test"): | |
| if key in cfg and key not in meta_trainer: | |
| meta_trainer[key] = cfg.pop(key) | |
| cfg["version"] = 2 | |
| return cfg | |
| def migrate_v0_to_v1(cfg): | |
| """ | |
| Migration from submission v0 (refine_*) to rebuttal v1 (input_error_*). | |
| """ | |
| cfg = OmegaConf.to_container(cfg, resolve=False) | |
| so = cfg["scene_trainer"]["scene_optimizer"] | |
| si = cfg["scene_trainer"]["scene_initializer"] | |
| # ------------------------------------------------------------------ | |
| # Module renames | |
| # ------------------------------------------------------------------ | |
| if si["name"] == "depthsplat": | |
| si["name"] = "resplat_v1" | |
| if so["name"] == "depthsplat": | |
| if so["refine_input_gradient"]: | |
| so["name"] = "learn2splat" | |
| else: | |
| so["name"] = "resplat_v1" | |
| # ------------------------------------------------------------------ | |
| # Key renames (declarative) | |
| # ------------------------------------------------------------------ | |
| RENAME_MAP = { | |
| # feature extraction | |
| "refine_lpips_error": "input_error_lpips_features", | |
| "refine_pool_vgg_features": "input_error_pool_vgg_features", | |
| "refine_use_all_vgg_features": "input_error_use_all_vgg_features", | |
| "refine_vit_feature": "input_error_vit_feature", | |
| "refine_resnet_feature": "input_error_resnet_feature", | |
| "no_freeze_resnet_feature": "input_error_no_freeze_resnet_feature", | |
| "shallow_resnet_feature": "input_error_shallow_resnet_feature", | |
| "resnet_feature_layers": "input_error_resnet_feature_layers", | |
| "refine_convnext_feature": "input_error_convnext_feature", | |
| "convnext_feature_size": "input_error_convnext_feature_size", | |
| "refine_concat_feature": "input_error_concat_feature", | |
| "refine_concat_feature_cosine": "input_error_concat_feature_cosine", | |
| "refine_cosine_feature": "input_error_cosine_feature", | |
| "refine_add_feature": "input_error_add_feature", | |
| "refine_concat_rgb_feature_error": "input_error_concat_rgb_feature_error", | |
| # render error → input error | |
| "render_error_no_abs": "input_error_no_abs", | |
| "render_error_no_shuffle": "input_error_no_shuffle", | |
| "render_cache_resnet_feature": "input_error_cache_resnet_feature", | |
| "render_view_pool_resnet_feature": "input_error_view_pool_resnet_feature", | |
| "render_global_pool_resnet_feature": "input_error_global_pool_resnet_feature", | |
| # input toggles | |
| "refine_input_alpha": "input_alpha", | |
| "refine_input_depth": "input_depth", | |
| "refine_input_depth_smooth_error": "input_depth_smooth_error", | |
| "refine_input_error": "input_error", | |
| # attention (input error) | |
| "radii_averaged_render_error": "input_error_radii_averaged", | |
| "cross_attn_additional_render_error": "input_error_additional_cross_attn", | |
| "num_intermediate_views": "input_error_num_intermediate_views", | |
| "render_error_mv_attn_blocks": "input_error_mv_attn_blocks", | |
| # context handling | |
| "render_error_num_views": "input_error_num_views", | |
| "render_error_remain_context": "input_error_remain_context", | |
| "render_error_merge_remain_context": "input_error_merge_remain_context", | |
| "render_error_warp_remain_context": "input_error_warp_remain_context", | |
| "render_error_random_num_remain_context": "input_error_random_num_remain_context", | |
| "render_error_num_remain_context_test": "input_error_num_remain_context_test", | |
| "render_error_warp_input_view": "input_error_warp_input_view", | |
| # input gradient | |
| "refine_input_gradient": "input_gradient", | |
| "refine_input_gradient_log": "input_gradient_log", | |
| "refine_input_gradient_log_clip_deltas": "input_gradient_log_clip_deltas", | |
| "refine_input_gradient_scale": "input_gradient_scale", | |
| # normalize input | |
| "normalize_update_input": "input_gradient_normalize", | |
| "normalize_update_input_type": "input_gradient_normalize_type", | |
| "normalize_state": "input_normalize_state", | |
| "normalize_gaussians": "input_normalize_gaussians", | |
| # update head | |
| "final_head_act": "update_head_final_act", | |
| "refine_output_scale_mag": "update_head_scale_mag", | |
| "scalar_scale_out": "update_head_scalar_scale", | |
| "scalar_scale_out_act": "update_head_scalar_scale_act", | |
| } | |
| for old, new in RENAME_MAP.items(): | |
| if old in so: | |
| so[new] = so.pop(old) | |
| # ------------------------------------------------------------------ | |
| # New / fixed defaults | |
| # ------------------------------------------------------------------ | |
| if so["name"] in ["clogs", "learn2splat", "resplat_v1"]: | |
| so["update_head_hidden_dim_matches"] = "output" | |
| else: | |
| raise NotImplementedError | |
| if so["state_channels"] == 0: | |
| so["state_channels"] = 256 | |
| # ------------------------------------------------------------------ | |
| # Version bump | |
| # ------------------------------------------------------------------ | |
| cfg["version"] = 1 | |
| return OmegaConf.create(cfg) | |