File size: 2,968 Bytes
78d2329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
import sys
import warnings
from pathlib import Path

import hydra
import torch
from jaxtyping import install_import_hook
from omegaconf import DictConfig
import matplotlib.pyplot as plt

from optgs.misc.io import cyan

# Configure beartype and jaxtyping.
with install_import_hook(
        ("optgs",),
        ("beartype", "beartype"),
):
    from optgs.config import setup_cfg
    from optgs.dataset.data_module import DataModule
    from optgs.misc.step_tracker import StepTracker

# print torch device info
print(cyan(f"Torch version: {torch.__version__}"))
if torch.cuda.is_available():
    print(cyan(f"CUDA is available. Number of devices: {torch.cuda.device_count()}"))
    for i in range(torch.cuda.device_count()):
        print(cyan(f"Device {i}: {torch.cuda.get_device_name(i)}"))
else:
    print(cyan("CUDA is not available."))
    # raise ValueError("CUDA is required to run this code.")


@hydra.main(
    version_base=None,
    config_path="config",
    config_name="main",
)
def train(cfg_dict: DictConfig):
    # Set up configuration.
    cfg, cfg_dict, eval_cfg = setup_cfg(cfg_dict)

    # This allows the current step to be shared with the data loader processes.
    step_tracker = StepTracker()

    data_module = DataModule(
        cfg.dataset,
        cfg.data_loader,
        step_tracker,
    )

    if cfg.mode == "train":
        print("train:", len(data_module.train_dataloader()))
        print("val:", len(data_module.val_dataloader()))
        print("test:", len(data_module.test_dataloader()))
    else:
        print("test:", len(data_module.test_dataloader()))

    # DEBUGGING: loop over all data once to catch errors early
    for batch_idx, batch in enumerate(data_module.test_dataloader()):
        extrinsics = batch["context"]["extrinsics"]
        pose_norm = extrinsics.view(extrinsics.shape[0], -1).norm(dim=1)
        if pose_norm > 1e3:
            print(f"Batch {batch_idx}: pose norm {pose_norm.item():.4f} {extrinsics} {batch['scene']} {batch['context']['index']}")

        image = batch["context"]["image"][0, 0].permute(1, 2, 0).cpu().numpy()

        plt.figure()
        plt.imshow(image)
        plt.title(f"Batch {batch_idx}\n{batch['scene'][0]}")
        plt.show()

    print(cyan("DEBUG: Completed one full pass through the data loaders without errors. Exiting now."))
    sys.exit(0)


if __name__ == "__main__":
    warnings.filterwarnings("ignore")
    torch.set_float32_matmul_precision('high')

    if not torch.cuda.is_available():
        print("")
        print(cyan("=" * 80))
        print(cyan("CUDA is not available, running on CPU."))
        print(cyan("=" * 80))
        print("")

        # Print the hostname and current working directory.
    print(cyan("=" * 80))
    print(cyan(f"Starting training on {os.uname().nodename}, slurm job id: {os.environ.get('SLURM_JOB_ID', 'N/A')}"))
    print(cyan(f"Current working directory: {Path.cwd()}"))
    print(cyan("=" * 80))

    train()