Learn2Splat / optgs /misc /wandb_tools.py
SteEsp's picture
Add Docker-based Learn2Splat demo (viser GUI)
78d2329 verified
import os
from pathlib import Path
import wandb
from pytorch_lightning.loggers.wandb import WandbLogger
from omegaconf import OmegaConf
from optgs.misc.LocalLogger import LocalLogger
from optgs.paths import DEBUG
def version_to_int(artifact) -> int:
"""Convert versions of the form vX to X. For example, v12 to 12."""
return int(artifact.version[1:])
def download_checkpoint(
run_id: str,
download_dir: Path,
version: str | None,
) -> Path:
api = wandb.Api()
run = api.run(run_id)
# Find the latest saved model checkpoint.
chosen = None
for artifact in run.logged_artifacts():
if artifact.type != "model" or artifact.state != "COMMITTED":
continue
# If no version is specified, use the latest.
if version is None:
if chosen is None or version_to_int(artifact) > version_to_int(chosen):
chosen = artifact
# If a specific verison is specified, look for it.
elif version == artifact.version:
chosen = artifact
break
# Download the checkpoint.
download_dir.mkdir(exist_ok=True, parents=True)
root = download_dir / run_id
chosen.download(root=root)
return root / "model.ckpt"
def setup_wandb_logger(cfg, cfg_dict) -> WandbLogger | LocalLogger:
if cfg_dict.wandb.mode == "disabled" or cfg.mode != "train":
return LocalLogger()
wandb_extra_kwargs = {}
# Detect the wandb id job run if resuming
if cfg_dict.checkpointing.resume:
if cfg_dict.wandb.id is None:
print(f"Resuming wandb run without id, using latest run in output directory.")
# Find the latest wandb run id in the output directory
wandb_dir = cfg_dict.output_dir / "wandb" / "latest-run"
# look for a file name in the format "run-######.wandb" file and extract the id
wandb_files = list(wandb_dir.glob("run-*.wandb"))
assert len(wandb_files) <= 1, "Multiple wandb files found in the latest run directory."
if len(wandb_files) == 1:
wandb_file = wandb_files[0]
wandb_id = wandb_file.stem.split('-')[1]
wandb_extra_kwargs.update({'id': wandb_id, 'resume': "must"})
if cfg_dict.wandb.id is not None:
print(f"Setting wandb run with id from cfg {cfg_dict.wandb.id}.")
wandb_extra_kwargs.update({'id': cfg_dict.wandb.id, 'resume': "must"})
run_name = os.path.basename(cfg_dict.output_dir)
if cfg_dict.log_slurm_id:
hostname = os.uname().nodename
job_id = os.environ.get('SLURM_JOB_ID', "local run: " + hostname)
run_name += f" ({job_id})"
# if debugging, add a tag to the run name
if DEBUG:
run_name += " DEBUG"
cfg_dict.wandb.update({'tags': ['debug']})
if os.environ.get('WANDB_ENTITY') is not None:
cfg_dict.wandb.update({'entity': os.environ.get('WANDB_ENTITY')})
logger = WandbLogger(
entity=cfg_dict.wandb.entity,
project=cfg_dict.wandb.project,
mode=cfg_dict.wandb.mode,
name=run_name,
tags=cfg_dict.wandb.get("tags", None),
log_model=False,
save_dir=cfg_dict.output_dir,
config=OmegaConf.to_container(cfg_dict),
**wandb_extra_kwargs,
)
if logger.experiment is not None:
# Log code
logger.experiment.log_code("optgs")
# Log notes
if cfg_dict.wandb.notes is not None:
logger.experiment.notes = cfg_dict.wandb.notes
# Write wandb run ID to file for SLURM requeue resume
wandb_id_file = os.environ.get("WANDB_ID_FILE")
if wandb_id_file:
with open(wandb_id_file, "w") as f:
f.write(logger.experiment.id)
print(f"Wrote wandb run ID {logger.experiment.id} to {wandb_id_file}")
return logger
def update_checkpoint_path(path: str | None, wandb_cfg: dict) -> Path | None:
if path is None:
return None
if not str(path).startswith("wandb://"):
return Path(path)
run_id, *version = path[len("wandb://") :].split(":")
if len(version) == 0:
version = None
elif len(version) == 1:
version = version[0]
else:
raise ValueError("Invalid version specifier!")
project = wandb_cfg["project"]
return download_checkpoint(
f"{project}/{run_id}",
Path("checkpoints"),
version,
)