File size: 964 Bytes
f1f682e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# models/__init__.py

from .base_model import BaseModel
from .qwen_2d5_omni_7b import Qwen2_5Omni7B
from .vllm_client import VLLMClient
from .uno_scorer_hf import UNOScorerHF

# Model name to class mapping
MODEL_REGISTRY = {
    "Qwen-2.5-Omni-7B": Qwen2_5Omni7B,
    "VLLMClient": VLLMClient,
    "UNOScorerHF":UNOScorerHF
}

def get_model(model_name: str, model_path: str, **kwargs) -> BaseModel:
    """
    Model factory function.

    :param model_name: Registered model name (e.g., 'qwen_api')
    :param model_path: Actual model path or API identifier (e.g., 'qwen-turbo')
    :param kwargs: Other model initialization parameters
    :return: Instance of BaseModel
    """
    if model_name not in MODEL_REGISTRY:
        raise ValueError(f"Unknown model name: {model_name}. Available models: {list(MODEL_REGISTRY.keys())}")
        
    model_class = MODEL_REGISTRY[model_name]
    return model_class(model_name=model_name, model_path=model_path, **kwargs)