Learn2Splat / optgs /misc /checkpointing.py
SteEsp's picture
Add Docker-based Learn2Splat demo (viser GUI)
78d2329 verified
import os
from collections import OrderedDict
from typing import Any
import torch
from optgs.misc.io import cyan
# Function to extract the step number from the filename
def extract_step(file_name):
step_str = file_name.split("-")[1].split("_")[1].replace(".ckpt", "")
return int(step_str)
def find_latest_ckpt(ckpt_dir):
# List all files in the directory that end with .ckpt
ckpt_files = [f for f in os.listdir(ckpt_dir) if f.endswith(".ckpt")]
# Check if there are any .ckpt files in the directory
if not ckpt_files:
raise ValueError(f"No .ckpt files found in {ckpt_dir}.")
else:
# Find the file with the maximum step
latest_ckpt_file = max(ckpt_files, key=extract_step)
return ckpt_dir / latest_ckpt_file
def no_resume_upsampler(pretrained_state_dict):
new_state_dict = OrderedDict()
for key, value in pretrained_state_dict.items():
if 'upsampler' not in key:
new_state_dict[key] = value
return new_state_dict
def load_partial_state_dict(model, pretrained_state_dict):
# Load only matching parameters
model_state_dict = model.state_dict()
filtered_state_dict = {
k: v for k, v in pretrained_state_dict.items()
if k in model_state_dict and v.shape == model_state_dict[k].shape
}
# for key in model_state_dict:
# if key not in filtered_state_dict:
# print(key)
model_state_dict.update(filtered_state_dict)
model.load_state_dict(model_state_dict)
def _load_state_dict(path):
ckpt = torch.load(path, map_location='cpu')
if 'state_dict' in ckpt:
return ckpt['state_dict']
if 'model' in ckpt:
return ckpt['model']
return ckpt
def load_optimizer(cfg, scene_trainer, strict_load):
pretrained_model = torch.load(cfg.checkpointing.pretrained_optimizer, map_location='cpu')
if 'state_dict' in pretrained_model:
pretrained_model = pretrained_model['state_dict']
# Strip scene_trainer. prefix if present (Lightning checkpoint format)
pretrained_model = {k.replace("scene_trainer.", ""): v for k, v in pretrained_model.items()}
if any(k.startswith("optimizer.") for k in pretrained_model):
# Unified repo format: keys are optimizer.*
optimizer_state_dict = {k[len("optimizer."):]: v for k, v in pretrained_model.items() if
k.startswith("optimizer.")}
else:
# Resplat repo format: keys are encoder.* (before init/opt split).
# Strip encoder. prefix; init-related keys will be ignored via strict=False.
optimizer_state_dict = {k[len("encoder."):]: v for k, v in pretrained_model.items() if k.startswith("encoder.")}
# Rename module attributes that changed when the encoder was split.
_ORIG_OPTIMIZER_ATTR_RENAMES = {
"render_error_mv_attn": "update_error_attn",
}
renamed = {}
for k, v in optimizer_state_dict.items():
for old, new in _ORIG_OPTIMIZER_ATTR_RENAMES.items():
if k == old or k.startswith(old + "."):
k = new + k[len(old):]
break
renamed[k] = v
optimizer_state_dict = renamed
# If init_state_wo_features is True, remove all feature-related parameters from the optimizer state dict
print(cfg.scene_trainer.scene_optimizer.init_state_wo_features)
if getattr(cfg.scene_trainer.scene_optimizer, "init_state_wo_features", False):
optimizer_state_dict = {k: v for k, v in optimizer_state_dict.items() if "update_proj" not in k}
scene_trainer.optimizer.load_state_dict(optimizer_state_dict, strict=strict_load)
print(cyan(f"Loaded pretrained optimizer: {cfg.checkpointing.pretrained_optimizer}"))
def load_initializer(cfg, scene_trainer, strict_load):
pretrained_model = torch.load(cfg.checkpointing.pretrained_initializer, map_location='cpu')
if 'state_dict' in pretrained_model:
pretrained_model = pretrained_model['state_dict']
# Strip scene_trainer. prefix if present (Lightning checkpoint format)
pretrained_model = {k.replace("scene_trainer.", ""): v for k, v in pretrained_model.items()}
if any(k.startswith("initializer.") for k in pretrained_model):
assert all(k.startswith("initializer.") for k in pretrained_model)
# Current repo format: keys are initializer.*
initializer_state_dict = {k[len("initializer."):]: v for k, v in pretrained_model.items() if
k.startswith("initializer.")}
else:
# Resplat repo format: keys are encoder.* (before init/opt split)
initializer_state_dict = {k[len("encoder."):]: v for k, v in pretrained_model.items() if
k.startswith("encoder.")}
scene_trainer.initializer.load_state_dict(initializer_state_dict, strict=strict_load)
print(cyan(f"Loaded pretrained initializer: {cfg.checkpointing.pretrained_initializer}"))
def load_full_model(cfg, scene_trainer, strict_load):
pretrained_model = torch.load(cfg.checkpointing.pretrained_model, map_location='cpu')
if 'state_dict' in pretrained_model:
pretrained_model = pretrained_model['state_dict']
if cfg.checkpointing.partial_load:
print('partial load')
load_partial_state_dict(scene_trainer, pretrained_model)
else:
scene_trainer.load_state_dict(pretrained_model, strict=strict_load)
print(cyan(f"Loaded pretrained weights: {cfg.checkpointing.pretrained_model}"))
def load_base_model(cfg, scene_trainer, strict_load: bool | Any):
if cfg.checkpointing.pretrained_model is not None:
load_full_model(cfg, scene_trainer, strict_load)
else:
# Load pretrained initializer if available
if cfg.checkpointing.pretrained_initializer is not None:
load_initializer(cfg, scene_trainer, strict_load)
if cfg.checkpointing.pretrained_optimizer is not None and scene_trainer.optimizer is not None:
load_optimizer(cfg, scene_trainer, strict_load)
def load_model_weights(cfg, scene_trainer, strict_load, mode: str):
assert mode in ("train", "test")
if mode == "train":
# only load monodepth
if cfg.checkpointing.pretrained_monodepth is not None:
strict_load = False
pretrained_model = torch.load(cfg.checkpointing.pretrained_monodepth, map_location='cpu')
if 'state_dict' in pretrained_model:
pretrained_model = pretrained_model['state_dict']
if cfg.model.encoder.separate_depth_color or cfg.model.encoder.separate_depth_gaussian_scale:
scene_trainer.encoder.feature_extractor.load_state_dict(pretrained_model, strict=strict_load)
else:
scene_trainer.encoder.depth_predictor.load_state_dict(pretrained_model, strict=strict_load)
print(cyan(f"Loaded pretrained monodepth: {cfg.checkpointing.pretrained_monodepth}"))
# freeze mono vit
if cfg.checkpointing.freeze_mono_vit:
print('freeze mono vit')
for params in scene_trainer.encoder.depth_predictor.pretrained.parameters():
params.requires_grad = False
# load pretrained mvdepth
if cfg.checkpointing.pretrained_mvdepth is not None:
pretrained_model = torch.load(cfg.checkpointing.pretrained_mvdepth, map_location='cpu')['model']
if cfg.model.encoder.separate_depth_color or cfg.model.encoder.separate_depth_gaussian_scale:
scene_trainer.encoder.feature_extractor.load_state_dict(pretrained_model, strict=False)
else:
scene_trainer.encoder.depth_predictor.load_state_dict(pretrained_model, strict=False)
print(cyan(f"Loaded pretrained mvdepth: {cfg.checkpointing.pretrained_mvdepth}"))
# load full model (or separate initializer/optimizer checkpoints)
load_base_model(cfg, scene_trainer, strict_load)
# load pretrained depth
if cfg.checkpointing.pretrained_depth is not None:
pretrained_model = _load_state_dict(cfg.checkpointing.pretrained_depth)
if mode == "train":
if cfg.checkpointing.partial_load:
print('partial load depth')
load_partial_state_dict(scene_trainer.initializer.depth_predictor, pretrained_model)
else:
if cfg.checkpointing.no_resume_upsampler:
pretrained_model = no_resume_upsampler(pretrained_model)
strict_load = False
scene_trainer.initializer.depth_predictor.load_state_dict(pretrained_model, strict=strict_load)
else:
scene_trainer.initializer.depth_predictor.load_state_dict(pretrained_model, strict=True)
print(cyan(f"Loaded pretrained depth: {cfg.checkpointing.pretrained_depth}"))
# load pretrained scale predictor
if mode == "train" and cfg.checkpointing.pretrained_scale_predictor is not None:
pretrained_model = _load_state_dict(cfg.checkpointing.pretrained_scale_predictor)
scene_trainer.encoder.scale_predictor.load_state_dict(pretrained_model, strict=strict_load)
print(cyan(f"Loaded pretrained scale predictor: {cfg.checkpointing.pretrained_scale_predictor}"))
print('freeze scale predictor')
for params in scene_trainer.encoder.scale_predictor.parameters():
params.requires_grad = False
# load pretrained update module
if cfg.checkpointing.resume_update_module is not None:
pretrained_model = _load_state_dict(cfg.checkpointing.resume_update_module)
# Filter and load only matching "update_" parameters
filtered_dict = {
k: v for k, v in pretrained_model.items()
if "encoder.update" in k and k in scene_trainer.state_dict()
and v.shape == scene_trainer.state_dict()[k].shape
}
# Load them using strict=False so it skips missing/unmatched keys
scene_trainer.load_state_dict(filtered_dict, strict=False)
print(cyan(f"Loaded pretrained update module: {cfg.checkpointing.resume_update_module}"))
if mode == "train":
apply_freezes(cfg, scene_trainer)
def apply_freezes(cfg, scene_trainer):
if getattr(cfg.scene_trainer.scene_initializer, 'freeze_depth', False):
print('freeze depth')
for params in scene_trainer.initializer.depth_predictor.parameters():
params.requires_grad = False
if not cfg.scene_trainer.train_scene_init:
print('train refine only, freezing scene initializer')
for name, params in scene_trainer.initializer.named_parameters():
params.requires_grad = False
if cfg.scene_trainer.num_update_steps > 0:
if not cfg.scene_trainer.train_scene_opt:
print('train refine only, freezing scene optimizer')
for name, params in scene_trainer.optimizer.named_parameters():
params.requires_grad = False
if cfg.scene_trainer.scene_optimizer.train_global_update_only:
print('train global update only')
for name, params in scene_trainer.optimizer.named_parameters():
if 'global_update' not in name:
params.requires_grad = False