Spaces:
Sleeping
Sleeping
File size: 2,968 Bytes
78d2329 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 | import os
import sys
import warnings
from pathlib import Path
import hydra
import torch
from jaxtyping import install_import_hook
from omegaconf import DictConfig
import matplotlib.pyplot as plt
from optgs.misc.io import cyan
# Configure beartype and jaxtyping.
with install_import_hook(
("optgs",),
("beartype", "beartype"),
):
from optgs.config import setup_cfg
from optgs.dataset.data_module import DataModule
from optgs.misc.step_tracker import StepTracker
# print torch device info
print(cyan(f"Torch version: {torch.__version__}"))
if torch.cuda.is_available():
print(cyan(f"CUDA is available. Number of devices: {torch.cuda.device_count()}"))
for i in range(torch.cuda.device_count()):
print(cyan(f"Device {i}: {torch.cuda.get_device_name(i)}"))
else:
print(cyan("CUDA is not available."))
# raise ValueError("CUDA is required to run this code.")
@hydra.main(
version_base=None,
config_path="config",
config_name="main",
)
def train(cfg_dict: DictConfig):
# Set up configuration.
cfg, cfg_dict, eval_cfg = setup_cfg(cfg_dict)
# This allows the current step to be shared with the data loader processes.
step_tracker = StepTracker()
data_module = DataModule(
cfg.dataset,
cfg.data_loader,
step_tracker,
)
if cfg.mode == "train":
print("train:", len(data_module.train_dataloader()))
print("val:", len(data_module.val_dataloader()))
print("test:", len(data_module.test_dataloader()))
else:
print("test:", len(data_module.test_dataloader()))
# DEBUGGING: loop over all data once to catch errors early
for batch_idx, batch in enumerate(data_module.test_dataloader()):
extrinsics = batch["context"]["extrinsics"]
pose_norm = extrinsics.view(extrinsics.shape[0], -1).norm(dim=1)
if pose_norm > 1e3:
print(f"Batch {batch_idx}: pose norm {pose_norm.item():.4f} {extrinsics} {batch['scene']} {batch['context']['index']}")
image = batch["context"]["image"][0, 0].permute(1, 2, 0).cpu().numpy()
plt.figure()
plt.imshow(image)
plt.title(f"Batch {batch_idx}\n{batch['scene'][0]}")
plt.show()
print(cyan("DEBUG: Completed one full pass through the data loaders without errors. Exiting now."))
sys.exit(0)
if __name__ == "__main__":
warnings.filterwarnings("ignore")
torch.set_float32_matmul_precision('high')
if not torch.cuda.is_available():
print("")
print(cyan("=" * 80))
print(cyan("CUDA is not available, running on CPU."))
print(cyan("=" * 80))
print("")
# Print the hostname and current working directory.
print(cyan("=" * 80))
print(cyan(f"Starting training on {os.uname().nodename}, slurm job id: {os.environ.get('SLURM_JOB_ID', 'N/A')}"))
print(cyan(f"Current working directory: {Path.cwd()}"))
print(cyan("=" * 80))
train()
|