ashwml's picture
Upload 233 files
5aa312d
# Path setup, and access the config.yml file, datasets folder & trained models
import sys
from pathlib import Path
from pydantic import BaseModel
from strictyaml import YAML, load
import os
from typing import List
file = Path(__file__).resolve()
parent, root = file.parent, file.parents[1]
sys.path.append(str(root))
# Project Directories
PACKAGE_ROOT = Path(__file__).resolve().parent
ROOT = PACKAGE_ROOT.parent
CONFIG_FILE_PATH = ROOT / "config.yml"
DATASET_DIR = ROOT / "dataset"
CAPTIONS_DIR = ROOT / "dataset" / "captions.txt"
IMAGES_DIR = ROOT / "dataset" / "Images"
TRAINED_MODEL_DIR = ROOT / "trained_models"
os.environ["WANDB_DISABLED"] = "true"
class AppConfig(BaseModel):
"""
Application-level config.
"""
package_name: str
training_data_file: str
pipeline_name: str
pipeline_save_file: str
class ModelConfig(BaseModel):
"""
All configuration relevant to model
training and feature engineering.
"""
target: str
features: str
ENCODER: str
DECODER: str
TRAIN_BATCH_SIZE: int
VAL_BATCH_SIZE: int
VAL_EPOCHS: int
LR: float
SEED: int
MAX_LEN: int
SUMMARY_LEN: int
WEIGHT_DECAY: float
MEAN: str
STD: str
TRAIN_PCT: float
NUM_WORKERS: int
EPOCHS: int
IMG_SIZE: int
LABEL_MASK: int
TOP_K: int
TOP_P: float
EARLY_STOPPING: bool
NGRAM_SIZE: int
LEN_PENALTY: float
NUM_BEAMS: int
NUM_LOGGING_STEPS: int
n_estimators: int
max_depth: int
class Config(BaseModel):
"""Master config object."""
app_config: AppConfig
lmodel_config: ModelConfig
def find_config_file() -> Path:
"""Locate the configuration file."""
if CONFIG_FILE_PATH.is_file():
return CONFIG_FILE_PATH
raise Exception(f"Config not found at {CONFIG_FILE_PATH!r}")
def fetch_config_from_yaml(cfg_path: Path = None) -> YAML:
"""Parse YAML containing the package configuration."""
if not cfg_path:
cfg_path = find_config_file()
if cfg_path:
with open(cfg_path, "r") as conf_file:
parsed_config = load(conf_file.read())
return parsed_config
raise OSError(f"Did not find config file at path: {cfg_path}")
def create_and_validate_config(parsed_config: YAML = None) -> Config:
"""Run validation on config values."""
if parsed_config is None:
parsed_config = fetch_config_from_yaml()
# specify the data attribute from the strictyaml YAML type.
_config = Config(
app_config= AppConfig(**parsed_config.data),
lmodel_config= ModelConfig(**parsed_config.data),
)
return _config
config = create_and_validate_config()