|
|
from .logger import get_logger |
|
|
|
|
|
logger = get_logger() |
|
|
|
|
|
MODEL_NAME_TO_CONFIG = { |
|
|
"OPEN_AI": "../configs/agent_configs/react_agent_gpt4_async.yaml", |
|
|
"AZURE_OPEN_AI": "../configs/agent_configs/react_agent_azureopenai_gpt_4_async.yaml", |
|
|
"AZURE_GPT35_TURBO": "../configs/agent_configs/react_agent_azureopenai_gpt_35_turbo_async.yaml", |
|
|
"AZURE_GPT4": "../configs/agent_configs/react_agent_azureopenai_gpt_4_async.yaml", |
|
|
"LLAMA": "../configs/agent_configs/react_agent_llama_async.yaml", |
|
|
"OPT": "../configs/agent_configs/react_agent_opt_async.yaml", |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
def get_model_config_path(input_model_name): |
|
|
if input_model_name is None: |
|
|
model_name = "openai" |
|
|
else: |
|
|
model_name = input_model_name |
|
|
|
|
|
|
|
|
if model_name in MODEL_NAME_TO_CONFIG: |
|
|
return MODEL_NAME_TO_CONFIG[model_name] |
|
|
|
|
|
|
|
|
if model_name.upper() in MODEL_NAME_TO_CONFIG: |
|
|
return MODEL_NAME_TO_CONFIG[model_name.upper()] |
|
|
|
|
|
if "openai" in model_name: |
|
|
return MODEL_NAME_TO_CONFIG["AZURE_OPEN_AI"] |
|
|
|
|
|
elif "llama" in model_name: |
|
|
return MODEL_NAME_TO_CONFIG["LLAMA"] |
|
|
elif "opt" in model_name: |
|
|
return MODEL_NAME_TO_CONFIG["OPT"] |
|
|
else: |
|
|
logger.warning("unknown model name, use official.") |
|
|
return MODEL_NAME_TO_CONFIG["AZURE_OPEN_AI"] |
|
|
|