Spaces:
Sleeping
Sleeping
File size: 7,706 Bytes
d12a6df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
from typing import Any
import os
def create_llm_engine(model_string: str, use_cache: bool = False, is_multimodal: bool = False, **kwargs) -> Any:
print(f"Creating LLM engine for model: {model_string}")
"""
Factory function to create appropriate LLM engine instance.
For supported models and model_string examples, see:
https://github.com/lupantech/AgentFlow/blob/main/assets/doc/llm_engine.md
- Uses kwargs.get() instead of setdefault
- Only passes supported parameters to each backend
- Handles frequency_penalty, presence_penalty, repetition_penalty per backend
- External parameters (temperature, top_p) are respected if provided
"""
original_model_string = model_string
print(f"creating llm engine {model_string} with: is_multimodal: {is_multimodal}, kwargs: {kwargs}")
# === Azure OpenAI ===
if "azure" in model_string:
from .azure import ChatAzureOpenAI
model_string = model_string.replace("azure-", "")
# Azure supports: temperature, top_p, frequency_penalty, presence_penalty
config = {
"model_string": model_string,
"use_cache": use_cache,
"is_multimodal": is_multimodal,
"temperature": kwargs.get("temperature", 0.7),
"top_p": kwargs.get("top_p", 0.9),
"frequency_penalty": kwargs.get("frequency_penalty", 0.5),
"presence_penalty": kwargs.get("presence_penalty", 0.5),
}
return ChatAzureOpenAI(**config)
# === OpenAI (GPT) ===
elif any(x in model_string for x in ["gpt", "o1", "o3", "o4"]):
from .openai import ChatOpenAI
config = {
"model_string": model_string,
"use_cache": use_cache,
"is_multimodal": is_multimodal,
"temperature": kwargs.get("temperature", 0.7),
"top_p": kwargs.get("top_p", 0.9),
"frequency_penalty": kwargs.get("frequency_penalty", 0.5),
"presence_penalty": kwargs.get("presence_penalty", 0.5),
}
return ChatOpenAI(**config)
# === DashScope (Qwen) ===
elif "dashscope" in model_string:
from .dashscope import ChatDashScope
# DashScope uses temperature, top_p — but not frequency/presence_penalty
config = {
"model_string": model_string,
"use_cache": use_cache,
"is_multimodal": is_multimodal,
"temperature": kwargs.get("temperature", 0.7),
"top_p": kwargs.get("top_p", 0.9),
}
return ChatDashScope(**config)
# === Anthropic (Claude) ===
elif "claude" in model_string:
from .anthropic import ChatAnthropic
if "ANTHROPIC_API_KEY" not in os.environ:
raise ValueError("Please set the ANTHROPIC_API_KEY environment variable.")
# Anthropic supports: temperature, top_p, top_k — NOT frequency/presence_penalty
config = {
"model_string": model_string,
"use_cache": use_cache,
"is_multimodal": is_multimodal,
"temperature": kwargs.get("temperature", 0.7),
"top_p": kwargs.get("top_p", 0.9),
"top_k": kwargs.get("top_k", 50), # optional
}
return ChatAnthropic(**config)
# === DeepSeek ===
elif any(x in model_string for x in ["deepseek-chat", "deepseek-reasoner"]):
from .deepseek import ChatDeepseek
# DeepSeek uses repetition_penalty, not frequency/presence
config = {
"model_string": model_string,
"use_cache": use_cache,
"is_multimodal": is_multimodal,
}
return ChatDeepseek(**config)
# === Gemini ===
elif "gemini" in model_string:
print("gemini model found")
from .gemini import ChatGemini
# Gemini uses repetition_penalty
config = {
"model_string": model_string,
"use_cache": use_cache,
"is_multimodal": is_multimodal,
}
return ChatGemini(**config)
# === Grok (xAI) ===
elif "grok" in model_string:
from .xai import ChatGrok
if "GROK_API_KEY" not in os.environ:
raise ValueError("Please set the GROK_API_KEY environment variable.")
# Assume Grok uses repetition_penalty
config = {
"model_string": model_string,
"use_cache": use_cache,
"is_multimodal": is_multimodal,
"temperature": kwargs.get("temperature", 0.7),
"top_p": kwargs.get("top_p", 0.9),
"repetition_penalty": kwargs.get("repetition_penalty", 1.2),
}
return ChatGrok(**config)
# === vLLM ===
elif "vllm" in model_string:
from .vllm import ChatVLLM
model_string = model_string.replace("vllm-", "")
config = {
"model_string": model_string,
"base_url": kwargs.get("base_url", "http://localhost:8000/v1"), # TODO: check the RL training initialized port and name
"use_cache": use_cache,
"is_multimodal": is_multimodal,
"temperature": kwargs.get("temperature", 0.7),
"top_p": kwargs.get("top_p", 0.9),
"frequency_penalty": kwargs.get("frequency_penalty", 1.2),
"max_model_len": kwargs.get("max_model_len", 15200),
"max_seq_len_to_capture": kwargs.get("max_seq_len_to_capture", 15200),
}
print("serving ")
return ChatVLLM(**config)
# === LiteLLM ===
elif "litellm" in model_string:
from .litellm import ChatLiteLLM
model_string = model_string.replace("litellm-", "")
# LiteLLM supports frequency/presence_penalty as routing params
config = {
"model_string": model_string,
"use_cache": use_cache,
"is_multimodal": is_multimodal,
"temperature": kwargs.get("temperature", 0.7),
"top_p": kwargs.get("top_p", 0.9),
"frequency_penalty": kwargs.get("frequency_penalty", 0.5),
"presence_penalty": kwargs.get("presence_penalty", 0.5),
}
return ChatLiteLLM(**config)
# === Together AI ===
elif "together" in model_string:
from .together import ChatTogether
if "TOGETHER_API_KEY" not in os.environ:
raise ValueError("Please set the TOGETHER_API_KEY environment variable.")
model_string = model_string.replace("together-", "")
config = {
"model_string": model_string,
"use_cache": use_cache,
"is_multimodal": is_multimodal,
}
return ChatTogether(**config)
# === Ollama ===
elif "ollama" in model_string:
from .ollama import ChatOllama
model_string = model_string.replace("ollama-", "")
config = {
"model_string": model_string,
"use_cache": use_cache,
"is_multimodal": is_multimodal,
"temperature": kwargs.get("temperature", 0.7),
"top_p": kwargs.get("top_p", 0.9),
"repetition_penalty": kwargs.get("repetition_penalty", 1.2),
}
return ChatOllama(**config)
else:
raise ValueError(
f"Engine {original_model_string} not supported. "
"If you are using Azure OpenAI models, please ensure the model string has the prefix 'azure-'. "
"For Together models, use 'together-'. For VLLM models, use 'vllm-'. For LiteLLM models, use 'litellm-'. "
"For Ollama models, use 'ollama-'. "
"For other custom engines, you can edit the factory.py file and add its interface file. "
"Your pull request will be warmly welcomed!"
) |