| import logging |
| import os |
| import sys |
| import warnings |
|
|
| import hydra |
| from hydra.core.hydra_config import HydraConfig |
| from hydra.utils import instantiate |
| from omegaconf import DictConfig |
|
|
| from transforna import compute_cv, infer_benchmark, infer_tcga, train |
|
|
| warnings.filterwarnings("ignore") |
|
|
|
|
| logger = logging.getLogger(__name__) |
|
|
| def add_config_to_sys_path(): |
| cfg = HydraConfig.get() |
| config_path = [path["path"] for path in cfg.runtime.config_sources if path["schema"] == "file"][0] |
| sys.path.append(config_path) |
|
|
| |
| |
| @hydra.main(config_path='../conf', config_name="main_config") |
| def main(cfg: DictConfig) -> None: |
| add_config_to_sys_path() |
| |
| output_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir |
|
|
| path = os.getcwd() |
| |
| cfg['train_config'] = instantiate(cfg['train_config']).__dict__ |
| cfg['model_config'] = instantiate(cfg['model_config']).__dict__ |
|
|
| |
| cfg['model_config']["model_input"] = cfg["model_name"] |
|
|
| |
| if cfg["inference"]: |
| logger.info(f"Started inference on {cfg['task']}") |
| if cfg['task'] == 'tcga': |
| return infer_tcga(cfg,path=path) |
| else: |
| return infer_benchmark(cfg,path=path) |
| else: |
| if cfg["cross_val"]: |
| compute_cv(cfg,path,output_dir=output_dir) |
|
|
| else: |
| train(cfg,path=path,output_dir=output_dir) |
| |
| if __name__ == "__main__": |
| main() |
|
|