File size: 2,933 Bytes
e94400c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 | """
Framework factory utilities.
Automatically builds registered framework implementations
based on configuration.
Each framework module (e.g., M1.py, QwenFast.py) should register itself:
from starVLA.model.framework.framework_registry import FRAMEWORK_REGISTRY
@FRAMEWORK_REGISTRY.register("InternVLA-M1")
def build_model_framework(config):
return InternVLA_M1(config=config)
"""
import pkgutil
import importlib
from omegaconf import OmegaConf
from starVLA.model.tools import FRAMEWORK_REGISTRY
from starVLA.training.trainer_utils import initialize_overwatch
logger = initialize_overwatch(__name__)
try:
pkg_path = __path__
except NameError:
pkg_path = None
# Auto-import all framework submodules to trigger registration
if pkg_path is not None:
for _, module_name, _ in pkgutil.iter_modules(pkg_path):
if module_name.startswith("_"):
continue
try:
importlib.import_module(f"{__name__}.{module_name}")
except Exception as e:
logger.warning(f"Failed to auto-import framework submodule `{module_name}`: {e}")
def build_framework(cfg):
"""
Build a framework model from config.
Args:
cfg: Config object (OmegaConf / namespace) containing:
cfg.framework.name: Identifier string (e.g. "InternVLA-M1")
Returns:
nn.Module: Instantiated framework model.
"""
if not hasattr(cfg, "framework"):
raise ValueError("Missing `cfg.framework` in configuration.")
framework_id = getattr(cfg.framework, "name", None)
if not framework_id:
framework_id = getattr(cfg.framework, "framework_py", None) # Backward compatibility for legacy config yaml
if framework_id:
cfg.framework.name = framework_id
if not framework_id:
raise ValueError("Missing framework identifier. Set `cfg.framework.name` (or legacy `framework_py`).")
if framework_id == "ActionModelFM":
from starVLA.model.modules.action_model.ActionModel_FM import ActionModelFM
from starVLA.model.modules.action_model.configuration_actionmodel import ActionModelConfig
action_model_cfg = getattr(cfg.framework, "action_model", None)
if action_model_cfg is None:
config = ActionModelConfig()
else:
action_model_kwargs = OmegaConf.to_container(action_model_cfg, resolve=True)
config = ActionModelConfig(**action_model_kwargs)
return ActionModelFM(config)
elif framework_id == "QwenLatent":
from starVLA.model.framework.QwenLatent import QwenLatent
return QwenLatent(cfg)
elif framework_id == "QwenLatent_history":
from starVLA.model.framework.QwenLatent_history import QwenLatentHistory
return QwenLatentHistory(cfg)
model_class = FRAMEWORK_REGISTRY[framework_id]
return model_class(cfg)
__all__ = ["build_framework", "FRAMEWORK_REGISTRY"]
|