Spaces:
Running
Running
| """ | |
| config.py | |
| ========= | |
| Backward-compatible configuration wrapper. | |
| This file now delegates to the per-model configs in configs/. | |
| Existing code that does `from config import CFG` will continue to work. | |
| Usage: | |
| from config import CFG | |
| cfg = CFG.load_from_env() # loads default (BLIP) config | |
| cfg = CFG.load_for_model("git") # loads GIT-specific config | |
| cfg.get_model_dir("blip") # β "./outputs/blip" | |
| """ | |
| import os | |
| from dataclasses import dataclass, field | |
| from typing import Literal | |
| from configs import get_config | |
| from configs.base_config import BaseConfig | |
| class CFG(BaseConfig): | |
| """ | |
| Master config that merges all fields across all model types. | |
| This exists for backward compatibility with app.py, eval.py, etc. | |
| """ | |
| # βββ Model Selection ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| vlm_type: Literal["blip", "vit_gpt2", "git", "custom"] = "blip" | |
| # βββ Model IDs (all models so app.py can reference any) βββββββββββββββββ | |
| model_id: str = "Salesforce/blip-image-captioning-base" | |
| vit_gpt2_model_id: str = "nlpconnect/vit-gpt2-image-captioning" | |
| git_model_id: str = "microsoft/git-base-coco" | |
| vit_encoder_id: str = "google/vit-base-patch16-224-in21k" | |
| # βββ Custom VLM (Shakespeare Decoder) βββββββββββββββββββββββββββββββββββ | |
| shakespeare_file: str = "./input.txt" | |
| shakespeare_weights_path: str = "./shakespeare_transformer.pt" | |
| text_embed_dim: int = 384 | |
| n_heads: int = 8 | |
| n_layers: int = 8 | |
| block_size: int = 256 | |
| dropout: float = 0.1 | |
| # βββ Unified Output βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # All checkpoints go under: outputs/{model}/best/ and outputs/{model}/latest/ | |
| output_root: str = "./outputs" | |
| def get_model_dir(self, model_name: str) -> str: | |
| """Return the output directory for a specific model: outputs/{model_name}/""" | |
| return os.path.join(self.output_root, model_name) | |
| def load_from_env(cls): | |
| """Load the default (backward-compat) config.""" | |
| return cls() | |
| def load_for_model(cls, model_type: str): | |
| """ | |
| Load a model-specific config from configs/ and merge into CFG. | |
| This lets train.py use optimized per-model hyperparameters while | |
| keeping the CFG dataclass compatible with the rest of the codebase. | |
| """ | |
| model_cfg = get_config(model_type) | |
| base = cls() | |
| # Overwrite fields that the model config provides | |
| for field_name in model_cfg.__dataclass_fields__: | |
| if hasattr(base, field_name): | |
| setattr(base, field_name, getattr(model_cfg, field_name)) | |
| return base | |