import os import asyncio from loguru import logger from openai import OpenAI from functools import partial from typing import Callable CLIENTS = { "glm-4-plus": { "api_key": os.environ.get("API_GLM_4_PLUS"), "base_url": "https://open.bigmodel.cn/api/paas/v4", }, "glm-4": { "api_key": os.environ.get("API_GLM_4"), "base_url": "https://open.bigmodel.cn/api/paas/v4", }, "glm-4-airx": { "api_key": os.environ.get("API_GLM_4_AIRX"), "base_url": "https://open.bigmodel.cn/api/paas/v4", }, "glm-4-flash": { "api_key": os.environ.get("API_GLM_4_FLASH"), "base_url": "https://open.bigmodel.cn/api/paas/v4", }, "gpt-4o-mini": { "api_key": os.environ.get("API_GPT_4O_MINI"), "base_url": "https://api.qqslyx.com/v1", }, "deepseek-chat": { "api_key": os.environ.get("API_DEEPSEEK_CHAT"), "base_url": "https://api.deepseek.com/v1" }, "deepseek-v3-250324": { "api_key": os.environ.get("API_DEEPSEEK_V3"), "base_url": "https://ark.cn-beijing.volces.com/api/v3" } } def get_chat_func(model_names: list[str]): """ Get a list of chat functions for the specified model names. Args: model_names (list[str]): A list of model names. Returns: list[Callable]: A list of chat functions. """ chat_funcs = [] for model_name in model_names: if model_name not in list(CLIENTS.keys()): continue chat_funcs.append(partial(chat_completion, model_name=model_name)) return chat_funcs async def chat_completion(prompt: str, model_name: str) -> str: """ Perform a chat completion using the specified model. Args: prompt (str): The prompt to send to the model. model_name (str): The name of the model to use. client (OpenAI, optional): The OpenAI client to use. Defaults to None. Returns: str: The response from the model. """ assert model_name in list(CLIENTS.keys()), f"Model {model_name} not found" API_KEY = CLIENTS[model_name]["api_key"] BASE_URL = CLIENTS[model_name]["base_url"] client = OpenAI(api_key=API_KEY, base_url=BASE_URL) completion = client.chat.completions.create( model=model_name, messages=[ {"role": "user", "content": prompt} ] ) return completion async def retry_operation(func, task, max_retries=5, delay=0.5, *args, **kwargs): """ Retry an operation asynchronously with exponential backoff. Args: func (Callable): The function to be retried. task (Task): The task object to update the status. max_retries (int, optional): The maximum number of retries. Defaults to 5. delay (float, optional): The initial delay between retries. Defaults to 0.5. *args: Additional positional arguments to pass to the function. **kwargs: Additional keyword arguments to pass to the function. Returns: Any: The result of the operation. """ retries = 0 exceptions = [] while retries < max_retries: # return await func(*args, **kwargs) try: return await func(*args, **kwargs), None except Exception as e: exceptions.append(f"retry {retries}: {e}") retries += 1 logger.error(e) await asyncio.sleep(delay * retries) continue return None, "\n".join(exceptions) async def chat_completion_multiple_models( prompt: str, model_names: list[str] = [], chat_funcs: list[Callable] = [] ): """ Perform a chat completion using multiple models asynchronously. Args: prompt (str): The prompt to send to the models. model_names (list[str], optional): A list of model names. Defaults to []. chat_funcs (list[Callable], optional): A list of chat functions. Defaults to []. Returns: list[Any]: A list of results from the chat completions. """ if not chat_funcs or len(chat_funcs) == 0: chat_funcs = get_chat_func(model_names) return await asyncio.gather( *(chat_func(prompt=prompt) for chat_func in chat_funcs) ) async def func_wrap_multiple_models( wrap_func: Callable, model_names: list[str] = [], chat_funcs: list[Callable] = [], model_weights: list[float] = [], *args, ): """ Wrap a function to be executed asynchronously with multiple models. Args: func (Callable): The function to be wrapped. model_names (list[str], optional): A list of model names. Defaults to []. chat_funcs (list[Callable], optional): A list of chat functions. Defaults to []. model_weights (list[float], optional): A list of model weights. Defaults to []. *args: Additional positional arguments to pass to the function. Returns: list[Any]: A list of results from the function. """ if not chat_funcs or len(chat_funcs) == 0: chat_funcs = get_chat_func(model_names) if not model_weights: model_weights = [1.0 for _ in range(len(chat_funcs))] assert len(chat_funcs) == len(model_weights), \ "model_weights must be same length as chat_funcs" return await asyncio.gather( *(wrap_func(*args, chat_func=chat_func) for chat_func in chat_funcs) ) async def compare_chat_chocies( contents: list[str], model_names: list[Callable] = [], chat_funcs: list[Callable] = [], model_weights: list[float] = [] ): if not chat_funcs or len(chat_funcs) == 0: chat_funcs = get_chat_func(model_names) if not model_weights: model_weights = [1.0 for _ in range(len(chat_funcs))] assert len(chat_funcs) == len(model_weights), \ "model_weights must be same length as chat_funcs" prompts = [] eval_chat_funcs = [] for i in range(len(contents)): prompt = f""" You are provided with {len(contents)-1} choices, and you are asked to rank them based on the quality and relevance. Rank 1 is the best. Just Output Index and Corresponding Rank in format Index:Rank. Just Number, no text. For example: "0:1" is correct, "Index 0:1" and "0: 1" are wrong. One Line for Each Rank. Just output like "Index:Rank\nIndex:Rank\nIndex:Rank\n" No other output is allowed. """ for j, content in enumerate(contents): if i == j: # skip self evaluation continue else: prompt += f""" Index {j}: {content} ---------- """ prompts.append(prompt) eval_chat_funcs.append(chat_funcs[i]) compares = await asyncio.gather( *(chat_func(prompt=prompt) for prompt, chat_func in zip(prompts, eval_chat_funcs)) ) rank_scores = {i: 0 for i in range(len(contents))} for i, comp in enumerate(compares): for rank in comp.choices[0].message.content.strip().split("\n"): index, rank = rank.split(":") rank_scores[int(index)] += int(rank) * model_weights[i] return rank_scores