|
|
import os |
|
|
import asyncio |
|
|
|
|
|
from loguru import logger |
|
|
from openai import OpenAI |
|
|
from functools import partial |
|
|
from typing import Callable |
|
|
|
|
|
CLIENTS = { |
|
|
"glm-4-plus": { |
|
|
"api_key": os.environ.get("API_GLM_4_PLUS"), |
|
|
"base_url": "https://open.bigmodel.cn/api/paas/v4", |
|
|
}, |
|
|
"glm-4": { |
|
|
"api_key": os.environ.get("API_GLM_4"), |
|
|
"base_url": "https://open.bigmodel.cn/api/paas/v4", |
|
|
}, |
|
|
"glm-4-airx": { |
|
|
"api_key": os.environ.get("API_GLM_4_AIRX"), |
|
|
"base_url": "https://open.bigmodel.cn/api/paas/v4", |
|
|
}, |
|
|
"glm-4-flash": { |
|
|
"api_key": os.environ.get("API_GLM_4_FLASH"), |
|
|
"base_url": "https://open.bigmodel.cn/api/paas/v4", |
|
|
}, |
|
|
"gpt-4o-mini": { |
|
|
"api_key": os.environ.get("API_GPT_4O_MINI"), |
|
|
"base_url": "https://api.qqslyx.com/v1", |
|
|
}, |
|
|
"deepseek-chat": { |
|
|
"api_key": os.environ.get("API_DEEPSEEK_CHAT"), |
|
|
"base_url": "https://api.deepseek.com/v1" |
|
|
}, |
|
|
"deepseek-v3-250324": { |
|
|
"api_key": os.environ.get("API_DEEPSEEK_V3"), |
|
|
"base_url": "https://ark.cn-beijing.volces.com/api/v3" |
|
|
} |
|
|
} |
|
|
|
|
|
def get_chat_func(model_names: list[str]): |
|
|
""" |
|
|
Get a list of chat functions for the specified model names. |
|
|
|
|
|
Args: |
|
|
model_names (list[str]): A list of model names. |
|
|
|
|
|
Returns: |
|
|
list[Callable]: A list of chat functions. |
|
|
""" |
|
|
chat_funcs = [] |
|
|
for model_name in model_names: |
|
|
if model_name not in list(CLIENTS.keys()): |
|
|
continue |
|
|
chat_funcs.append(partial(chat_completion, model_name=model_name)) |
|
|
return chat_funcs |
|
|
|
|
|
|
|
|
async def chat_completion(prompt: str, model_name: str) -> str: |
|
|
""" |
|
|
Perform a chat completion using the specified model. |
|
|
|
|
|
Args: |
|
|
prompt (str): The prompt to send to the model. |
|
|
model_name (str): The name of the model to use. |
|
|
client (OpenAI, optional): The OpenAI client to use. Defaults to None. |
|
|
|
|
|
Returns: |
|
|
str: The response from the model. |
|
|
|
|
|
""" |
|
|
assert model_name in list(CLIENTS.keys()), f"Model {model_name} not found" |
|
|
|
|
|
API_KEY = CLIENTS[model_name]["api_key"] |
|
|
BASE_URL = CLIENTS[model_name]["base_url"] |
|
|
|
|
|
client = OpenAI(api_key=API_KEY, base_url=BASE_URL) |
|
|
completion = client.chat.completions.create( |
|
|
model=model_name, |
|
|
messages=[ |
|
|
{"role": "user", "content": prompt} |
|
|
] |
|
|
) |
|
|
return completion |
|
|
|
|
|
|
|
|
async def retry_operation(func, task, max_retries=5, delay=0.5, *args, **kwargs): |
|
|
""" |
|
|
Retry an operation asynchronously with exponential backoff. |
|
|
|
|
|
Args: |
|
|
func (Callable): The function to be retried. |
|
|
task (Task): The task object to update the status. |
|
|
max_retries (int, optional): The maximum number of retries. Defaults to 5. |
|
|
delay (float, optional): The initial delay between retries. Defaults to 0.5. |
|
|
*args: Additional positional arguments to pass to the function. |
|
|
**kwargs: Additional keyword arguments to pass to the function. |
|
|
|
|
|
Returns: |
|
|
Any: The result of the operation. |
|
|
|
|
|
""" |
|
|
retries = 0 |
|
|
exceptions = [] |
|
|
while retries < max_retries: |
|
|
|
|
|
try: |
|
|
return await func(*args, **kwargs), None |
|
|
except Exception as e: |
|
|
exceptions.append(f"retry {retries}: {e}") |
|
|
retries += 1 |
|
|
logger.error(e) |
|
|
await asyncio.sleep(delay * retries) |
|
|
continue |
|
|
return None, "\n".join(exceptions) |
|
|
|
|
|
|
|
|
async def chat_completion_multiple_models( |
|
|
prompt: str, |
|
|
model_names: list[str] = [], |
|
|
chat_funcs: list[Callable] = [] |
|
|
): |
|
|
""" |
|
|
Perform a chat completion using multiple models asynchronously. |
|
|
|
|
|
Args: |
|
|
prompt (str): The prompt to send to the models. |
|
|
model_names (list[str], optional): A list of model names. Defaults to []. |
|
|
chat_funcs (list[Callable], optional): A list of chat functions. Defaults to []. |
|
|
|
|
|
Returns: |
|
|
list[Any]: A list of results from the chat completions. |
|
|
|
|
|
""" |
|
|
if not chat_funcs or len(chat_funcs) == 0: |
|
|
chat_funcs = get_chat_func(model_names) |
|
|
return await asyncio.gather( |
|
|
*(chat_func(prompt=prompt) |
|
|
for chat_func in chat_funcs) |
|
|
) |
|
|
|
|
|
|
|
|
async def func_wrap_multiple_models( |
|
|
wrap_func: Callable, |
|
|
model_names: list[str] = [], |
|
|
chat_funcs: list[Callable] = [], |
|
|
model_weights: list[float] = [], |
|
|
*args, |
|
|
): |
|
|
""" |
|
|
Wrap a function to be executed asynchronously with multiple models. |
|
|
|
|
|
Args: |
|
|
func (Callable): The function to be wrapped. |
|
|
model_names (list[str], optional): A list of model names. Defaults to []. |
|
|
chat_funcs (list[Callable], optional): A list of chat functions. Defaults to []. |
|
|
model_weights (list[float], optional): A list of model weights. Defaults to []. |
|
|
*args: Additional positional arguments to pass to the function. |
|
|
|
|
|
Returns: |
|
|
list[Any]: A list of results from the function. |
|
|
|
|
|
""" |
|
|
if not chat_funcs or len(chat_funcs) == 0: |
|
|
chat_funcs = get_chat_func(model_names) |
|
|
if not model_weights: |
|
|
model_weights = [1.0 for _ in range(len(chat_funcs))] |
|
|
assert len(chat_funcs) == len(model_weights), \ |
|
|
"model_weights must be same length as chat_funcs" |
|
|
|
|
|
return await asyncio.gather( |
|
|
*(wrap_func(*args, chat_func=chat_func) |
|
|
for chat_func in chat_funcs) |
|
|
) |
|
|
|
|
|
|
|
|
async def compare_chat_chocies( |
|
|
contents: list[str], |
|
|
model_names: list[Callable] = [], |
|
|
chat_funcs: list[Callable] = [], |
|
|
model_weights: list[float] = [] |
|
|
): |
|
|
if not chat_funcs or len(chat_funcs) == 0: |
|
|
chat_funcs = get_chat_func(model_names) |
|
|
if not model_weights: |
|
|
model_weights = [1.0 for _ in range(len(chat_funcs))] |
|
|
assert len(chat_funcs) == len(model_weights), \ |
|
|
"model_weights must be same length as chat_funcs" |
|
|
|
|
|
prompts = [] |
|
|
eval_chat_funcs = [] |
|
|
for i in range(len(contents)): |
|
|
prompt = f""" |
|
|
You are provided with {len(contents)-1} choices, and you are asked to rank them based on the quality and relevance. |
|
|
Rank 1 is the best. |
|
|
Just Output Index and Corresponding Rank in format Index:Rank. |
|
|
Just Number, no text. For example: "0:1" is correct, "Index 0:1" and "0: 1" are wrong. |
|
|
One Line for Each Rank. |
|
|
Just output like "Index:Rank\nIndex:Rank\nIndex:Rank\n" |
|
|
No other output is allowed. |
|
|
|
|
|
""" |
|
|
for j, content in enumerate(contents): |
|
|
if i == j: |
|
|
continue |
|
|
else: |
|
|
prompt += f""" |
|
|
Index {j}: |
|
|
{content} |
|
|
---------- |
|
|
|
|
|
""" |
|
|
prompts.append(prompt) |
|
|
eval_chat_funcs.append(chat_funcs[i]) |
|
|
compares = await asyncio.gather( |
|
|
*(chat_func(prompt=prompt) |
|
|
for prompt, chat_func in zip(prompts, eval_chat_funcs)) |
|
|
) |
|
|
|
|
|
rank_scores = {i: 0 for i in range(len(contents))} |
|
|
for i, comp in enumerate(compares): |
|
|
for rank in comp.choices[0].message.content.strip().split("\n"): |
|
|
index, rank = rank.split(":") |
|
|
rank_scores[int(index)] += int(rank) * model_weights[i] |
|
|
return rank_scores |
|
|
|
|
|
|