# Path setup, and access the config.yml file, datasets folder & trained models import sys from pathlib import Path from pydantic import BaseModel from strictyaml import YAML, load import os from typing import List file = Path(__file__).resolve() parent, root = file.parent, file.parents[1] sys.path.append(str(root)) # Project Directories PACKAGE_ROOT = Path(__file__).resolve().parent ROOT = PACKAGE_ROOT.parent CONFIG_FILE_PATH = ROOT / "config.yml" DATASET_DIR = ROOT / "dataset" CAPTIONS_DIR = ROOT / "dataset" / "captions.txt" IMAGES_DIR = ROOT / "dataset" / "Images" TRAINED_MODEL_DIR = ROOT / "trained_models" os.environ["WANDB_DISABLED"] = "true" class AppConfig(BaseModel): """ Application-level config. """ package_name: str training_data_file: str pipeline_name: str pipeline_save_file: str class ModelConfig(BaseModel): """ All configuration relevant to model training and feature engineering. """ target: str features: str ENCODER: str DECODER: str TRAIN_BATCH_SIZE: int VAL_BATCH_SIZE: int VAL_EPOCHS: int LR: float SEED: int MAX_LEN: int SUMMARY_LEN: int WEIGHT_DECAY: float MEAN: str STD: str TRAIN_PCT: float NUM_WORKERS: int EPOCHS: int IMG_SIZE: int LABEL_MASK: int TOP_K: int TOP_P: float EARLY_STOPPING: bool NGRAM_SIZE: int LEN_PENALTY: float NUM_BEAMS: int NUM_LOGGING_STEPS: int n_estimators: int max_depth: int class Config(BaseModel): """Master config object.""" app_config: AppConfig lmodel_config: ModelConfig def find_config_file() -> Path: """Locate the configuration file.""" if CONFIG_FILE_PATH.is_file(): return CONFIG_FILE_PATH raise Exception(f"Config not found at {CONFIG_FILE_PATH!r}") def fetch_config_from_yaml(cfg_path: Path = None) -> YAML: """Parse YAML containing the package configuration.""" if not cfg_path: cfg_path = find_config_file() if cfg_path: with open(cfg_path, "r") as conf_file: parsed_config = load(conf_file.read()) return parsed_config raise OSError(f"Did not find config file at path: {cfg_path}") def create_and_validate_config(parsed_config: YAML = None) -> Config: """Run validation on config values.""" if parsed_config is None: parsed_config = fetch_config_from_yaml() # specify the data attribute from the strictyaml YAML type. _config = Config( app_config= AppConfig(**parsed_config.data), lmodel_config= ModelConfig(**parsed_config.data), ) return _config config = create_and_validate_config()