| | import os |
| | from pathlib import Path |
| | import warnings |
| | import copy |
| |
|
| | import hydra |
| | import torch |
| | import swanlab as wandb |
| | import os |
| | from colorama import Fore |
| | from jaxtyping import install_import_hook |
| | from omegaconf import DictConfig, OmegaConf |
| |
|
| | import random |
| | import numpy as np |
| |
|
| | import sys |
| | |
| | sys.path = [p for p in sys.path if "/mnt/data-3/users/nichaojun/.local" not in p] |
| |
|
| |
|
| | from pytorch_lightning import Trainer |
| | from pytorch_lightning.callbacks import ( |
| | LearningRateMonitor, |
| | ModelCheckpoint, |
| | ) |
| |
|
| | from pytorch_lightning.callbacks import Callback |
| | from typing import Any, Dict, Optional |
| |
|
| |
|
| | |
| | from swanlab.integration.pytorch_lightning import SwanLabLogger |
| |
|
| | from pytorch_lightning.plugins.environments import LightningEnvironment |
| | from pytorch_lightning.strategies import DDPStrategy |
| |
|
| | |
| | with install_import_hook( |
| | ("src",), |
| | ("beartype", "beartype"), |
| | ): |
| | from src.config import load_typed_root_config |
| | from src.dataset.data_module import DataModule |
| | from src.global_cfg import set_cfg |
| | from src.loss import get_losses |
| | from src.misc.LocalLogger import LocalLogger |
| | from src.misc.step_tracker import StepTracker |
| | from src.misc.wandb_tools import update_checkpoint_path |
| | from src.misc.resume_ckpt import find_latest_ckpt |
| | from src.model.decoder import get_decoder |
| | from src.model.encoder import get_encoder |
| | from src.model.model_wrapper import ModelWrapper |
| |
|
| |
|
| | def cyan(text: str) -> str: |
| | return f"{Fore.CYAN}{text}{Fore.RESET}" |
| |
|
| |
|
| | class UnfreezePretrainedCallback(Callback): |
| | def __init__(self, unfreeze_step: int = 20000): |
| | self.unfreeze_step = unfreeze_step |
| | self.has_unfrozen = False |
| |
|
| | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): |
| | |
| | if trainer.is_global_zero and not self.has_unfrozen: |
| | current_step = trainer.global_step |
| | if current_step >= self.unfreeze_step: |
| | print(cyan(f"Step {current_step}: Unfreezing pretrained_monodepth parameters")) |
| | |
| | for param in pl_module.encoder.depth_predictor.parameters(): |
| | param.requires_grad = True |
| |
|
| | self.has_unfrozen = True |
| |
|
| | def _set_global_seed(seed): |
| | """设置Python和NumPy的全局随机种子""" |
| | random.seed(seed) |
| | np.random.seed(seed) |
| | |
| | print(f"Global Python/NumPy seed set to: {seed}") |
| |
|
| |
|
| |
|
| | @hydra.main( |
| | version_base=None, |
| | config_path="../config", |
| | config_name="main", |
| | ) |
| |
|
| | def train(cfg_dict: DictConfig): |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | if cfg_dict["mode"] == "train" and cfg_dict["train"]["eval_model_every_n_val"] > 0: |
| | eval_cfg_dict = copy.deepcopy(cfg_dict) |
| | dataset_dir = str(cfg_dict["dataset"]["roots"]).lower() |
| | if "re10k" in dataset_dir: |
| | eval_path = "assets/evaluation_index_re10k.json" |
| | elif "dl3dv" in dataset_dir: |
| | if cfg_dict["dataset"]["view_sampler"]["num_context_views"] == 6: |
| | eval_path = "assets/dl3dv_start_0_distance_50_ctx_6v_tgt_8v.json" |
| | else: |
| | raise ValueError("unsupported number of views for dl3dv") |
| | else: |
| | raise Exception("Fail to load eval index path") |
| | eval_cfg_dict["dataset"]["view_sampler"] = { |
| | "name": "evaluation", |
| | "index_path": eval_path, |
| | "num_context_views": cfg_dict["dataset"]["view_sampler"]["num_context_views"], |
| | } |
| | eval_cfg = load_typed_root_config(eval_cfg_dict) |
| | else: |
| | eval_cfg = None |
| |
|
| | cfg = load_typed_root_config(cfg_dict) |
| | set_cfg(cfg_dict) |
| |
|
| | |
| | if cfg_dict.output_dir is None: |
| | output_dir = Path( |
| | hydra.core.hydra_config.HydraConfig.get()["runtime"]["output_dir"] |
| | ) |
| | else: |
| | output_dir = Path(cfg_dict.output_dir) |
| | os.makedirs(output_dir, exist_ok=True) |
| | print(cyan(f"Saving outputs to {output_dir}.")) |
| |
|
| | |
| | callbacks = [] |
| | if cfg_dict.wandb.mode != "disabled" and cfg.mode == "train": |
| | wandb_extra_kwargs = {} |
| | if cfg_dict.wandb.id is not None: |
| | wandb_extra_kwargs.update({'id': cfg_dict.wandb.id, |
| | 'resume': "must"}) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | logger = SwanLabLogger( |
| | project = cfg_dict.wandb.project, |
| | experiment_name = cfg_dict.wandb.entity, |
| | workspace = cfg_dict.wandb.workspace |
| | ) |
| |
|
| | callbacks.append(LearningRateMonitor("step", True)) |
| |
|
| | |
| | |
| | else: |
| | logger = LocalLogger() |
| |
|
| | |
| | callbacks.append( |
| | ModelCheckpoint( |
| | output_dir / "checkpoints", |
| | every_n_train_steps=cfg.checkpointing.every_n_train_steps, |
| | save_top_k=cfg.checkpointing.save_top_k, |
| | monitor="info/global_step", |
| | mode="max", |
| | ) |
| | ) |
| |
|
| | |
| | callbacks.append(UnfreezePretrainedCallback(unfreeze_step=20000)) |
| |
|
| | for cb in callbacks: |
| | cb.CHECKPOINT_EQUALS_CHAR = '_' |
| |
|
| | |
| | if cfg.checkpointing.resume: |
| | if not os.path.exists(output_dir / 'checkpoints'): |
| | checkpoint_path = None |
| | else: |
| | checkpoint_path = find_latest_ckpt(output_dir / 'checkpoints') |
| | print(f'resume from {checkpoint_path}') |
| | else: |
| | checkpoint_path = update_checkpoint_path(cfg.checkpointing.load, cfg.wandb) |
| |
|
| | |
| | step_tracker = StepTracker() |
| |
|
| | |
| | if torch.cuda.device_count() > 1: |
| | |
| | ddp_strategy = DDPStrategy( |
| | find_unused_parameters=True, |
| | static_graph=False, |
| | process_group_backend="nccl" if torch.cuda.is_available() else "gloo", |
| | ) |
| | else: |
| | ddp_strategy = "auto" |
| |
|
| | trainer = Trainer( |
| | max_epochs=-1, |
| | accelerator="gpu", |
| | logger=logger, |
| | devices=torch.cuda.device_count(), |
| | strategy= ddp_strategy , |
| | callbacks=callbacks, |
| | val_check_interval=cfg.trainer.val_check_interval, |
| | enable_progress_bar=cfg.mode == "test", |
| | gradient_clip_val=cfg.trainer.gradient_clip_val, |
| | max_steps=cfg.trainer.max_steps, |
| | num_sanity_val_steps=cfg.trainer.num_sanity_val_steps, |
| | num_nodes=cfg.trainer.num_nodes, |
| | plugins=LightningEnvironment() if cfg.use_plugins else None, |
| | ) |
| |
|
| | |
| | if torch.distributed.is_initialized(): |
| | |
| | if torch.distributed.get_rank() == 0: |
| | _set_global_seed(42) |
| | |
| | torch.distributed.barrier() |
| | else: |
| | |
| | _set_global_seed(42) |
| |
|
| | torch.manual_seed(cfg_dict.seed + trainer.global_rank) |
| |
|
| | encoder, encoder_visualizer = get_encoder(cfg.model.encoder) |
| |
|
| | model_wrapper = ModelWrapper( |
| | cfg.optimizer, |
| | cfg.test, |
| | cfg.train, |
| | encoder, |
| | encoder_visualizer, |
| | get_decoder(cfg.model.decoder, cfg.dataset), |
| | get_losses(cfg.loss), |
| | step_tracker, |
| | eval_data_cfg=( |
| | None if eval_cfg is None else eval_cfg.dataset |
| | ), |
| | ) |
| | |
| | model_wrapper._check_param_updates = False |
| | |
| | |
| | data_module = DataModule( |
| | cfg.dataset, |
| | cfg.data_loader, |
| | step_tracker, |
| | global_rank=trainer.global_rank, |
| | ) |
| |
|
| | if cfg.mode == "train": |
| | print("train:", len(data_module.train_dataloader())) |
| | print("val:", len(data_module.val_dataloader())) |
| | print("test:", len(data_module.test_dataloader())) |
| |
|
| | strict_load = not cfg.checkpointing.no_strict_load |
| |
|
| | if cfg.mode == "train": |
| | |
| | if cfg.checkpointing.pretrained_monodepth is not None: |
| | strict_load = False |
| | pretrained_model = torch.load(cfg.checkpointing.pretrained_monodepth, map_location='cpu') |
| | if 'state_dict' in pretrained_model: |
| | pretrained_model = pretrained_model['state_dict'] |
| |
|
| | load_result = model_wrapper.encoder.depth_predictor.load_state_dict(pretrained_model, strict=strict_load) |
| | |
| | |
| | loaded_keys = set(pretrained_model.keys()) - set(load_result.unexpected_keys) |
| | |
| | |
| | for name, param in model_wrapper.encoder.depth_predictor.named_parameters(): |
| | if name in loaded_keys: |
| | param.requires_grad = False |
| | |
| | print("\n===== 参数加载报告 =====") |
| | print(f"✅ 成功加载参数数量: {len(loaded_keys)}") |
| | print(f"❄️ 冻结参数数量: {len(loaded_keys)}") |
| | print(f"⚠️ 缺失参数: {len(load_result.missing_keys)} 个") |
| | print(f"⚠️ 多余参数: {len(load_result.unexpected_keys)} 个") |
| | print(f"🛠️ 可训练参数数量: {len([p for p in model_wrapper.encoder.depth_predictor.parameters() if p.requires_grad])}") |
| | |
| | print( |
| | cyan( |
| | f"Loaded pretrained monodepth (partial freezing): {cfg.checkpointing.pretrained_monodepth}" |
| | ) |
| | ) |
| | |
| | |
| | if cfg.checkpointing.pretrained_mvdepth is not None: |
| | pretrained_model = torch.load(cfg.checkpointing.pretrained_mvdepth, map_location='cpu')['model'] |
| |
|
| | load_result = model_wrapper.encoder.depth_predictor.load_state_dict(pretrained_model, strict=False) |
| | |
| | print("\n===== 参数加载报告 =====") |
| | print(f"✅ 成功加载参数数量: {len(pretrained_model) - len(load_result.unexpected_keys)}") |
| | print(f"⚠️ 缺失参数: {len(load_result.missing_keys)} 个") |
| | print(f"⚠️ 多余参数: {len(load_result.unexpected_keys)} 个") |
| | |
| | |
| | print( |
| | cyan( |
| | f"Loaded pretrained mvdepth: {cfg.checkpointing.pretrained_mvdepth}" |
| | ) |
| | ) |
| | |
| | |
| | if cfg.checkpointing.pretrained_model is not None: |
| | strict_load = False |
| | pretrained_model = torch.load(cfg.checkpointing.pretrained_model, map_location='cpu') |
| | if 'state_dict' in pretrained_model: |
| | pretrained_model = pretrained_model['state_dict'] |
| |
|
| | model_wrapper.load_state_dict(pretrained_model, strict=strict_load) |
| | print( |
| | cyan( |
| | f"Loaded pretrained weights: {cfg.checkpointing.pretrained_model}" |
| | ) |
| | ) |
| |
|
| | |
| | if cfg.checkpointing.pretrained_depth is not None: |
| | |
| | strict_load = False |
| | pretrained_model = torch.load(cfg.checkpointing.pretrained_depth, map_location='cpu') |
| | if 'state_dict' in pretrained_model: |
| | pretrained_model = pretrained_model['state_dict'] |
| | |
| | |
| | |
| | load_result = model_wrapper.encoder.depth_predictor.load_state_dict(pretrained_model, strict=strict_load) |
| | |
| | |
| | print("\n===== 参数加载报告 =====") |
| | print(f"✅ 成功加载参数数量: {len(pretrained_model) - len(load_result.unexpected_keys)}") |
| | print(f"⚠️ 缺失参数: {len(load_result.missing_keys)} 个") |
| | print(f"⚠️ 多余参数: {len(load_result.unexpected_keys)} 个") |
| | |
| | print( |
| | cyan( |
| | f"Loaded pretrained depth: {cfg.checkpointing.pretrained_depth}" |
| | ) |
| | ) |
| | |
| | trainer.fit(model_wrapper, datamodule=data_module, ckpt_path=checkpoint_path) |
| | else: |
| | |
| | if cfg.checkpointing.pretrained_model is not None: |
| | pretrained_model = torch.load(cfg.checkpointing.pretrained_model, map_location='cpu') |
| | if 'state_dict' in pretrained_model: |
| | pretrained_model = pretrained_model['state_dict'] |
| |
|
| | model_wrapper.load_state_dict(pretrained_model, strict=strict_load) |
| | print( |
| | cyan( |
| | f"Loaded pretrained weights: {cfg.checkpointing.pretrained_model}" |
| | ) |
| | ) |
| |
|
| | |
| | if cfg.checkpointing.pretrained_depth is not None: |
| | pretrained_model = torch.load(cfg.checkpointing.pretrained_depth, map_location='cpu')['model'] |
| |
|
| | strict_load = True |
| | model_wrapper.encoder.depth_predictor.load_state_dict(pretrained_model, strict=strict_load) |
| | print( |
| | cyan( |
| | f"Loaded pretrained depth: {cfg.checkpointing.pretrained_depth}" |
| | ) |
| | ) |
| | |
| | trainer.test( |
| | model_wrapper, |
| | datamodule=data_module, |
| | ckpt_path=checkpoint_path, |
| | ) |
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | if __name__ == "__main__": |
| | warnings.filterwarnings("ignore") |
| | torch.set_float32_matmul_precision('high') |
| |
|
| | train() |
| |
|