Spaces:
Sleeping
Sleeping
| import os | |
| from collections import OrderedDict | |
| from typing import Any | |
| import torch | |
| from optgs.misc.io import cyan | |
| # Function to extract the step number from the filename | |
| def extract_step(file_name): | |
| step_str = file_name.split("-")[1].split("_")[1].replace(".ckpt", "") | |
| return int(step_str) | |
| def find_latest_ckpt(ckpt_dir): | |
| # List all files in the directory that end with .ckpt | |
| ckpt_files = [f for f in os.listdir(ckpt_dir) if f.endswith(".ckpt")] | |
| # Check if there are any .ckpt files in the directory | |
| if not ckpt_files: | |
| raise ValueError(f"No .ckpt files found in {ckpt_dir}.") | |
| else: | |
| # Find the file with the maximum step | |
| latest_ckpt_file = max(ckpt_files, key=extract_step) | |
| return ckpt_dir / latest_ckpt_file | |
| def no_resume_upsampler(pretrained_state_dict): | |
| new_state_dict = OrderedDict() | |
| for key, value in pretrained_state_dict.items(): | |
| if 'upsampler' not in key: | |
| new_state_dict[key] = value | |
| return new_state_dict | |
| def load_partial_state_dict(model, pretrained_state_dict): | |
| # Load only matching parameters | |
| model_state_dict = model.state_dict() | |
| filtered_state_dict = { | |
| k: v for k, v in pretrained_state_dict.items() | |
| if k in model_state_dict and v.shape == model_state_dict[k].shape | |
| } | |
| # for key in model_state_dict: | |
| # if key not in filtered_state_dict: | |
| # print(key) | |
| model_state_dict.update(filtered_state_dict) | |
| model.load_state_dict(model_state_dict) | |
| def _load_state_dict(path): | |
| ckpt = torch.load(path, map_location='cpu') | |
| if 'state_dict' in ckpt: | |
| return ckpt['state_dict'] | |
| if 'model' in ckpt: | |
| return ckpt['model'] | |
| return ckpt | |
| def load_optimizer(cfg, scene_trainer, strict_load): | |
| pretrained_model = torch.load(cfg.checkpointing.pretrained_optimizer, map_location='cpu') | |
| if 'state_dict' in pretrained_model: | |
| pretrained_model = pretrained_model['state_dict'] | |
| # Strip scene_trainer. prefix if present (Lightning checkpoint format) | |
| pretrained_model = {k.replace("scene_trainer.", ""): v for k, v in pretrained_model.items()} | |
| if any(k.startswith("optimizer.") for k in pretrained_model): | |
| # Unified repo format: keys are optimizer.* | |
| optimizer_state_dict = {k[len("optimizer."):]: v for k, v in pretrained_model.items() if | |
| k.startswith("optimizer.")} | |
| else: | |
| # Resplat repo format: keys are encoder.* (before init/opt split). | |
| # Strip encoder. prefix; init-related keys will be ignored via strict=False. | |
| optimizer_state_dict = {k[len("encoder."):]: v for k, v in pretrained_model.items() if k.startswith("encoder.")} | |
| # Rename module attributes that changed when the encoder was split. | |
| _ORIG_OPTIMIZER_ATTR_RENAMES = { | |
| "render_error_mv_attn": "update_error_attn", | |
| } | |
| renamed = {} | |
| for k, v in optimizer_state_dict.items(): | |
| for old, new in _ORIG_OPTIMIZER_ATTR_RENAMES.items(): | |
| if k == old or k.startswith(old + "."): | |
| k = new + k[len(old):] | |
| break | |
| renamed[k] = v | |
| optimizer_state_dict = renamed | |
| # If init_state_wo_features is True, remove all feature-related parameters from the optimizer state dict | |
| print(cfg.scene_trainer.scene_optimizer.init_state_wo_features) | |
| if getattr(cfg.scene_trainer.scene_optimizer, "init_state_wo_features", False): | |
| optimizer_state_dict = {k: v for k, v in optimizer_state_dict.items() if "update_proj" not in k} | |
| scene_trainer.optimizer.load_state_dict(optimizer_state_dict, strict=strict_load) | |
| print(cyan(f"Loaded pretrained optimizer: {cfg.checkpointing.pretrained_optimizer}")) | |
| def load_initializer(cfg, scene_trainer, strict_load): | |
| pretrained_model = torch.load(cfg.checkpointing.pretrained_initializer, map_location='cpu') | |
| if 'state_dict' in pretrained_model: | |
| pretrained_model = pretrained_model['state_dict'] | |
| # Strip scene_trainer. prefix if present (Lightning checkpoint format) | |
| pretrained_model = {k.replace("scene_trainer.", ""): v for k, v in pretrained_model.items()} | |
| if any(k.startswith("initializer.") for k in pretrained_model): | |
| assert all(k.startswith("initializer.") for k in pretrained_model) | |
| # Current repo format: keys are initializer.* | |
| initializer_state_dict = {k[len("initializer."):]: v for k, v in pretrained_model.items() if | |
| k.startswith("initializer.")} | |
| else: | |
| # Resplat repo format: keys are encoder.* (before init/opt split) | |
| initializer_state_dict = {k[len("encoder."):]: v for k, v in pretrained_model.items() if | |
| k.startswith("encoder.")} | |
| scene_trainer.initializer.load_state_dict(initializer_state_dict, strict=strict_load) | |
| print(cyan(f"Loaded pretrained initializer: {cfg.checkpointing.pretrained_initializer}")) | |
| def load_full_model(cfg, scene_trainer, strict_load): | |
| pretrained_model = torch.load(cfg.checkpointing.pretrained_model, map_location='cpu') | |
| if 'state_dict' in pretrained_model: | |
| pretrained_model = pretrained_model['state_dict'] | |
| if cfg.checkpointing.partial_load: | |
| print('partial load') | |
| load_partial_state_dict(scene_trainer, pretrained_model) | |
| else: | |
| scene_trainer.load_state_dict(pretrained_model, strict=strict_load) | |
| print(cyan(f"Loaded pretrained weights: {cfg.checkpointing.pretrained_model}")) | |
| def load_base_model(cfg, scene_trainer, strict_load: bool | Any): | |
| if cfg.checkpointing.pretrained_model is not None: | |
| load_full_model(cfg, scene_trainer, strict_load) | |
| else: | |
| # Load pretrained initializer if available | |
| if cfg.checkpointing.pretrained_initializer is not None: | |
| load_initializer(cfg, scene_trainer, strict_load) | |
| if cfg.checkpointing.pretrained_optimizer is not None and scene_trainer.optimizer is not None: | |
| load_optimizer(cfg, scene_trainer, strict_load) | |
| def load_model_weights(cfg, scene_trainer, strict_load, mode: str): | |
| assert mode in ("train", "test") | |
| if mode == "train": | |
| # only load monodepth | |
| if cfg.checkpointing.pretrained_monodepth is not None: | |
| strict_load = False | |
| pretrained_model = torch.load(cfg.checkpointing.pretrained_monodepth, map_location='cpu') | |
| if 'state_dict' in pretrained_model: | |
| pretrained_model = pretrained_model['state_dict'] | |
| if cfg.model.encoder.separate_depth_color or cfg.model.encoder.separate_depth_gaussian_scale: | |
| scene_trainer.encoder.feature_extractor.load_state_dict(pretrained_model, strict=strict_load) | |
| else: | |
| scene_trainer.encoder.depth_predictor.load_state_dict(pretrained_model, strict=strict_load) | |
| print(cyan(f"Loaded pretrained monodepth: {cfg.checkpointing.pretrained_monodepth}")) | |
| # freeze mono vit | |
| if cfg.checkpointing.freeze_mono_vit: | |
| print('freeze mono vit') | |
| for params in scene_trainer.encoder.depth_predictor.pretrained.parameters(): | |
| params.requires_grad = False | |
| # load pretrained mvdepth | |
| if cfg.checkpointing.pretrained_mvdepth is not None: | |
| pretrained_model = torch.load(cfg.checkpointing.pretrained_mvdepth, map_location='cpu')['model'] | |
| if cfg.model.encoder.separate_depth_color or cfg.model.encoder.separate_depth_gaussian_scale: | |
| scene_trainer.encoder.feature_extractor.load_state_dict(pretrained_model, strict=False) | |
| else: | |
| scene_trainer.encoder.depth_predictor.load_state_dict(pretrained_model, strict=False) | |
| print(cyan(f"Loaded pretrained mvdepth: {cfg.checkpointing.pretrained_mvdepth}")) | |
| # load full model (or separate initializer/optimizer checkpoints) | |
| load_base_model(cfg, scene_trainer, strict_load) | |
| # load pretrained depth | |
| if cfg.checkpointing.pretrained_depth is not None: | |
| pretrained_model = _load_state_dict(cfg.checkpointing.pretrained_depth) | |
| if mode == "train": | |
| if cfg.checkpointing.partial_load: | |
| print('partial load depth') | |
| load_partial_state_dict(scene_trainer.initializer.depth_predictor, pretrained_model) | |
| else: | |
| if cfg.checkpointing.no_resume_upsampler: | |
| pretrained_model = no_resume_upsampler(pretrained_model) | |
| strict_load = False | |
| scene_trainer.initializer.depth_predictor.load_state_dict(pretrained_model, strict=strict_load) | |
| else: | |
| scene_trainer.initializer.depth_predictor.load_state_dict(pretrained_model, strict=True) | |
| print(cyan(f"Loaded pretrained depth: {cfg.checkpointing.pretrained_depth}")) | |
| # load pretrained scale predictor | |
| if mode == "train" and cfg.checkpointing.pretrained_scale_predictor is not None: | |
| pretrained_model = _load_state_dict(cfg.checkpointing.pretrained_scale_predictor) | |
| scene_trainer.encoder.scale_predictor.load_state_dict(pretrained_model, strict=strict_load) | |
| print(cyan(f"Loaded pretrained scale predictor: {cfg.checkpointing.pretrained_scale_predictor}")) | |
| print('freeze scale predictor') | |
| for params in scene_trainer.encoder.scale_predictor.parameters(): | |
| params.requires_grad = False | |
| # load pretrained update module | |
| if cfg.checkpointing.resume_update_module is not None: | |
| pretrained_model = _load_state_dict(cfg.checkpointing.resume_update_module) | |
| # Filter and load only matching "update_" parameters | |
| filtered_dict = { | |
| k: v for k, v in pretrained_model.items() | |
| if "encoder.update" in k and k in scene_trainer.state_dict() | |
| and v.shape == scene_trainer.state_dict()[k].shape | |
| } | |
| # Load them using strict=False so it skips missing/unmatched keys | |
| scene_trainer.load_state_dict(filtered_dict, strict=False) | |
| print(cyan(f"Loaded pretrained update module: {cfg.checkpointing.resume_update_module}")) | |
| if mode == "train": | |
| apply_freezes(cfg, scene_trainer) | |
| def apply_freezes(cfg, scene_trainer): | |
| if getattr(cfg.scene_trainer.scene_initializer, 'freeze_depth', False): | |
| print('freeze depth') | |
| for params in scene_trainer.initializer.depth_predictor.parameters(): | |
| params.requires_grad = False | |
| if not cfg.scene_trainer.train_scene_init: | |
| print('train refine only, freezing scene initializer') | |
| for name, params in scene_trainer.initializer.named_parameters(): | |
| params.requires_grad = False | |
| if cfg.scene_trainer.num_update_steps > 0: | |
| if not cfg.scene_trainer.train_scene_opt: | |
| print('train refine only, freezing scene optimizer') | |
| for name, params in scene_trainer.optimizer.named_parameters(): | |
| params.requires_grad = False | |
| if cfg.scene_trainer.scene_optimizer.train_global_update_only: | |
| print('train global update only') | |
| for name, params in scene_trainer.optimizer.named_parameters(): | |
| if 'global_update' not in name: | |
| params.requires_grad = False |