File size: 3,042 Bytes
c374021
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""
config.py
=========
Backward-compatible configuration wrapper.

This file now delegates to the per-model configs in configs/.
Existing code that does `from config import CFG` will continue to work.

Usage:
    from config import CFG
    cfg = CFG.load_from_env()          # loads default (BLIP) config
    cfg = CFG.load_for_model("git")    # loads GIT-specific config
    cfg.get_model_dir("blip")          # β†’ "./outputs/blip"
"""

import os
from dataclasses import dataclass, field
from typing import Literal

from configs import get_config
from configs.base_config import BaseConfig


@dataclass
class CFG(BaseConfig):
    """
    Master config that merges all fields across all model types.
    This exists for backward compatibility with app.py, eval.py, etc.
    """
    # ─── Model Selection ────────────────────────────────────────────────────
    vlm_type: Literal["blip", "vit_gpt2", "git", "custom"] = "blip"

    # ─── Model IDs (all models so app.py can reference any) ─────────────────
    model_id: str = "Salesforce/blip-image-captioning-base"
    vit_gpt2_model_id: str = "nlpconnect/vit-gpt2-image-captioning"
    git_model_id: str = "microsoft/git-base-coco"
    vit_encoder_id: str = "google/vit-base-patch16-224-in21k"

    # ─── Custom VLM (Shakespeare Decoder) ───────────────────────────────────
    shakespeare_file: str = "./input.txt"
    shakespeare_weights_path: str = "./shakespeare_transformer.pt"
    text_embed_dim: int = 384
    n_heads: int = 8
    n_layers: int = 8
    block_size: int = 256
    dropout: float = 0.1

    # ─── Unified Output ─────────────────────────────────────────────────────
    # All checkpoints go under: outputs/{model}/best/ and outputs/{model}/latest/
    output_root: str = "./outputs"

    def get_model_dir(self, model_name: str) -> str:
        """Return the output directory for a specific model: outputs/{model_name}/"""
        return os.path.join(self.output_root, model_name)

    @classmethod
    def load_from_env(cls):
        """Load the default (backward-compat) config."""
        return cls()

    @classmethod
    def load_for_model(cls, model_type: str):
        """
        Load a model-specific config from configs/ and merge into CFG.

        This lets train.py use optimized per-model hyperparameters while
        keeping the CFG dataclass compatible with the rest of the codebase.
        """
        model_cfg = get_config(model_type)
        base = cls()
        # Overwrite fields that the model config provides
        for field_name in model_cfg.__dataclass_fields__:
            if hasattr(base, field_name):
                setattr(base, field_name, getattr(model_cfg, field_name))
        return base