File size: 2,933 Bytes
e94400c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""
Framework factory utilities.
Automatically builds registered framework implementations
based on configuration.

Each framework module (e.g., M1.py, QwenFast.py) should register itself:
    from starVLA.model.framework.framework_registry import FRAMEWORK_REGISTRY

    @FRAMEWORK_REGISTRY.register("InternVLA-M1")
    def build_model_framework(config):
        return InternVLA_M1(config=config)
"""

import pkgutil
import importlib
from omegaconf import OmegaConf
from starVLA.model.tools import FRAMEWORK_REGISTRY

from starVLA.training.trainer_utils import initialize_overwatch

logger = initialize_overwatch(__name__)

try:
    pkg_path = __path__
except NameError:
    pkg_path = None

# Auto-import all framework submodules to trigger registration
if pkg_path is not None:
    for _, module_name, _ in pkgutil.iter_modules(pkg_path):
        if module_name.startswith("_"):
            continue
        try:
            importlib.import_module(f"{__name__}.{module_name}")
        except Exception as e:
            logger.warning(f"Failed to auto-import framework submodule `{module_name}`: {e}")
        
def build_framework(cfg):
    """
    Build a framework model from config.
    Args:
        cfg: Config object (OmegaConf / namespace) containing:
             cfg.framework.name: Identifier string (e.g. "InternVLA-M1")
    Returns:
        nn.Module: Instantiated framework model.
    """

    if not hasattr(cfg, "framework"):
        raise ValueError("Missing `cfg.framework` in configuration.")

    framework_id = getattr(cfg.framework, "name", None)
    if not framework_id:
        framework_id = getattr(cfg.framework, "framework_py", None)  # Backward compatibility for legacy config yaml
        if framework_id:
            cfg.framework.name = framework_id

    if not framework_id:
        raise ValueError("Missing framework identifier. Set `cfg.framework.name` (or legacy `framework_py`).")

    if framework_id == "ActionModelFM":
        from starVLA.model.modules.action_model.ActionModel_FM import ActionModelFM
        from starVLA.model.modules.action_model.configuration_actionmodel import ActionModelConfig
        action_model_cfg = getattr(cfg.framework, "action_model", None)
        if action_model_cfg is None:
            config = ActionModelConfig()
        else:
            action_model_kwargs = OmegaConf.to_container(action_model_cfg, resolve=True)
            config = ActionModelConfig(**action_model_kwargs)
        return ActionModelFM(config)
    elif framework_id == "QwenLatent":
        from starVLA.model.framework.QwenLatent import QwenLatent
        return QwenLatent(cfg)
    elif framework_id == "QwenLatent_history":
        from starVLA.model.framework.QwenLatent_history import QwenLatentHistory
        return QwenLatentHistory(cfg)
    model_class = FRAMEWORK_REGISTRY[framework_id]
    return model_class(cfg)

__all__ = ["build_framework", "FRAMEWORK_REGISTRY"]