Spaces:
Build error
Build error
| """Run wandb sweep.""" | |
| # Copyright (C) 2020 Intel Corporation | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, | |
| # software distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions | |
| # and limitations under the License. | |
| from argparse import ArgumentParser | |
| from pathlib import Path | |
| from typing import Union | |
| import pytorch_lightning as pl | |
| from omegaconf import DictConfig, ListConfig, OmegaConf | |
| from pytorch_lightning import seed_everything | |
| from pytorch_lightning.loggers import WandbLogger | |
| from utils import flatten_hpo_params | |
| import wandb | |
| from anomalib.config import get_configurable_parameters, update_input_size_config | |
| from anomalib.data import get_datamodule | |
| from anomalib.models import get_model | |
| from anomalib.utils.sweep import flatten_sweep_params, set_in_nested_config | |
| class WandbSweep: | |
| """wandb sweep. | |
| Args: | |
| config (DictConfig): Original model configuration. | |
| sweep_config (DictConfig): Sweep configuration. | |
| """ | |
| def __init__(self, config: Union[DictConfig, ListConfig], sweep_config: Union[DictConfig, ListConfig]) -> None: | |
| self.config = config | |
| self.sweep_config = sweep_config | |
| self.observation_budget = sweep_config.observation_budget | |
| if "observation_budget" in self.sweep_config.keys(): | |
| # this instance check is to silence mypy. | |
| if isinstance(self.sweep_config, DictConfig): | |
| self.sweep_config.pop("observation_budget") | |
| def run(self): | |
| """Run the sweep.""" | |
| flattened_hpo_params = flatten_hpo_params(self.sweep_config.parameters) | |
| self.sweep_config.parameters = flattened_hpo_params | |
| sweep_id = wandb.sweep( | |
| OmegaConf.to_object(self.sweep_config), | |
| project=f"{self.config.model.name}_{self.config.dataset.name}", | |
| ) | |
| wandb.agent(sweep_id, function=self.sweep, count=self.observation_budget) | |
| def sweep(self): | |
| """Method to load the model, update config and call fit. The metrics are logged to ```wandb``` dashboard.""" | |
| wandb_logger = WandbLogger(config=flatten_sweep_params(self.sweep_config), log_model=False) | |
| sweep_config = wandb_logger.experiment.config | |
| for param in sweep_config.keys(): | |
| set_in_nested_config(self.config, param.split("."), sweep_config[param]) | |
| config = update_input_size_config(self.config) | |
| model = get_model(config) | |
| datamodule = get_datamodule(config) | |
| # Disable saving checkpoints as all checkpoints from the sweep will get uploaded | |
| config.trainer.checkpoint_callback = False | |
| trainer = pl.Trainer(**config.trainer, logger=wandb_logger) | |
| trainer.fit(model, datamodule=datamodule) | |
| def get_args(): | |
| """Gets parameters from commandline.""" | |
| parser = ArgumentParser() | |
| parser.add_argument("--model", type=str, default="padim", help="Name of the algorithm to train/test") | |
| parser.add_argument("--model_config", type=Path, required=False, help="Path to a model config file") | |
| parser.add_argument("--sweep_config", type=Path, required=True, help="Path to sweep configuration") | |
| return parser.parse_args() | |
| if __name__ == "__main__": | |
| args = get_args() | |
| model_config = get_configurable_parameters(model_name=args.model, config_path=args.model_config) | |
| hpo_config = OmegaConf.load(args.sweep_config) | |
| if model_config.project.seed != 0: | |
| seed_everything(model_config.project.seed) | |
| sweep = WandbSweep(model_config, hpo_config) | |
| sweep.run() | |