| | import json |
| | from fastapi import HTTPException |
| | import httpx |
| |
|
| | from log_config import logger |
| | from collections import defaultdict |
| |
|
| | import asyncio |
| |
|
| | class ThreadSafeCircularList: |
| | def __init__(self, items): |
| | self.items = items |
| | self.index = 0 |
| | self.lock = asyncio.Lock() |
| |
|
| | async def next(self): |
| | async with self.lock: |
| | item = self.items[self.index] |
| | self.index = (self.index + 1) % len(self.items) |
| | return item |
| |
|
| | def circular_list_encoder(obj): |
| | if isinstance(obj, ThreadSafeCircularList): |
| | return obj.to_dict() |
| | raise TypeError(f'Object of type {obj.__class__.__name__} is not JSON serializable') |
| |
|
| | provider_api_circular_list = defaultdict(ThreadSafeCircularList) |
| |
|
| | def get_model_dict(provider): |
| | model_dict = {} |
| | for model in provider['model']: |
| | if type(model) == str: |
| | model_dict[model] = model |
| | if isinstance(model, dict): |
| | model_dict.update({new: old for old, new in model.items()}) |
| | return model_dict |
| |
|
| | def update_initial_model(api_url, api): |
| | try: |
| | endpoint = BaseAPI(api_url=api_url) |
| | endpoint_models_url = endpoint.v1_models |
| | if isinstance(api, list): |
| | api = api[0] |
| | response = httpx.get( |
| | endpoint_models_url, |
| | headers={"Authorization": f"Bearer {api}"}, |
| | ) |
| | models = response.json() |
| | if models.get("error"): |
| | raise Exception({"error": models.get("error"), "endpoint": endpoint_models_url, "api": api}) |
| | |
| | models_list = models["data"] |
| | models_id = [model["id"] for model in models_list] |
| | set_models = set() |
| | for model_item in models_id: |
| | set_models.add(model_item) |
| | models_id = list(set_models) |
| | |
| | return models_id |
| | except Exception as e: |
| | |
| | import traceback |
| | traceback.print_exc() |
| | return [] |
| |
|
| | from ruamel.yaml import YAML, YAMLError |
| | yaml = YAML() |
| | yaml.preserve_quotes = True |
| | yaml.indent(mapping=2, sequence=4, offset=2) |
| |
|
| | API_YAML_PATH = "./api.yaml" |
| |
|
| | def save_api_yaml(config_data): |
| | with open(API_YAML_PATH, "w", encoding="utf-8") as f: |
| | yaml.dump(config_data, f) |
| |
|
| | def update_config(config_data, use_config_url=False): |
| | for index, provider in enumerate(config_data['providers']): |
| | if provider.get('project_id'): |
| | provider['base_url'] = 'https://aiplatform.googleapis.com/' |
| | if provider.get('cf_account_id'): |
| | provider['base_url'] = 'https://api.cloudflare.com/' |
| |
|
| | provider_api = provider.get('api', None) |
| | if provider_api: |
| | if isinstance(provider_api, str): |
| | provider_api_circular_list[provider['provider']] = ThreadSafeCircularList([provider_api]) |
| | if isinstance(provider_api, list): |
| | provider_api_circular_list[provider['provider']] = ThreadSafeCircularList(provider_api) |
| |
|
| | if not provider.get("model"): |
| | model_list = update_initial_model(provider['base_url'], provider['api']) |
| | if model_list: |
| | provider["model"] = model_list |
| | if not use_config_url: |
| | save_api_yaml(config_data) |
| |
|
| | if provider.get("tools") == None: |
| | provider["tools"] = True |
| |
|
| | config_data['providers'][index] = provider |
| |
|
| | api_keys_db = config_data['api_keys'] |
| |
|
| | for index, api_key in enumerate(config_data['api_keys']): |
| | weights_dict = {} |
| | models = [] |
| | if api_key.get('model'): |
| | for model in api_key.get('model'): |
| | if isinstance(model, dict): |
| | key, value = list(model.items())[0] |
| | |
| | if "/" in key: |
| | weights_dict.update({key: int(value)}) |
| | models.append(key) |
| | if isinstance(model, str): |
| | models.append(model) |
| | if weights_dict: |
| | config_data['api_keys'][index]['weights'] = weights_dict |
| | config_data['api_keys'][index]['model'] = models |
| | api_keys_db[index]['model'] = models |
| | else: |
| | |
| | config_data['api_keys'][index]['model'] = ["all"] |
| | api_keys_db[index]['model'] = ["all"] |
| |
|
| | api_list = [item["api"] for item in api_keys_db] |
| | |
| | return config_data, api_keys_db, api_list |
| |
|
| | |
| | async def load_config(app=None): |
| |
|
| | if app and not hasattr(app.state, 'client'): |
| | import os |
| | TIMEOUT = float(os.getenv("TIMEOUT", 100)) |
| | timeout = httpx.Timeout(connect=15.0, read=TIMEOUT, write=30.0, pool=30.0) |
| | default_headers = { |
| | "User-Agent": "curl/7.68.0", |
| | "Accept": "*/*", |
| | } |
| | app.state.client = httpx.AsyncClient( |
| | timeout=timeout, |
| | headers=default_headers, |
| | http2=True, |
| | verify=True, |
| | follow_redirects=True, |
| | ) |
| |
|
| | try: |
| | with open(API_YAML_PATH, 'r', encoding='utf-8') as file: |
| | conf = yaml.load(file) |
| |
|
| | if conf: |
| | config, api_keys_db, api_list = update_config(conf, use_config_url=False) |
| | else: |
| | logger.error("配置文件 'api.yaml' 为空。请检查文件内容。") |
| | config, api_keys_db, api_list = {}, {}, [] |
| | except FileNotFoundError: |
| | logger.error("'api.yaml' not found. Please check the file path.") |
| | config, api_keys_db, api_list = {}, {}, [] |
| | except YAMLError as e: |
| | logger.error("配置文件 'api.yaml' 格式不正确。请检查 YAML 格式。%s", e) |
| | config, api_keys_db, api_list = {}, {}, [] |
| | except OSError as e: |
| | logger.error(f"open 'api.yaml' failed: {e}") |
| | config, api_keys_db, api_list = {}, {}, [] |
| |
|
| | if config != {}: |
| | return config, api_keys_db, api_list |
| |
|
| | import os |
| | |
| | config_url = os.environ.get('CONFIG_URL') |
| | if config_url: |
| | try: |
| | response = await app.state.client.get(config_url) |
| | |
| | response.raise_for_status() |
| | config_data = yaml.load(response.text) |
| | |
| | |
| | if config_data: |
| | config, api_keys_db, api_list = update_config(config_data, use_config_url=True) |
| | else: |
| | logger.error(f"Error fetching or parsing config from {config_url}") |
| | config, api_keys_db, api_list = {}, {}, [] |
| | except Exception as e: |
| | logger.error(f"Error fetching or parsing config from {config_url}: {str(e)}") |
| | config, api_keys_db, api_list = {}, {}, [] |
| | return config, api_keys_db, api_list |
| |
|
| | def ensure_string(item): |
| | if isinstance(item, (bytes, bytearray)): |
| | return item.decode("utf-8") |
| | elif isinstance(item, str): |
| | return item |
| | elif isinstance(item, dict): |
| | return f"data: {json.dumps(item)}\n\n" |
| | else: |
| | return str(item) |
| |
|
| | import asyncio |
| | import time as time_module |
| | async def error_handling_wrapper(generator): |
| | start_time = time_module.time() |
| | try: |
| | first_item = await generator.__anext__() |
| | first_response_time = time_module.time() - start_time |
| | first_item_str = first_item |
| | |
| | if isinstance(first_item_str, (bytes, bytearray)): |
| | first_item_str = first_item_str.decode("utf-8") |
| | if isinstance(first_item_str, str): |
| | if first_item_str.startswith("data:"): |
| | first_item_str = first_item_str.lstrip("data: ") |
| | if first_item_str.startswith("[DONE]"): |
| | logger.error("error_handling_wrapper [DONE]!") |
| | raise StopAsyncIteration |
| | if "The bot's usage is covered by the developer" in first_item_str: |
| | logger.error("error const string: %s", first_item_str) |
| | raise StopAsyncIteration |
| | if "process this request due to overload or policy" in first_item_str: |
| | logger.error("error const string: %s", first_item_str) |
| | raise StopAsyncIteration |
| | try: |
| | first_item_str = json.loads(first_item_str) |
| | except json.JSONDecodeError: |
| | logger.error("error_handling_wrapper JSONDecodeError!" + repr(first_item_str)) |
| | raise StopAsyncIteration |
| | if isinstance(first_item_str, dict) and 'error' in first_item_str: |
| | |
| | status_code = first_item_str.get('status_code', 500) |
| | detail = first_item_str.get('details', f"{first_item_str}") |
| | raise HTTPException(status_code=status_code, detail=f"{detail}"[:300]) |
| |
|
| | |
| | async def new_generator(): |
| | yield ensure_string(first_item) |
| | try: |
| | async for item in generator: |
| | yield ensure_string(item) |
| | except (httpx.ReadError, asyncio.CancelledError, httpx.RemoteProtocolError) as e: |
| | logger.error(f"Network error in new_generator: {e}") |
| | raise |
| |
|
| | return new_generator(), first_response_time |
| |
|
| | except StopAsyncIteration: |
| | raise HTTPException(status_code=400, detail="data: {'error': 'No data returned'}") |
| |
|
| | def post_all_models(token, config, api_list): |
| | all_models = [] |
| | unique_models = set() |
| |
|
| | if token not in api_list: |
| | raise HTTPException(status_code=403, detail="Invalid or missing API Key") |
| | api_index = api_list.index(token) |
| | if config['api_keys'][api_index]['model']: |
| | for model in config['api_keys'][api_index]['model']: |
| | if model == "all": |
| | |
| | all_models = get_all_models(config) |
| | return all_models |
| | if "/" in model: |
| | provider = model.split("/")[0] |
| | model = model.split("/")[1] |
| | if model == "*": |
| | for provider_item in config["providers"]: |
| | if provider_item['provider'] != provider: |
| | continue |
| | model_dict = get_model_dict(provider_item) |
| | for model_item in model_dict.keys(): |
| | if model_item not in unique_models: |
| | unique_models.add(model_item) |
| | model_info = { |
| | "id": model_item, |
| | "object": "model", |
| | "created": 1720524448858, |
| | "owned_by": "uni-api" |
| | |
| | } |
| | all_models.append(model_info) |
| | else: |
| | for provider_item in config["providers"]: |
| | if provider_item['provider'] != provider: |
| | continue |
| | model_dict = get_model_dict(provider_item) |
| | for model_item in model_dict.keys() : |
| | if model_item not in unique_models and model_item == model: |
| | unique_models.add(model_item) |
| | model_info = { |
| | "id": model_item, |
| | "object": "model", |
| | "created": 1720524448858, |
| | "owned_by": "uni-api" |
| | } |
| | all_models.append(model_info) |
| | continue |
| |
|
| | if model not in unique_models: |
| | unique_models.add(model) |
| | model_info = { |
| | "id": model, |
| | "object": "model", |
| | "created": 1720524448858, |
| | "owned_by": model |
| | } |
| | all_models.append(model_info) |
| |
|
| | return all_models |
| |
|
| | def get_all_models(config): |
| | all_models = [] |
| | unique_models = set() |
| |
|
| | for provider in config["providers"]: |
| | model_dict = get_model_dict(provider) |
| | for model in model_dict.keys(): |
| | if model not in unique_models: |
| | unique_models.add(model) |
| | model_info = { |
| | "id": model, |
| | "object": "model", |
| | "created": 1720524448858, |
| | "owned_by": "uni-api" |
| | } |
| | all_models.append(model_info) |
| |
|
| | return all_models |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | c35s = ThreadSafeCircularList(["us-east5", "europe-west1"]) |
| | c3s = ThreadSafeCircularList(["us-east5", "us-central1", "asia-southeast1"]) |
| | c3o = ThreadSafeCircularList(["us-east5"]) |
| | c3h = ThreadSafeCircularList(["us-east5", "us-central1", "europe-west1", "europe-west4"]) |
| | gem = ThreadSafeCircularList(["us-central1", "us-east4", "us-west1", "us-west4", "europe-west1", "europe-west2"]) |
| |
|
| | class BaseAPI: |
| | def __init__( |
| | self, |
| | api_url: str = "https://api.openai.com/v1/chat/completions", |
| | ): |
| | if api_url == "": |
| | api_url = "https://api.openai.com/v1/chat/completions" |
| | self.source_api_url: str = api_url |
| | from urllib.parse import urlparse, urlunparse |
| | parsed_url = urlparse(self.source_api_url) |
| | if parsed_url.scheme == "": |
| | raise Exception("Error: API_URL is not set") |
| | if parsed_url.path != '/': |
| | before_v1 = parsed_url.path.split("/v1")[0] |
| | else: |
| | before_v1 = "" |
| | self.base_url: str = urlunparse(parsed_url[:2] + (before_v1,) + ("",) * 3) |
| | self.v1_url: str = urlunparse(parsed_url[:2]+ (before_v1 + "/v1",) + ("",) * 3) |
| | self.v1_models: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/models",) + ("",) * 3) |
| | if parsed_url.netloc == "api.deepseek.com": |
| | self.chat_url: str = urlunparse(parsed_url[:2] + ("/chat/completions",) + ("",) * 3) |
| | else: |
| | self.chat_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/chat/completions",) + ("",) * 3) |
| | self.image_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/images/generations",) + ("",) * 3) |
| | self.audio_transcriptions: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/audio/transcriptions",) + ("",) * 3) |
| | self.moderations: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/moderations",) + ("",) * 3) |
| | self.embeddings: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/embeddings",) + ("",) * 3) |
| |
|
| | def safe_get(data, *keys, default=None): |
| | for key in keys: |
| | try: |
| | data = data[key] if isinstance(data, (dict, list)) else data.get(key) |
| | except (KeyError, IndexError, AttributeError, TypeError): |
| | return default |
| | return data |
| |
|