depthsplat / src /main.py
Yeqing0814's picture
Upload folder using huggingface_hub
a6dd040 verified
import os
from pathlib import Path
import warnings
import copy
import hydra
import torch
import swanlab as wandb
import os
from colorama import Fore
from jaxtyping import install_import_hook
from omegaconf import DictConfig, OmegaConf
import random
import numpy as np
import sys
# 排除冲突路径
sys.path = [p for p in sys.path if "/mnt/data-3/users/nichaojun/.local" not in p]
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import (
LearningRateMonitor,
ModelCheckpoint,
)
from pytorch_lightning.callbacks import Callback
from typing import Any, Dict, Optional
# from pytorch_lightning.loggers.wandb import WandbLogger
from swanlab.integration.pytorch_lightning import SwanLabLogger
from pytorch_lightning.plugins.environments import LightningEnvironment
from pytorch_lightning.strategies import DDPStrategy
# Configure beartype and jaxtyping.
with install_import_hook(
("src",),
("beartype", "beartype"),
):
from src.config import load_typed_root_config
from src.dataset.data_module import DataModule
from src.global_cfg import set_cfg
from src.loss import get_losses
from src.misc.LocalLogger import LocalLogger
from src.misc.step_tracker import StepTracker
from src.misc.wandb_tools import update_checkpoint_path
from src.misc.resume_ckpt import find_latest_ckpt
from src.model.decoder import get_decoder
from src.model.encoder import get_encoder
from src.model.model_wrapper import ModelWrapper
def cyan(text: str) -> str:
return f"{Fore.CYAN}{text}{Fore.RESET}"
class UnfreezePretrainedCallback(Callback):
def __init__(self, unfreeze_step: int = 20000): # 修改1:参数名改为 unfreeze_step
self.unfreeze_step = unfreeze_step #
self.has_unfrozen = False
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): # 修改3:改为 batch_end 钩子
# 只在主进程执行解冻操作
if trainer.is_global_zero and not self.has_unfrozen:
current_step = trainer.global_step # 修改4:获取当前步数
if current_step >= self.unfreeze_step:
print(cyan(f"Step {current_step}: Unfreezing pretrained_monodepth parameters"))
# 解冻所有参数
for param in pl_module.encoder.depth_predictor.parameters():
param.requires_grad = True
self.has_unfrozen = True
def _set_global_seed(seed):
"""设置Python和NumPy的全局随机种子"""
random.seed(seed)
np.random.seed(seed)
# 注意:这里不设置PyTorch的种子,因为PyTorch种子需要按进程单独设置
print(f"Global Python/NumPy seed set to: {seed}")
@hydra.main(
version_base=None,
config_path="../config",
config_name="main",
)
def train(cfg_dict: DictConfig):
# # 设置GPU内存限制 - 在任何CUDA操作之前
# if torch.cuda.is_available():
# # 你可以从配置中读取内存限制比例,或者直接设置一个固定值
# memory_fraction = cfg_dict.get("gpu_memory_fraction", 0.8) # 默认使用80%显存
# torch.cuda.set_per_process_memory_fraction(memory_fraction)
# print(f"Set GPU memory fraction to: {memory_fraction}")
if cfg_dict["mode"] == "train" and cfg_dict["train"]["eval_model_every_n_val"] > 0:
eval_cfg_dict = copy.deepcopy(cfg_dict)
dataset_dir = str(cfg_dict["dataset"]["roots"]).lower()
if "re10k" in dataset_dir:
eval_path = "assets/evaluation_index_re10k.json"
elif "dl3dv" in dataset_dir:
if cfg_dict["dataset"]["view_sampler"]["num_context_views"] == 6:
eval_path = "assets/dl3dv_start_0_distance_50_ctx_6v_tgt_8v.json"
else:
raise ValueError("unsupported number of views for dl3dv")
else:
raise Exception("Fail to load eval index path")
eval_cfg_dict["dataset"]["view_sampler"] = {
"name": "evaluation",
"index_path": eval_path,
"num_context_views": cfg_dict["dataset"]["view_sampler"]["num_context_views"],
}
eval_cfg = load_typed_root_config(eval_cfg_dict)
else:
eval_cfg = None
cfg = load_typed_root_config(cfg_dict)
set_cfg(cfg_dict)
# Set up the output directory.
if cfg_dict.output_dir is None:
output_dir = Path(
hydra.core.hydra_config.HydraConfig.get()["runtime"]["output_dir"]
)
else: # for resuming
output_dir = Path(cfg_dict.output_dir)
os.makedirs(output_dir, exist_ok=True)
print(cyan(f"Saving outputs to {output_dir}."))
# Set up logging with wandb.
callbacks = []
if cfg_dict.wandb.mode != "disabled" and cfg.mode == "train":
wandb_extra_kwargs = {}
if cfg_dict.wandb.id is not None:
wandb_extra_kwargs.update({'id': cfg_dict.wandb.id,
'resume': "must"})
# logger = WandbLogger(
# entity=cfg_dict.wandb.entity,
# project=cfg_dict.wandb.project,
# mode=cfg_dict.wandb.mode,
# name=os.path.basename(cfg_dict.output_dir),
# tags=cfg_dict.wandb.get("tags", None),
# log_model=False,
# save_dir=output_dir,
# config=OmegaConf.to_container(cfg_dict),
# **wandb_extra_kwargs,
# )
logger = SwanLabLogger(
project = cfg_dict.wandb.project,
experiment_name = cfg_dict.wandb.entity,
workspace = cfg_dict.wandb.workspace
)
callbacks.append(LearningRateMonitor("step", True))
# if wandb.run is not None:
# wandb.run.log_code("src")
else:
logger = LocalLogger()
# Set up checkpointing.
callbacks.append(
ModelCheckpoint(
output_dir / "checkpoints",
every_n_train_steps=cfg.checkpointing.every_n_train_steps,
save_top_k=cfg.checkpointing.save_top_k,
monitor="info/global_step",
mode="max",
)
)
#解冻参数回调函数
callbacks.append(UnfreezePretrainedCallback(unfreeze_step=20000))
for cb in callbacks:
cb.CHECKPOINT_EQUALS_CHAR = '_'
# Prepare the checkpoint for loading.
if cfg.checkpointing.resume:
if not os.path.exists(output_dir / 'checkpoints'):
checkpoint_path = None
else:
checkpoint_path = find_latest_ckpt(output_dir / 'checkpoints')
print(f'resume from {checkpoint_path}')
else:
checkpoint_path = update_checkpoint_path(cfg.checkpointing.load, cfg.wandb)
# This allows the current step to be shared with the data loader processes.
step_tracker = StepTracker()
# 创建分布式策略
if torch.cuda.device_count() > 1:
# 创建DDP策略并启用未使用参数检测
ddp_strategy = DDPStrategy(
find_unused_parameters=True, # 关键设置:解决冻结参数导致的DDP问题
static_graph=False, # 设置为False以适应冻结参数
process_group_backend="nccl" if torch.cuda.is_available() else "gloo",
)
else:
ddp_strategy = "auto"
trainer = Trainer(
max_epochs=-1,
accelerator="gpu",
logger=logger,
devices=torch.cuda.device_count(),
strategy= ddp_strategy , #'ddp' if torch.cuda.device_count() > 1 else "auto"
callbacks=callbacks,
val_check_interval=cfg.trainer.val_check_interval,
enable_progress_bar=cfg.mode == "test",
gradient_clip_val=cfg.trainer.gradient_clip_val,
max_steps=cfg.trainer.max_steps,
num_sanity_val_steps=cfg.trainer.num_sanity_val_steps,
num_nodes=cfg.trainer.num_nodes,
plugins=LightningEnvironment() if cfg.use_plugins else None,
)
# 只在主进程中设置全局种子,避免分布式环境下的问题
if torch.distributed.is_initialized():
# 在分布式环境中,只有rank 0设置全局种子
if torch.distributed.get_rank() == 0:
_set_global_seed(42)
# 等待rank 0完成种子设置
torch.distributed.barrier()
else:
# 在非分布式环境中,直接设置
_set_global_seed(42)
torch.manual_seed(cfg_dict.seed + trainer.global_rank)
encoder, encoder_visualizer = get_encoder(cfg.model.encoder)
model_wrapper = ModelWrapper(
cfg.optimizer,
cfg.test,
cfg.train,
encoder,
encoder_visualizer,
get_decoder(cfg.model.decoder, cfg.dataset),
get_losses(cfg.loss),
step_tracker,
eval_data_cfg=(
None if eval_cfg is None else eval_cfg.dataset
),
)
#测试参数
model_wrapper._check_param_updates = False
data_module = DataModule(
cfg.dataset,
cfg.data_loader,
step_tracker,
global_rank=trainer.global_rank,
)
if cfg.mode == "train":
print("train:", len(data_module.train_dataloader()))
print("val:", len(data_module.val_dataloader()))
print("test:", len(data_module.test_dataloader()))
strict_load = not cfg.checkpointing.no_strict_load
if cfg.mode == "train":
# only load monodepth
if cfg.checkpointing.pretrained_monodepth is not None:
strict_load = False
pretrained_model = torch.load(cfg.checkpointing.pretrained_monodepth, map_location='cpu')
if 'state_dict' in pretrained_model:
pretrained_model = pretrained_model['state_dict']
load_result = model_wrapper.encoder.depth_predictor.load_state_dict(pretrained_model, strict=strict_load)
# 获取成功加载的参数名称
loaded_keys = set(pretrained_model.keys()) - set(load_result.unexpected_keys)
# 精确冻结:仅冻结成功加载的参数
for name, param in model_wrapper.encoder.depth_predictor.named_parameters():
if name in loaded_keys:
param.requires_grad = False
print("\n===== 参数加载报告 =====")
print(f"✅ 成功加载参数数量: {len(loaded_keys)}")
print(f"❄️ 冻结参数数量: {len(loaded_keys)}")
print(f"⚠️ 缺失参数: {len(load_result.missing_keys)} 个")
print(f"⚠️ 多余参数: {len(load_result.unexpected_keys)} 个")
print(f"🛠️ 可训练参数数量: {len([p for p in model_wrapper.encoder.depth_predictor.parameters() if p.requires_grad])}")
print(
cyan(
f"Loaded pretrained monodepth (partial freezing): {cfg.checkpointing.pretrained_monodepth}"
)
)
# load pretrained mvdepth
if cfg.checkpointing.pretrained_mvdepth is not None:
pretrained_model = torch.load(cfg.checkpointing.pretrained_mvdepth, map_location='cpu')['model']
load_result = model_wrapper.encoder.depth_predictor.load_state_dict(pretrained_model, strict=False)
print("\n===== 参数加载报告 =====")
print(f"✅ 成功加载参数数量: {len(pretrained_model) - len(load_result.unexpected_keys)}")
print(f"⚠️ 缺失参数: {len(load_result.missing_keys)} 个")
print(f"⚠️ 多余参数: {len(load_result.unexpected_keys)} 个")
print(
cyan(
f"Loaded pretrained mvdepth: {cfg.checkpointing.pretrained_mvdepth}"
)
)
# load full model
if cfg.checkpointing.pretrained_model is not None:
strict_load = False
pretrained_model = torch.load(cfg.checkpointing.pretrained_model, map_location='cpu')
if 'state_dict' in pretrained_model:
pretrained_model = pretrained_model['state_dict']
model_wrapper.load_state_dict(pretrained_model, strict=strict_load)
print(
cyan(
f"Loaded pretrained weights: {cfg.checkpointing.pretrained_model}"
)
)
# load pretrained depth
if cfg.checkpointing.pretrained_depth is not None:
strict_load = False
pretrained_model = torch.load(cfg.checkpointing.pretrained_depth, map_location='cpu')
if 'state_dict' in pretrained_model:
pretrained_model = pretrained_model['state_dict']
# pretrained_model = torch.load(cfg.checkpointing.pretrained_depth, map_location='cpu')['model']
# strict_load = True
load_result = model_wrapper.encoder.depth_predictor.load_state_dict(pretrained_model, strict=strict_load)
print("\n===== 参数加载报告 =====")
print(f"✅ 成功加载参数数量: {len(pretrained_model) - len(load_result.unexpected_keys)}")
print(f"⚠️ 缺失参数: {len(load_result.missing_keys)} 个")
print(f"⚠️ 多余参数: {len(load_result.unexpected_keys)} 个")
print(
cyan(
f"Loaded pretrained depth: {cfg.checkpointing.pretrained_depth}"
)
)
trainer.fit(model_wrapper, datamodule=data_module, ckpt_path=checkpoint_path)
else:
# load full model
if cfg.checkpointing.pretrained_model is not None:
pretrained_model = torch.load(cfg.checkpointing.pretrained_model, map_location='cpu')
if 'state_dict' in pretrained_model:
pretrained_model = pretrained_model['state_dict']
model_wrapper.load_state_dict(pretrained_model, strict=strict_load)
print(
cyan(
f"Loaded pretrained weights: {cfg.checkpointing.pretrained_model}"
)
)
# load pretrained depth model only
if cfg.checkpointing.pretrained_depth is not None:
pretrained_model = torch.load(cfg.checkpointing.pretrained_depth, map_location='cpu')['model']
strict_load = True
model_wrapper.encoder.depth_predictor.load_state_dict(pretrained_model, strict=strict_load)
print(
cyan(
f"Loaded pretrained depth: {cfg.checkpointing.pretrained_depth}"
)
)
trainer.test(
model_wrapper,
datamodule=data_module,
ckpt_path=checkpoint_path,
)
if __name__ == "__main__":
warnings.filterwarnings("ignore")
torch.set_float32_matmul_precision('high')
train()