Learn2Splat / optgs /scripts /dev /debug_dataset.py
SteEsp's picture
Add Docker-based Learn2Splat demo (viser GUI)
78d2329 verified
import os
import sys
import warnings
from pathlib import Path
import hydra
import torch
from jaxtyping import install_import_hook
from omegaconf import DictConfig
import matplotlib.pyplot as plt
from optgs.misc.io import cyan
# Configure beartype and jaxtyping.
with install_import_hook(
("optgs",),
("beartype", "beartype"),
):
from optgs.config import setup_cfg
from optgs.dataset.data_module import DataModule
from optgs.misc.step_tracker import StepTracker
# print torch device info
print(cyan(f"Torch version: {torch.__version__}"))
if torch.cuda.is_available():
print(cyan(f"CUDA is available. Number of devices: {torch.cuda.device_count()}"))
for i in range(torch.cuda.device_count()):
print(cyan(f"Device {i}: {torch.cuda.get_device_name(i)}"))
else:
print(cyan("CUDA is not available."))
# raise ValueError("CUDA is required to run this code.")
@hydra.main(
version_base=None,
config_path="config",
config_name="main",
)
def train(cfg_dict: DictConfig):
# Set up configuration.
cfg, cfg_dict, eval_cfg = setup_cfg(cfg_dict)
# This allows the current step to be shared with the data loader processes.
step_tracker = StepTracker()
data_module = DataModule(
cfg.dataset,
cfg.data_loader,
step_tracker,
)
if cfg.mode == "train":
print("train:", len(data_module.train_dataloader()))
print("val:", len(data_module.val_dataloader()))
print("test:", len(data_module.test_dataloader()))
else:
print("test:", len(data_module.test_dataloader()))
# DEBUGGING: loop over all data once to catch errors early
for batch_idx, batch in enumerate(data_module.test_dataloader()):
extrinsics = batch["context"]["extrinsics"]
pose_norm = extrinsics.view(extrinsics.shape[0], -1).norm(dim=1)
if pose_norm > 1e3:
print(f"Batch {batch_idx}: pose norm {pose_norm.item():.4f} {extrinsics} {batch['scene']} {batch['context']['index']}")
image = batch["context"]["image"][0, 0].permute(1, 2, 0).cpu().numpy()
plt.figure()
plt.imshow(image)
plt.title(f"Batch {batch_idx}\n{batch['scene'][0]}")
plt.show()
print(cyan("DEBUG: Completed one full pass through the data loaders without errors. Exiting now."))
sys.exit(0)
if __name__ == "__main__":
warnings.filterwarnings("ignore")
torch.set_float32_matmul_precision('high')
if not torch.cuda.is_available():
print("")
print(cyan("=" * 80))
print(cyan("CUDA is not available, running on CPU."))
print(cyan("=" * 80))
print("")
# Print the hostname and current working directory.
print(cyan("=" * 80))
print(cyan(f"Starting training on {os.uname().nodename}, slurm job id: {os.environ.get('SLURM_JOB_ID', 'N/A')}"))
print(cyan(f"Current working directory: {Path.cwd()}"))
print(cyan("=" * 80))
train()