import os from pathlib import Path import warnings import copy import hydra import torch import swanlab as wandb import os from colorama import Fore from jaxtyping import install_import_hook from omegaconf import DictConfig, OmegaConf import random import numpy as np import sys # 排除冲突路径 sys.path = [p for p in sys.path if "/mnt/data-3/users/nichaojun/.local" not in p] from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ( LearningRateMonitor, ModelCheckpoint, ) from pytorch_lightning.callbacks import Callback from typing import Any, Dict, Optional # from pytorch_lightning.loggers.wandb import WandbLogger from swanlab.integration.pytorch_lightning import SwanLabLogger from pytorch_lightning.plugins.environments import LightningEnvironment from pytorch_lightning.strategies import DDPStrategy # Configure beartype and jaxtyping. with install_import_hook( ("src",), ("beartype", "beartype"), ): from src.config import load_typed_root_config from src.dataset.data_module import DataModule from src.global_cfg import set_cfg from src.loss import get_losses from src.misc.LocalLogger import LocalLogger from src.misc.step_tracker import StepTracker from src.misc.wandb_tools import update_checkpoint_path from src.misc.resume_ckpt import find_latest_ckpt from src.model.decoder import get_decoder from src.model.encoder import get_encoder from src.model.model_wrapper import ModelWrapper def cyan(text: str) -> str: return f"{Fore.CYAN}{text}{Fore.RESET}" class UnfreezePretrainedCallback(Callback): def __init__(self, unfreeze_step: int = 20000): # 修改1:参数名改为 unfreeze_step self.unfreeze_step = unfreeze_step # self.has_unfrozen = False def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): # 修改3:改为 batch_end 钩子 # 只在主进程执行解冻操作 if trainer.is_global_zero and not self.has_unfrozen: current_step = trainer.global_step # 修改4:获取当前步数 if current_step >= self.unfreeze_step: print(cyan(f"Step {current_step}: Unfreezing pretrained_monodepth parameters")) # 解冻所有参数 for param in pl_module.encoder.depth_predictor.parameters(): param.requires_grad = True self.has_unfrozen = True def _set_global_seed(seed): """设置Python和NumPy的全局随机种子""" random.seed(seed) np.random.seed(seed) # 注意:这里不设置PyTorch的种子,因为PyTorch种子需要按进程单独设置 print(f"Global Python/NumPy seed set to: {seed}") @hydra.main( version_base=None, config_path="../config", config_name="main", ) def train(cfg_dict: DictConfig): # # 设置GPU内存限制 - 在任何CUDA操作之前 # if torch.cuda.is_available(): # # 你可以从配置中读取内存限制比例,或者直接设置一个固定值 # memory_fraction = cfg_dict.get("gpu_memory_fraction", 0.8) # 默认使用80%显存 # torch.cuda.set_per_process_memory_fraction(memory_fraction) # print(f"Set GPU memory fraction to: {memory_fraction}") if cfg_dict["mode"] == "train" and cfg_dict["train"]["eval_model_every_n_val"] > 0: eval_cfg_dict = copy.deepcopy(cfg_dict) dataset_dir = str(cfg_dict["dataset"]["roots"]).lower() if "re10k" in dataset_dir: eval_path = "assets/evaluation_index_re10k.json" elif "dl3dv" in dataset_dir: if cfg_dict["dataset"]["view_sampler"]["num_context_views"] == 6: eval_path = "assets/dl3dv_start_0_distance_50_ctx_6v_tgt_8v.json" else: raise ValueError("unsupported number of views for dl3dv") else: raise Exception("Fail to load eval index path") eval_cfg_dict["dataset"]["view_sampler"] = { "name": "evaluation", "index_path": eval_path, "num_context_views": cfg_dict["dataset"]["view_sampler"]["num_context_views"], } eval_cfg = load_typed_root_config(eval_cfg_dict) else: eval_cfg = None cfg = load_typed_root_config(cfg_dict) set_cfg(cfg_dict) # Set up the output directory. if cfg_dict.output_dir is None: output_dir = Path( hydra.core.hydra_config.HydraConfig.get()["runtime"]["output_dir"] ) else: # for resuming output_dir = Path(cfg_dict.output_dir) os.makedirs(output_dir, exist_ok=True) print(cyan(f"Saving outputs to {output_dir}.")) # Set up logging with wandb. callbacks = [] if cfg_dict.wandb.mode != "disabled" and cfg.mode == "train": wandb_extra_kwargs = {} if cfg_dict.wandb.id is not None: wandb_extra_kwargs.update({'id': cfg_dict.wandb.id, 'resume': "must"}) # logger = WandbLogger( # entity=cfg_dict.wandb.entity, # project=cfg_dict.wandb.project, # mode=cfg_dict.wandb.mode, # name=os.path.basename(cfg_dict.output_dir), # tags=cfg_dict.wandb.get("tags", None), # log_model=False, # save_dir=output_dir, # config=OmegaConf.to_container(cfg_dict), # **wandb_extra_kwargs, # ) logger = SwanLabLogger( project = cfg_dict.wandb.project, experiment_name = cfg_dict.wandb.entity, workspace = cfg_dict.wandb.workspace ) callbacks.append(LearningRateMonitor("step", True)) # if wandb.run is not None: # wandb.run.log_code("src") else: logger = LocalLogger() # Set up checkpointing. callbacks.append( ModelCheckpoint( output_dir / "checkpoints", every_n_train_steps=cfg.checkpointing.every_n_train_steps, save_top_k=cfg.checkpointing.save_top_k, monitor="info/global_step", mode="max", ) ) #解冻参数回调函数 callbacks.append(UnfreezePretrainedCallback(unfreeze_step=20000)) for cb in callbacks: cb.CHECKPOINT_EQUALS_CHAR = '_' # Prepare the checkpoint for loading. if cfg.checkpointing.resume: if not os.path.exists(output_dir / 'checkpoints'): checkpoint_path = None else: checkpoint_path = find_latest_ckpt(output_dir / 'checkpoints') print(f'resume from {checkpoint_path}') else: checkpoint_path = update_checkpoint_path(cfg.checkpointing.load, cfg.wandb) # This allows the current step to be shared with the data loader processes. step_tracker = StepTracker() # 创建分布式策略 if torch.cuda.device_count() > 1: # 创建DDP策略并启用未使用参数检测 ddp_strategy = DDPStrategy( find_unused_parameters=True, # 关键设置:解决冻结参数导致的DDP问题 static_graph=False, # 设置为False以适应冻结参数 process_group_backend="nccl" if torch.cuda.is_available() else "gloo", ) else: ddp_strategy = "auto" trainer = Trainer( max_epochs=-1, accelerator="gpu", logger=logger, devices=torch.cuda.device_count(), strategy= ddp_strategy , #'ddp' if torch.cuda.device_count() > 1 else "auto" callbacks=callbacks, val_check_interval=cfg.trainer.val_check_interval, enable_progress_bar=cfg.mode == "test", gradient_clip_val=cfg.trainer.gradient_clip_val, max_steps=cfg.trainer.max_steps, num_sanity_val_steps=cfg.trainer.num_sanity_val_steps, num_nodes=cfg.trainer.num_nodes, plugins=LightningEnvironment() if cfg.use_plugins else None, ) # 只在主进程中设置全局种子,避免分布式环境下的问题 if torch.distributed.is_initialized(): # 在分布式环境中,只有rank 0设置全局种子 if torch.distributed.get_rank() == 0: _set_global_seed(42) # 等待rank 0完成种子设置 torch.distributed.barrier() else: # 在非分布式环境中,直接设置 _set_global_seed(42) torch.manual_seed(cfg_dict.seed + trainer.global_rank) encoder, encoder_visualizer = get_encoder(cfg.model.encoder) model_wrapper = ModelWrapper( cfg.optimizer, cfg.test, cfg.train, encoder, encoder_visualizer, get_decoder(cfg.model.decoder, cfg.dataset), get_losses(cfg.loss), step_tracker, eval_data_cfg=( None if eval_cfg is None else eval_cfg.dataset ), ) #测试参数 model_wrapper._check_param_updates = False data_module = DataModule( cfg.dataset, cfg.data_loader, step_tracker, global_rank=trainer.global_rank, ) if cfg.mode == "train": print("train:", len(data_module.train_dataloader())) print("val:", len(data_module.val_dataloader())) print("test:", len(data_module.test_dataloader())) strict_load = not cfg.checkpointing.no_strict_load if cfg.mode == "train": # only load monodepth if cfg.checkpointing.pretrained_monodepth is not None: strict_load = False pretrained_model = torch.load(cfg.checkpointing.pretrained_monodepth, map_location='cpu') if 'state_dict' in pretrained_model: pretrained_model = pretrained_model['state_dict'] load_result = model_wrapper.encoder.depth_predictor.load_state_dict(pretrained_model, strict=strict_load) # 获取成功加载的参数名称 loaded_keys = set(pretrained_model.keys()) - set(load_result.unexpected_keys) # 精确冻结:仅冻结成功加载的参数 for name, param in model_wrapper.encoder.depth_predictor.named_parameters(): if name in loaded_keys: param.requires_grad = False print("\n===== 参数加载报告 =====") print(f"✅ 成功加载参数数量: {len(loaded_keys)}") print(f"❄️ 冻结参数数量: {len(loaded_keys)}") print(f"⚠️ 缺失参数: {len(load_result.missing_keys)} 个") print(f"⚠️ 多余参数: {len(load_result.unexpected_keys)} 个") print(f"🛠️ 可训练参数数量: {len([p for p in model_wrapper.encoder.depth_predictor.parameters() if p.requires_grad])}") print( cyan( f"Loaded pretrained monodepth (partial freezing): {cfg.checkpointing.pretrained_monodepth}" ) ) # load pretrained mvdepth if cfg.checkpointing.pretrained_mvdepth is not None: pretrained_model = torch.load(cfg.checkpointing.pretrained_mvdepth, map_location='cpu')['model'] load_result = model_wrapper.encoder.depth_predictor.load_state_dict(pretrained_model, strict=False) print("\n===== 参数加载报告 =====") print(f"✅ 成功加载参数数量: {len(pretrained_model) - len(load_result.unexpected_keys)}") print(f"⚠️ 缺失参数: {len(load_result.missing_keys)} 个") print(f"⚠️ 多余参数: {len(load_result.unexpected_keys)} 个") print( cyan( f"Loaded pretrained mvdepth: {cfg.checkpointing.pretrained_mvdepth}" ) ) # load full model if cfg.checkpointing.pretrained_model is not None: strict_load = False pretrained_model = torch.load(cfg.checkpointing.pretrained_model, map_location='cpu') if 'state_dict' in pretrained_model: pretrained_model = pretrained_model['state_dict'] model_wrapper.load_state_dict(pretrained_model, strict=strict_load) print( cyan( f"Loaded pretrained weights: {cfg.checkpointing.pretrained_model}" ) ) # load pretrained depth if cfg.checkpointing.pretrained_depth is not None: strict_load = False pretrained_model = torch.load(cfg.checkpointing.pretrained_depth, map_location='cpu') if 'state_dict' in pretrained_model: pretrained_model = pretrained_model['state_dict'] # pretrained_model = torch.load(cfg.checkpointing.pretrained_depth, map_location='cpu')['model'] # strict_load = True load_result = model_wrapper.encoder.depth_predictor.load_state_dict(pretrained_model, strict=strict_load) print("\n===== 参数加载报告 =====") print(f"✅ 成功加载参数数量: {len(pretrained_model) - len(load_result.unexpected_keys)}") print(f"⚠️ 缺失参数: {len(load_result.missing_keys)} 个") print(f"⚠️ 多余参数: {len(load_result.unexpected_keys)} 个") print( cyan( f"Loaded pretrained depth: {cfg.checkpointing.pretrained_depth}" ) ) trainer.fit(model_wrapper, datamodule=data_module, ckpt_path=checkpoint_path) else: # load full model if cfg.checkpointing.pretrained_model is not None: pretrained_model = torch.load(cfg.checkpointing.pretrained_model, map_location='cpu') if 'state_dict' in pretrained_model: pretrained_model = pretrained_model['state_dict'] model_wrapper.load_state_dict(pretrained_model, strict=strict_load) print( cyan( f"Loaded pretrained weights: {cfg.checkpointing.pretrained_model}" ) ) # load pretrained depth model only if cfg.checkpointing.pretrained_depth is not None: pretrained_model = torch.load(cfg.checkpointing.pretrained_depth, map_location='cpu')['model'] strict_load = True model_wrapper.encoder.depth_predictor.load_state_dict(pretrained_model, strict=strict_load) print( cyan( f"Loaded pretrained depth: {cfg.checkpointing.pretrained_depth}" ) ) trainer.test( model_wrapper, datamodule=data_module, ckpt_path=checkpoint_path, ) if __name__ == "__main__": warnings.filterwarnings("ignore") torch.set_float32_matmul_precision('high') train()