# models/__init__.pyfrom .base_model import BaseModel
from .qwen_2d5_omni_7b import Qwen2_5Omni7B
from .vllm_client import VLLMClient
from .uno_scorer_hf import UNOScorerHF
# Model name to class mapping
MODEL_REGISTRY = {
"Qwen-2.5-Omni-7B": Qwen2_5Omni7B,
"VLLMClient": VLLMClient,
"UNOScorerHF":UNOScorerHF
}
defget_model(model_name: str, model_path: str, **kwargs) -> BaseModel:
""" Model factory function. :param model_name: Registered model name (e.g., 'qwen_api') :param model_path: Actual model path or API identifier (e.g., 'qwen-turbo') :param kwargs: Other model initialization parameters :return: Instance of BaseModel """if model_name notin MODEL_REGISTRY:
raise ValueError(f"Unknown model name: {model_name}. Available models: {list(MODEL_REGISTRY.keys())}")
model_class = MODEL_REGISTRY[model_name]
return model_class(model_name=model_name, model_path=model_path, **kwargs)