amps / utils /api_utils.py
jibsn's picture
Update utils/api_utils.py
4d14429 verified
import os
import asyncio
from loguru import logger
from openai import OpenAI
from functools import partial
from typing import Callable
CLIENTS = {
"glm-4-plus": {
"api_key": os.environ.get("API_GLM_4_PLUS"),
"base_url": "https://open.bigmodel.cn/api/paas/v4",
},
"glm-4": {
"api_key": os.environ.get("API_GLM_4"),
"base_url": "https://open.bigmodel.cn/api/paas/v4",
},
"glm-4-airx": {
"api_key": os.environ.get("API_GLM_4_AIRX"),
"base_url": "https://open.bigmodel.cn/api/paas/v4",
},
"glm-4-flash": {
"api_key": os.environ.get("API_GLM_4_FLASH"),
"base_url": "https://open.bigmodel.cn/api/paas/v4",
},
"gpt-4o-mini": {
"api_key": os.environ.get("API_GPT_4O_MINI"),
"base_url": "https://api.qqslyx.com/v1",
},
"deepseek-chat": {
"api_key": os.environ.get("API_DEEPSEEK_CHAT"),
"base_url": "https://api.deepseek.com/v1"
},
"deepseek-v3-250324": {
"api_key": os.environ.get("API_DEEPSEEK_V3"),
"base_url": "https://ark.cn-beijing.volces.com/api/v3"
}
}
def get_chat_func(model_names: list[str]):
"""
Get a list of chat functions for the specified model names.
Args:
model_names (list[str]): A list of model names.
Returns:
list[Callable]: A list of chat functions.
"""
chat_funcs = []
for model_name in model_names:
if model_name not in list(CLIENTS.keys()):
continue
chat_funcs.append(partial(chat_completion, model_name=model_name))
return chat_funcs
async def chat_completion(prompt: str, model_name: str) -> str:
"""
Perform a chat completion using the specified model.
Args:
prompt (str): The prompt to send to the model.
model_name (str): The name of the model to use.
client (OpenAI, optional): The OpenAI client to use. Defaults to None.
Returns:
str: The response from the model.
"""
assert model_name in list(CLIENTS.keys()), f"Model {model_name} not found"
API_KEY = CLIENTS[model_name]["api_key"]
BASE_URL = CLIENTS[model_name]["base_url"]
client = OpenAI(api_key=API_KEY, base_url=BASE_URL)
completion = client.chat.completions.create(
model=model_name,
messages=[
{"role": "user", "content": prompt}
]
)
return completion
async def retry_operation(func, task, max_retries=5, delay=0.5, *args, **kwargs):
"""
Retry an operation asynchronously with exponential backoff.
Args:
func (Callable): The function to be retried.
task (Task): The task object to update the status.
max_retries (int, optional): The maximum number of retries. Defaults to 5.
delay (float, optional): The initial delay between retries. Defaults to 0.5.
*args: Additional positional arguments to pass to the function.
**kwargs: Additional keyword arguments to pass to the function.
Returns:
Any: The result of the operation.
"""
retries = 0
exceptions = []
while retries < max_retries:
# return await func(*args, **kwargs)
try:
return await func(*args, **kwargs), None
except Exception as e:
exceptions.append(f"retry {retries}: {e}")
retries += 1
logger.error(e)
await asyncio.sleep(delay * retries)
continue
return None, "\n".join(exceptions)
async def chat_completion_multiple_models(
prompt: str,
model_names: list[str] = [],
chat_funcs: list[Callable] = []
):
"""
Perform a chat completion using multiple models asynchronously.
Args:
prompt (str): The prompt to send to the models.
model_names (list[str], optional): A list of model names. Defaults to [].
chat_funcs (list[Callable], optional): A list of chat functions. Defaults to [].
Returns:
list[Any]: A list of results from the chat completions.
"""
if not chat_funcs or len(chat_funcs) == 0:
chat_funcs = get_chat_func(model_names)
return await asyncio.gather(
*(chat_func(prompt=prompt)
for chat_func in chat_funcs)
)
async def func_wrap_multiple_models(
wrap_func: Callable,
model_names: list[str] = [],
chat_funcs: list[Callable] = [],
model_weights: list[float] = [],
*args,
):
"""
Wrap a function to be executed asynchronously with multiple models.
Args:
func (Callable): The function to be wrapped.
model_names (list[str], optional): A list of model names. Defaults to [].
chat_funcs (list[Callable], optional): A list of chat functions. Defaults to [].
model_weights (list[float], optional): A list of model weights. Defaults to [].
*args: Additional positional arguments to pass to the function.
Returns:
list[Any]: A list of results from the function.
"""
if not chat_funcs or len(chat_funcs) == 0:
chat_funcs = get_chat_func(model_names)
if not model_weights:
model_weights = [1.0 for _ in range(len(chat_funcs))]
assert len(chat_funcs) == len(model_weights), \
"model_weights must be same length as chat_funcs"
return await asyncio.gather(
*(wrap_func(*args, chat_func=chat_func)
for chat_func in chat_funcs)
)
async def compare_chat_chocies(
contents: list[str],
model_names: list[Callable] = [],
chat_funcs: list[Callable] = [],
model_weights: list[float] = []
):
if not chat_funcs or len(chat_funcs) == 0:
chat_funcs = get_chat_func(model_names)
if not model_weights:
model_weights = [1.0 for _ in range(len(chat_funcs))]
assert len(chat_funcs) == len(model_weights), \
"model_weights must be same length as chat_funcs"
prompts = []
eval_chat_funcs = []
for i in range(len(contents)):
prompt = f"""
You are provided with {len(contents)-1} choices, and you are asked to rank them based on the quality and relevance.
Rank 1 is the best.
Just Output Index and Corresponding Rank in format Index:Rank.
Just Number, no text. For example: "0:1" is correct, "Index 0:1" and "0: 1" are wrong.
One Line for Each Rank.
Just output like "Index:Rank\nIndex:Rank\nIndex:Rank\n"
No other output is allowed.
"""
for j, content in enumerate(contents):
if i == j: # skip self evaluation
continue
else:
prompt += f"""
Index {j}:
{content}
----------
"""
prompts.append(prompt)
eval_chat_funcs.append(chat_funcs[i])
compares = await asyncio.gather(
*(chat_func(prompt=prompt)
for prompt, chat_func in zip(prompts, eval_chat_funcs))
)
rank_scores = {i: 0 for i in range(len(contents))}
for i, comp in enumerate(compares):
for rank in comp.choices[0].message.content.strip().split("\n"):
index, rank = rank.split(":")
rank_scores[int(index)] += int(rank) * model_weights[i]
return rank_scores