|
|
import time
|
|
|
import torch
|
|
|
import hydra
|
|
|
import pytorch_lightning as pl
|
|
|
from typing import Any
|
|
|
|
|
|
from hydra.core.config_store import ConfigStore
|
|
|
from omegaconf import OmegaConf
|
|
|
from pytorch_lightning.loggers import WandbLogger
|
|
|
from pytorch_lightning.callbacks import ModelCheckpoint
|
|
|
|
|
|
from pathlib import Path
|
|
|
from dataclasses import dataclass
|
|
|
|
|
|
from .module import GenericModule
|
|
|
from .data.module import GenericDataModule
|
|
|
from .callbacks import EvalSaveCallback, ImageLoggerCallback
|
|
|
from .models.schema import ModelConfiguration, DINOConfiguration, ResNetConfiguration
|
|
|
from .data.schema import MIADataConfiguration, KITTIDataConfiguration, NuScenesDataConfiguration
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class ExperimentConfiguration:
|
|
|
name: str
|
|
|
|
|
|
@dataclass
|
|
|
class Configuration:
|
|
|
model: ModelConfiguration
|
|
|
experiment: ExperimentConfiguration
|
|
|
data: Any
|
|
|
training: Any
|
|
|
|
|
|
|
|
|
cs = ConfigStore.instance()
|
|
|
|
|
|
|
|
|
cs.store(name="pretrain", node=Configuration)
|
|
|
cs.store(name="mapper_nuscenes", node=Configuration)
|
|
|
cs.store(name="mapper_kitti", node=Configuration)
|
|
|
|
|
|
|
|
|
cs.store(group="schema/data", name="mia",
|
|
|
node=MIADataConfiguration, package="data")
|
|
|
cs.store(group="schema/data", name="kitti", node=KITTIDataConfiguration, package="data")
|
|
|
cs.store(group="schema/data", name="nuscenes", node=NuScenesDataConfiguration, package="data")
|
|
|
|
|
|
cs.store(group="model/schema/backbone", name="dino", node=DINOConfiguration, package="model.image_encoder.backbone")
|
|
|
cs.store(group="model/schema/backbone", name="resnet", node=ResNetConfiguration, package="model.image_encoder.backbone")
|
|
|
|
|
|
|
|
|
@hydra.main(version_base=None, config_path="conf", config_name="pretrain")
|
|
|
def train(cfg: Configuration):
|
|
|
OmegaConf.resolve(cfg)
|
|
|
|
|
|
dm = GenericDataModule(cfg.data)
|
|
|
|
|
|
model = GenericModule(cfg)
|
|
|
|
|
|
exp_name_with_time = cfg.experiment.name + \
|
|
|
"_" + time.strftime("%Y-%m-%d_%H-%M-%S")
|
|
|
|
|
|
callbacks: list[pl.Callback]
|
|
|
|
|
|
if cfg.training.eval:
|
|
|
save_dir = Path(cfg.training.save_dir)
|
|
|
save_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
callbacks = [
|
|
|
EvalSaveCallback(save_dir=save_dir)
|
|
|
]
|
|
|
|
|
|
logger = None
|
|
|
else:
|
|
|
callbacks = [
|
|
|
ImageLoggerCallback(num_classes=cfg.training.num_classes),
|
|
|
ModelCheckpoint(
|
|
|
monitor=cfg.training.checkpointing.monitor,
|
|
|
save_last=cfg.training.checkpointing.save_last,
|
|
|
save_top_k=cfg.training.checkpointing.save_top_k,
|
|
|
)
|
|
|
]
|
|
|
|
|
|
logger = WandbLogger(
|
|
|
name=exp_name_with_time,
|
|
|
id=exp_name_with_time,
|
|
|
entity="mappred-large",
|
|
|
project="map-pred-full-v3",
|
|
|
)
|
|
|
|
|
|
logger.watch(model, log="all", log_freq=500)
|
|
|
|
|
|
if cfg.training.checkpoint is not None:
|
|
|
state_dict = torch.load(cfg.training.checkpoint)['state_dict']
|
|
|
model.load_state_dict(state_dict, strict=False)
|
|
|
|
|
|
trainer_args = OmegaConf.to_container(cfg.training.trainer)
|
|
|
trainer_args['callbacks'] = callbacks
|
|
|
trainer_args['logger'] = logger
|
|
|
|
|
|
trainer = pl.Trainer(**trainer_args)
|
|
|
|
|
|
if cfg.training.eval:
|
|
|
trainer.test(model, datamodule=dm)
|
|
|
else:
|
|
|
trainer.fit(model, datamodule=dm)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
pl.seed_everything(42)
|
|
|
torch.set_float32_matmul_precision("high")
|
|
|
|
|
|
train()
|
|
|
|