File size: 4,494 Bytes
b0c0df0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import importlib
import os
import sys
from typing import Literal

from loguru import logger

# os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

logger.remove()
# Configure logger with detailed format including file path, function name, and line number
log_format = "<green>{time:YYYY-MM-DD HH:mm:ss}</green> | " "<level>{level: <8}</level> | " "<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - " "<level>{message}</level>"
logger.add(sys.stdout, level="WARNING", format=log_format)


AVAILABLE_SIMPLE_MODELS = {
    "aero": "Aero",
    "plm": "PerceptionLM",
    "aria": "Aria",
    "auroracap": "AuroraCap",
    "batch_gpt4": "BatchGPT4",
    "claude": "Claude",
    "cogvlm2": "CogVLM2",
    "from_log": "FromLog",
    "fuyu": "Fuyu",
    "gemini_api": "GeminiAPI",
    "gpt4o_audio": "GPT4OAudio",
    "gemma3": "Gemma3",
    "gpt4v": "GPT4V",
    "idefics2": "Idefics2",
    "instructblip": "InstructBLIP",
    "internvideo2": "InternVideo2",
    "internvl": "InternVLChat",
    "internvl2": "InternVL2",
    "llama_vid": "LLaMAVid",
    "llama_vision": "LlamaVision",
    "llava": "Llava",
    "llava_hf": "LlavaHf",
    "llava_onevision": "Llava_OneVision",
    "llava_onevision1_5": "Llava_OneVision1_5",
    "llava_onevision_moviechat": "Llava_OneVision_MovieChat",
    "llava_sglang": "LlavaSglang",
    "llava_vid": "LlavaVid",
    "longva": "LongVA",
    "mantis": "Mantis",
    "minicpm_v": "MiniCPM_V",
    "minimonkey": "MiniMonkey",
    "moviechat": "MovieChat",
    "mplug_owl_video": "mplug_Owl",
    "ola": "Ola",
    "openai_compatible": "OpenAICompatible",
    "oryx": "Oryx",
    "phi3v": "Phi3v",
    "phi4_multimodal": "Phi4",
    "qwen2_5_omni": "Qwen2_5_Omni",
    "qwen2_5_vl": "Qwen2_5_VL",
    "qwen2_5_vl_interleave": "Qwen2_5_VL_Interleave",
    "qwen2_audio": "Qwen2_Audio",
    "qwen2_vl": "Qwen2_VL",
    "qwen_vl": "Qwen_VL",
    "qwen_vl_api": "Qwen_VL_API",
    "reka": "Reka",
    "ross": "Ross",
    "slime": "Slime",
    "srt_api": "SRT_API",
    "tinyllava": "TinyLlava",
    "videoChatGPT": "VideoChatGPT",
    "videochat2": "VideoChat2",
    "videollama3": "VideoLLaMA3",
    "video_llava": "VideoLLaVA",
    "vila": "VILA",
    "vita": "VITA",
    "vllm": "VLLM",
    "xcomposer2_4KHD": "XComposer2_4KHD",
    "xcomposer2d5": "XComposer2D5",
    "egogpt": "EgoGPT",
    "internvideo2_5": "InternVideo2_5",
    "videochat_flash": "VideoChat_Flash",
    "whisper": "Whisper",
    "whisper_vllm": "WhisperVllm",
    "vora": "VoRA",
}

AVAILABLE_CHAT_TEMPLATE_MODELS = {
    "llava_hf": "LlavaHf",
    "qwen2_5_vl": "Qwen2_5_VL",
    "thyme": "Thyme",
    "openai_compatible": "OpenAICompatible",
    "vllm": "VLLM",
    "vllm_generate": "VLLMGenerate",
    "sglang": "Sglang",
    "huggingface": "Huggingface",
    "async_openai": "AsyncOpenAIChat",
    "longvila": "LongVila",
}


def get_model(model_name, force_simple: bool = False):
    if model_name not in AVAILABLE_SIMPLE_MODELS and model_name not in AVAILABLE_CHAT_TEMPLATE_MODELS:
        raise ValueError(f"Model {model_name} not found in available models.")

    if model_name in AVAILABLE_CHAT_TEMPLATE_MODELS:
        model_type = "chat"
        AVAILABLE_MODELS = AVAILABLE_CHAT_TEMPLATE_MODELS
    else:
        model_type = "simple"
        AVAILABLE_MODELS = AVAILABLE_SIMPLE_MODELS

    # Override with force_simple if needed, but only if the model exists in AVAILABLE_SIMPLE_MODELS
    if force_simple and model_name in AVAILABLE_SIMPLE_MODELS:
        model_type = "simple"
        AVAILABLE_MODELS = AVAILABLE_SIMPLE_MODELS

    model_class = AVAILABLE_MODELS[model_name]
    if "." not in model_class:
        model_class = f"lmms_eval.models.{model_type}.{model_name}.{model_class}"

    try:
        model_module, model_class = model_class.rsplit(".", 1)
        module = __import__(model_module, fromlist=[model_class])
        return getattr(module, model_class)
    except Exception as e:
        logger.error(f"Failed to import {model_class} from {model_name}: {e}")
        raise


if os.environ.get("LMMS_EVAL_PLUGINS", None):
    # Allow specifying other packages to import models from
    for plugin in os.environ["LMMS_EVAL_PLUGINS"].split(","):
        m = importlib.import_module(f"{plugin}.models")
        # For plugin users, this will be replaced by chat template model later
        for model_name, model_class in getattr(m, "AVAILABLE_MODELS").items():
            AVAILABLE_SIMPLE_MODELS[model_name] = f"{plugin}.models.{model_name}.{model_class}"