Spaces:
Running
Running
File size: 3,331 Bytes
891e05c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 | import os
import sys
import hydra
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import RichProgressBar
from pytorch_lightning.loggers import TensorBoardLogger
from omegaconf import DictConfig
from pathlib import Path
from models.loupe import LoupeModel, LoupeConfig
from data_module import DataModule
from lit_model import LitModel
from utils import CustomWriter, convert_deepspeed_checkpoint
sys.path.insert(0, ".")
project_root = Path(__file__).resolve().parent.parent
save_path = os.environ.get("CATTINO_TASK_HOME", f"{project_root}/results")
@hydra.main(
config_path=str(project_root / "configs"),
config_name="infer",
version_base=None,
)
def main(cfg: DictConfig):
pl.seed_everything(cfg.seed)
if cfg.stage.name != "test":
raise ValueError("This script is for testing only. Please use the test stage.")
callbacks = [RichProgressBar()]
trainer_overrides = dict(
devices=1 if cfg.trainer.fast_dev_run else "auto",
)
if cfg.stage.enable_tta:
# if tta is enbaled, turn trainer to trainable
trainer_overrides.update(
logger=TensorBoardLogger(
save_dir=save_path,
name=cfg.stage.name,
default_hp_metric=False,
),
max_epochs=cfg.hparams.epoch,
gradient_clip_val=cfg.hparams.grad_clip_val,
val_check_interval=0.2,
log_every_n_steps=2,
accumulate_grad_batches=cfg.hparams.accumulate_grad_batches,
)
cfg.model.freeze_seg = False
cfg.model.enable_conditional_queries = True
else:
trainer_overrides.update(
logger=False,
enable_checkpointing=False,
)
if cfg.stage.pred_output_dir:
callbacks.append(CustomWriter(cfg=cfg, write_interval="batch"))
trainer_overrides.update(cfg.trainer)
if trainer_overrides.get("enable_checkpointing", False):
checkpoint_callback = hydra.utils.instantiate(
cfg.ckpt.saver,
dirpath=os.path.join(save_path, "checkpoints"),
)
callbacks.append(checkpoint_callback)
trainer = pl.Trainer(
callbacks=callbacks,
**trainer_overrides,
)
torch.set_float32_matmul_precision("medium")
loupe_config = LoupeConfig(stage=cfg.stage.name, **cfg.model)
loupe = LoupeModel(loupe_config)
model = LitModel(cfg=cfg, loupe=loupe)
data_module = DataModule(cfg, loupe_config)
if cfg.stage.enable_tta:
trainer.fit(
model,
data_module,
)
if "deepspeed" in cfg.trainer.get("strategy", "") and trainer_overrides.get(
"enable_checkpointing", False
):
convert_deepspeed_checkpoint(
cfg,
checkpoint_callback,
os.path.join(
project_root,
"checkpoints",
os.path.basename(checkpoint_callback.best_model_path),
),
)
elif cfg.stage.pred_output_dir:
trainer.predict(
model,
data_module,
return_predictions=False,
)
else:
trainer.test(
model,
data_module,
)
if __name__ == "__main__":
main()
|