VLM-Lens / src /main.py
marstin's picture
[martin-dev] add demo v1 test
d425e71
raw
history blame
2.86 kB
"""main.py.
This module here is the entrypoint to the VLM Lens toolkit.
"""
import logging
from src.models.base import ModelBase
from src.models.config import Config, ModelSelection
def get_model(
model_arch: ModelSelection,
config: Config
) -> ModelBase:
"""Returns the model based on the selection enum chosen.
Args:
model_arch (ModelSelection): ModelSelection enum chosen for the specific architecture.
config (Config): The configuration object.
Returns:
ModelBase: A model of type ModelBase which implements the runtime
"""
if model_arch == ModelSelection.LLAVA:
from src.models.llava import LlavaModel
return LlavaModel(config)
elif model_arch == ModelSelection.QWEN:
from src.models.qwen import QwenModel
return QwenModel(config)
elif model_arch == ModelSelection.CLIP:
from src.models.clip import ClipModel
return ClipModel(config)
elif model_arch == ModelSelection.GLAMM:
from src.models.glamm import GlammModel
return GlammModel(config)
elif model_arch == ModelSelection.JANUS:
from src.models.janus import JanusModel
return JanusModel(config)
elif model_arch == ModelSelection.BLIP2:
from src.models.blip2 import Blip2Model
return Blip2Model(config)
elif model_arch == ModelSelection.MOLMO:
from src.models.molmo import MolmoModel
return MolmoModel(config)
elif model_arch == ModelSelection.PALIGEMMA:
from src.models.paligemma import PaligemmaModel
return PaligemmaModel(config)
elif model_arch == ModelSelection.INTERNLM_XC:
from src.models.internlm_xc import InternLMXComposerModel
return InternLMXComposerModel(config)
elif model_arch == ModelSelection.INTERNVL:
from src.models.internvl import InternVLModel
return InternVLModel(config)
elif model_arch == ModelSelection.MINICPM:
from src.models.minicpm import MiniCPMModel
return MiniCPMModel(config)
elif model_arch == ModelSelection.COGVLM:
from src.models.cogvlm import CogVLMModel
return CogVLMModel(config)
elif model_arch == ModelSelection.PIXTRAL:
from src.models.pixtral import PixtralModel
return PixtralModel(config)
elif model_arch == ModelSelection.AYA_VISION:
from src.models.aya_vision import AyaVisionModel
return AyaVisionModel(config)
elif model_arch == ModelSelection.PLM:
from src.models.plm import PlmModel
return PlmModel(config)
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
config = Config()
logging.debug(
f'Config is set to '
f'{[(key, value) for key, value in config.__dict__.items()]}'
)
model = get_model(config.architecture, config)
model.run()