File size: 917 Bytes
c374021
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
"""
configs/__init__.py
===================
Config package — exposes a get_config() factory function.
"""

from .base_config import BaseConfig
from .blip_config import BlipConfig
from .vit_gpt2_config import ViTGPT2Config
from .git_config import GitConfig
from .custom_vlm_config import CustomVLMConfig


def get_config(model_type: str):
    """
    Return the appropriate config dataclass for the given model type.

    Args:
        model_type: one of 'blip', 'vit_gpt2', 'git', 'custom'

    Returns:
        Populated config dataclass instance.
    """
    registry = {
        "blip": BlipConfig,
        "vit_gpt2": ViTGPT2Config,
        "git": GitConfig,
        "custom": CustomVLMConfig,
    }
    cls = registry.get(model_type)
    if cls is None:
        raise ValueError(
            f"Unknown model_type '{model_type}'. "
            f"Choose from: {list(registry.keys())}"
        )
    return cls()