File size: 2,696 Bytes
5aa312d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# Path setup, and access the config.yml file, datasets folder & trained models
import sys
from pathlib import Path
from pydantic import BaseModel
from strictyaml import YAML, load
import os
from typing import List

file = Path(__file__).resolve()
parent, root = file.parent, file.parents[1]
sys.path.append(str(root))

# Project Directories
PACKAGE_ROOT = Path(__file__).resolve().parent
ROOT = PACKAGE_ROOT.parent
CONFIG_FILE_PATH = ROOT / "config.yml"
DATASET_DIR = ROOT / "dataset"
CAPTIONS_DIR = ROOT / "dataset" / "captions.txt"
IMAGES_DIR = ROOT / "dataset" / "Images"
TRAINED_MODEL_DIR = ROOT / "trained_models"

os.environ["WANDB_DISABLED"] = "true"
class AppConfig(BaseModel):
    """
    Application-level config.
    """
    package_name: str
    training_data_file: str
    pipeline_name: str
    pipeline_save_file: str


class ModelConfig(BaseModel):
    """
    All configuration relevant to model
    training and feature engineering.
    """
    target: str
    features: str

    ENCODER: str
    DECODER: str
    TRAIN_BATCH_SIZE: int
    VAL_BATCH_SIZE: int
    VAL_EPOCHS: int
    LR: float
    SEED: int
    MAX_LEN: int
    SUMMARY_LEN: int
    WEIGHT_DECAY: float
    MEAN: str
    STD: str
    TRAIN_PCT: float
    NUM_WORKERS: int
    EPOCHS: int
    IMG_SIZE: int
    LABEL_MASK: int
    TOP_K: int
    TOP_P: float
    EARLY_STOPPING: bool
    NGRAM_SIZE: int
    LEN_PENALTY: float
    NUM_BEAMS: int
    NUM_LOGGING_STEPS: int

    n_estimators: int
    max_depth: int


class Config(BaseModel):
    """Master config object."""

    app_config: AppConfig
    lmodel_config: ModelConfig


def find_config_file() -> Path:
    """Locate the configuration file."""
    
    if CONFIG_FILE_PATH.is_file():
        return CONFIG_FILE_PATH
    
    raise Exception(f"Config not found at {CONFIG_FILE_PATH!r}")


def fetch_config_from_yaml(cfg_path: Path = None) -> YAML:
    """Parse YAML containing the package configuration."""

    if not cfg_path:
        cfg_path = find_config_file()

    if cfg_path:
        with open(cfg_path, "r") as conf_file:
            parsed_config = load(conf_file.read())
            return parsed_config
        
    raise OSError(f"Did not find config file at path: {cfg_path}")


def create_and_validate_config(parsed_config: YAML = None) -> Config:
    """Run validation on config values."""
    if parsed_config is None:
        parsed_config = fetch_config_from_yaml()

    # specify the data attribute from the strictyaml YAML type.
    _config = Config(
        app_config= AppConfig(**parsed_config.data),
        lmodel_config= ModelConfig(**parsed_config.data),
    )

    return _config


config = create_and_validate_config()