File size: 2,861 Bytes
d425e71 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
"""main.py.
This module here is the entrypoint to the VLM Lens toolkit.
"""
import logging
from src.models.base import ModelBase
from src.models.config import Config, ModelSelection
def get_model(
model_arch: ModelSelection,
config: Config
) -> ModelBase:
"""Returns the model based on the selection enum chosen.
Args:
model_arch (ModelSelection): ModelSelection enum chosen for the specific architecture.
config (Config): The configuration object.
Returns:
ModelBase: A model of type ModelBase which implements the runtime
"""
if model_arch == ModelSelection.LLAVA:
from src.models.llava import LlavaModel
return LlavaModel(config)
elif model_arch == ModelSelection.QWEN:
from src.models.qwen import QwenModel
return QwenModel(config)
elif model_arch == ModelSelection.CLIP:
from src.models.clip import ClipModel
return ClipModel(config)
elif model_arch == ModelSelection.GLAMM:
from src.models.glamm import GlammModel
return GlammModel(config)
elif model_arch == ModelSelection.JANUS:
from src.models.janus import JanusModel
return JanusModel(config)
elif model_arch == ModelSelection.BLIP2:
from src.models.blip2 import Blip2Model
return Blip2Model(config)
elif model_arch == ModelSelection.MOLMO:
from src.models.molmo import MolmoModel
return MolmoModel(config)
elif model_arch == ModelSelection.PALIGEMMA:
from src.models.paligemma import PaligemmaModel
return PaligemmaModel(config)
elif model_arch == ModelSelection.INTERNLM_XC:
from src.models.internlm_xc import InternLMXComposerModel
return InternLMXComposerModel(config)
elif model_arch == ModelSelection.INTERNVL:
from src.models.internvl import InternVLModel
return InternVLModel(config)
elif model_arch == ModelSelection.MINICPM:
from src.models.minicpm import MiniCPMModel
return MiniCPMModel(config)
elif model_arch == ModelSelection.COGVLM:
from src.models.cogvlm import CogVLMModel
return CogVLMModel(config)
elif model_arch == ModelSelection.PIXTRAL:
from src.models.pixtral import PixtralModel
return PixtralModel(config)
elif model_arch == ModelSelection.AYA_VISION:
from src.models.aya_vision import AyaVisionModel
return AyaVisionModel(config)
elif model_arch == ModelSelection.PLM:
from src.models.plm import PlmModel
return PlmModel(config)
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
config = Config()
logging.debug(
f'Config is set to '
f'{[(key, value) for key, value in config.__dict__.items()]}'
)
model = get_model(config.architecture, config)
model.run()
|