Spaces:
No application file
No application file
| # Path setup, and access the config.yml file, datasets folder & trained models | |
| import sys | |
| from pathlib import Path | |
| from pydantic import BaseModel | |
| from strictyaml import YAML, load | |
| import os | |
| from typing import List | |
| file = Path(__file__).resolve() | |
| parent, root = file.parent, file.parents[1] | |
| sys.path.append(str(root)) | |
| # Project Directories | |
| PACKAGE_ROOT = Path(__file__).resolve().parent | |
| ROOT = PACKAGE_ROOT.parent | |
| CONFIG_FILE_PATH = ROOT / "config.yml" | |
| DATASET_DIR = ROOT / "dataset" | |
| CAPTIONS_DIR = ROOT / "dataset" / "captions.txt" | |
| IMAGES_DIR = ROOT / "dataset" / "Images" | |
| TRAINED_MODEL_DIR = ROOT / "trained_models" | |
| os.environ["WANDB_DISABLED"] = "true" | |
| class AppConfig(BaseModel): | |
| """ | |
| Application-level config. | |
| """ | |
| package_name: str | |
| training_data_file: str | |
| pipeline_name: str | |
| pipeline_save_file: str | |
| class ModelConfig(BaseModel): | |
| """ | |
| All configuration relevant to model | |
| training and feature engineering. | |
| """ | |
| target: str | |
| features: str | |
| ENCODER: str | |
| DECODER: str | |
| TRAIN_BATCH_SIZE: int | |
| VAL_BATCH_SIZE: int | |
| VAL_EPOCHS: int | |
| LR: float | |
| SEED: int | |
| MAX_LEN: int | |
| SUMMARY_LEN: int | |
| WEIGHT_DECAY: float | |
| MEAN: str | |
| STD: str | |
| TRAIN_PCT: float | |
| NUM_WORKERS: int | |
| EPOCHS: int | |
| IMG_SIZE: int | |
| LABEL_MASK: int | |
| TOP_K: int | |
| TOP_P: float | |
| EARLY_STOPPING: bool | |
| NGRAM_SIZE: int | |
| LEN_PENALTY: float | |
| NUM_BEAMS: int | |
| NUM_LOGGING_STEPS: int | |
| n_estimators: int | |
| max_depth: int | |
| class Config(BaseModel): | |
| """Master config object.""" | |
| app_config: AppConfig | |
| lmodel_config: ModelConfig | |
| def find_config_file() -> Path: | |
| """Locate the configuration file.""" | |
| if CONFIG_FILE_PATH.is_file(): | |
| return CONFIG_FILE_PATH | |
| raise Exception(f"Config not found at {CONFIG_FILE_PATH!r}") | |
| def fetch_config_from_yaml(cfg_path: Path = None) -> YAML: | |
| """Parse YAML containing the package configuration.""" | |
| if not cfg_path: | |
| cfg_path = find_config_file() | |
| if cfg_path: | |
| with open(cfg_path, "r") as conf_file: | |
| parsed_config = load(conf_file.read()) | |
| return parsed_config | |
| raise OSError(f"Did not find config file at path: {cfg_path}") | |
| def create_and_validate_config(parsed_config: YAML = None) -> Config: | |
| """Run validation on config values.""" | |
| if parsed_config is None: | |
| parsed_config = fetch_config_from_yaml() | |
| # specify the data attribute from the strictyaml YAML type. | |
| _config = Config( | |
| app_config= AppConfig(**parsed_config.data), | |
| lmodel_config= ModelConfig(**parsed_config.data), | |
| ) | |
| return _config | |
| config = create_and_validate_config() | |