project_02_DS / config.py
griddev's picture
first push
c374021
"""
config.py
=========
Backward-compatible configuration wrapper.
This file now delegates to the per-model configs in configs/.
Existing code that does `from config import CFG` will continue to work.
Usage:
from config import CFG
cfg = CFG.load_from_env() # loads default (BLIP) config
cfg = CFG.load_for_model("git") # loads GIT-specific config
cfg.get_model_dir("blip") # β†’ "./outputs/blip"
"""
import os
from dataclasses import dataclass, field
from typing import Literal
from configs import get_config
from configs.base_config import BaseConfig
@dataclass
class CFG(BaseConfig):
"""
Master config that merges all fields across all model types.
This exists for backward compatibility with app.py, eval.py, etc.
"""
# ─── Model Selection ────────────────────────────────────────────────────
vlm_type: Literal["blip", "vit_gpt2", "git", "custom"] = "blip"
# ─── Model IDs (all models so app.py can reference any) ─────────────────
model_id: str = "Salesforce/blip-image-captioning-base"
vit_gpt2_model_id: str = "nlpconnect/vit-gpt2-image-captioning"
git_model_id: str = "microsoft/git-base-coco"
vit_encoder_id: str = "google/vit-base-patch16-224-in21k"
# ─── Custom VLM (Shakespeare Decoder) ───────────────────────────────────
shakespeare_file: str = "./input.txt"
shakespeare_weights_path: str = "./shakespeare_transformer.pt"
text_embed_dim: int = 384
n_heads: int = 8
n_layers: int = 8
block_size: int = 256
dropout: float = 0.1
# ─── Unified Output ─────────────────────────────────────────────────────
# All checkpoints go under: outputs/{model}/best/ and outputs/{model}/latest/
output_root: str = "./outputs"
def get_model_dir(self, model_name: str) -> str:
"""Return the output directory for a specific model: outputs/{model_name}/"""
return os.path.join(self.output_root, model_name)
@classmethod
def load_from_env(cls):
"""Load the default (backward-compat) config."""
return cls()
@classmethod
def load_for_model(cls, model_type: str):
"""
Load a model-specific config from configs/ and merge into CFG.
This lets train.py use optimized per-model hyperparameters while
keeping the CFG dataclass compatible with the rest of the codebase.
"""
model_cfg = get_config(model_type)
base = cls()
# Overwrite fields that the model config provides
for field_name in model_cfg.__dataclass_fields__:
if hasattr(base, field_name):
setattr(base, field_name, getattr(model_cfg, field_name))
return base