| import os |
| import random |
| import string |
| import sys |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Optional |
|
|
| import hydra |
| import omegaconf |
| import pytorch_lightning as pl |
| import torch |
| import torch.multiprocessing |
| from omegaconf import OmegaConf, listconfig |
| from pytorch_lightning import LightningModule |
| from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint |
| from pytorch_lightning.loggers import WandbLogger |
| from pytorch_lightning.strategies import DDPStrategy |
| from pytorch_lightning.utilities import rank_zero_only |
|
|
| from boltz.data.module.training import BoltzTrainingDataModule, DataConfig |
|
|
|
|
| @dataclass |
| class TrainConfig: |
| """Train configuration. |
| |
| Attributes |
| ---------- |
| data : DataConfig |
| The data configuration. |
| model : ModelConfig |
| The model configuration. |
| output : str |
| The output directory. |
| trainer : Optional[dict] |
| The trainer configuration. |
| resume : Optional[str] |
| The resume checkpoint. |
| pretrained : Optional[str] |
| The pretrained model. |
| wandb : Optional[dict] |
| The wandb configuration. |
| disable_checkpoint : bool |
| Disable checkpoint. |
| matmul_precision : Optional[str] |
| The matmul precision. |
| find_unused_parameters : Optional[bool] |
| Find unused parameters. |
| save_top_k : Optional[int] |
| Save top k checkpoints. |
| validation_only : bool |
| Run validation only. |
| debug : bool |
| Debug mode. |
| strict_loading : bool |
| Fail on mismatched checkpoint weights. |
| load_confidence_from_trunk: Optional[bool] |
| Load pre-trained confidence weights from trunk. |
| |
| """ |
|
|
| data: DataConfig |
| model: LightningModule |
| output: str |
| trainer: Optional[dict] = None |
| resume: Optional[str] = None |
| pretrained: Optional[str] = None |
| wandb: Optional[dict] = None |
| disable_checkpoint: bool = False |
| matmul_precision: Optional[str] = None |
| find_unused_parameters: Optional[bool] = False |
| save_top_k: Optional[int] = 1 |
| validation_only: bool = False |
| debug: bool = False |
| strict_loading: bool = True |
| load_confidence_from_trunk: Optional[bool] = False |
|
|
|
|
| def train(raw_config: str, args: list[str]) -> None: |
| """Run training. |
| |
| Parameters |
| ---------- |
| raw_config : str |
| The input yaml configuration. |
| args : list[str] |
| Any command line overrides. |
| |
| """ |
| |
| raw_config = omegaconf.OmegaConf.load(raw_config) |
|
|
| |
| args = omegaconf.OmegaConf.from_dotlist(args) |
| raw_config = omegaconf.OmegaConf.merge(raw_config, args) |
|
|
| |
| cfg = hydra.utils.instantiate(raw_config) |
| cfg = TrainConfig(**cfg) |
|
|
| |
| if cfg.matmul_precision is not None: |
| torch.set_float32_matmul_precision(cfg.matmul_precision) |
|
|
| |
| trainer = cfg.trainer |
| if trainer is None: |
| trainer = {} |
|
|
| |
| devices = trainer.get("devices", 1) |
|
|
| wandb = cfg.wandb |
| if cfg.debug: |
| if isinstance(devices, int): |
| devices = 1 |
| elif isinstance(devices, (list, listconfig.ListConfig)): |
| devices = [devices[0]] |
| trainer["devices"] = devices |
| cfg.data.num_workers = 0 |
| if wandb: |
| wandb = None |
|
|
| |
| data_config = DataConfig(**cfg.data) |
| data_module = BoltzTrainingDataModule(data_config) |
| model_module = cfg.model |
|
|
| if cfg.pretrained and not cfg.resume: |
| |
| if cfg.load_confidence_from_trunk: |
| checkpoint = torch.load(cfg.pretrained, map_location="cpu") |
|
|
| |
| new_state_dict = {} |
| for key, value in checkpoint["state_dict"].items(): |
| if not key.startswith("structure_module") and not key.startswith( |
| "distogram_module" |
| ): |
| new_key = "confidence_module." + key |
| new_state_dict[new_key] = value |
| new_state_dict.update(checkpoint["state_dict"]) |
|
|
| |
| checkpoint["state_dict"] = new_state_dict |
|
|
| |
| random_string = "".join( |
| random.choices(string.ascii_lowercase + string.digits, k=10) |
| ) |
| file_path = os.path.dirname(cfg.pretrained) + "/" + random_string + ".ckpt" |
| print( |
| f"Saving modified checkpoint to {file_path} created by broadcasting trunk of {cfg.pretrained} to confidence module." |
| ) |
| torch.save(checkpoint, file_path) |
| else: |
| file_path = cfg.pretrained |
|
|
| print(f"Loading model from {file_path}") |
| model_module = type(model_module).load_from_checkpoint( |
| file_path, map_location="cpu", strict=False, **(model_module.hparams) |
| ) |
|
|
| if cfg.load_confidence_from_trunk: |
| os.remove(file_path) |
|
|
| |
| callbacks = [] |
| dirpath = cfg.output |
| if not cfg.disable_checkpoint: |
| mc = ModelCheckpoint( |
| monitor="val/lddt", |
| save_top_k=cfg.save_top_k, |
| save_last=True, |
| mode="max", |
| every_n_epochs=1, |
| ) |
| callbacks = [mc] |
|
|
| |
| loggers = [] |
| if wandb: |
| wdb_logger = WandbLogger( |
| name=wandb["name"], |
| group=wandb["name"], |
| save_dir=cfg.output, |
| project=wandb["project"], |
| entity=wandb["entity"], |
| log_model=False, |
| ) |
| loggers.append(wdb_logger) |
| |
|
|
| @rank_zero_only |
| def save_config_to_wandb() -> None: |
| config_out = Path(wdb_logger.experiment.dir) / "run.yaml" |
| with Path.open(config_out, "w") as f: |
| OmegaConf.save(raw_config, f) |
| wdb_logger.experiment.save(str(config_out)) |
|
|
| save_config_to_wandb() |
|
|
| |
| strategy = "auto" |
| if (isinstance(devices, int) and devices > 1) or ( |
| isinstance(devices, (list, listconfig.ListConfig)) and len(devices) > 1 |
| ): |
| strategy = DDPStrategy(find_unused_parameters=cfg.find_unused_parameters) |
|
|
| trainer = pl.Trainer( |
| default_root_dir=str(dirpath), |
| strategy=strategy, |
| callbacks=callbacks, |
| logger=loggers, |
| enable_checkpointing=not cfg.disable_checkpoint, |
| reload_dataloaders_every_n_epochs=1, |
| **trainer, |
| ) |
|
|
| if not cfg.strict_loading: |
| model_module.strict_loading = False |
|
|
| if cfg.validation_only: |
| trainer.validate( |
| model_module, |
| datamodule=data_module, |
| ckpt_path=cfg.resume, |
| ) |
| else: |
| trainer.fit( |
| model_module, |
| datamodule=data_module, |
| ckpt_path=cfg.resume, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| arg1 = sys.argv[1] |
| arg2 = sys.argv[2:] |
| train(arg1, arg2) |
|
|