| from log_config import logger |
|
|
| import re |
| import copy |
| import httpx |
| import secrets |
| from time import time |
| from contextlib import asynccontextmanager |
| from starlette.middleware.base import BaseHTTPMiddleware |
|
|
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi import FastAPI, HTTPException, Depends, Request, APIRouter |
| from fastapi.responses import JSONResponse |
| from fastapi.responses import StreamingResponse as FastAPIStreamingResponse |
| from starlette.responses import StreamingResponse as StarletteStreamingResponse |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
| from fastapi.exceptions import RequestValidationError |
|
|
| from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, UnifiedRequest, EmbeddingRequest |
| from request import get_payload |
| from response import fetch_response, fetch_response_stream |
| from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder, get_model_dict, save_api_yaml |
|
|
| from collections import defaultdict |
| from typing import List, Dict, Union |
| from urllib.parse import urlparse |
|
|
| import os |
| import string |
| import json |
|
|
| is_debug = bool(os.getenv("DEBUG", False)) |
| |
|
|
| from sqlalchemy import inspect, text |
| from sqlalchemy.sql import sqltypes |
|
|
| |
| DISABLE_DATABASE = os.getenv("DISABLE_DATABASE", "false").lower() == "true" |
|
|
| async def create_tables(): |
| if DISABLE_DATABASE: |
| return |
| async with db_engine.begin() as conn: |
| await conn.run_sync(Base.metadata.create_all) |
|
|
| |
| def check_and_add_columns(connection): |
| inspector = inspect(connection) |
| for table in [RequestStat, ChannelStat]: |
| table_name = table.__tablename__ |
| existing_columns = {col['name']: col['type'] for col in inspector.get_columns(table_name)} |
|
|
| for column_name, column in table.__table__.columns.items(): |
| if column_name not in existing_columns: |
| col_type = _map_sa_type_to_sql_type(column.type) |
| default = _get_default_sql(column.default) |
| connection.execute(text(f"ALTER TABLE {table_name} ADD COLUMN {column_name} {col_type}{default}")) |
|
|
| await conn.run_sync(check_and_add_columns) |
|
|
| def _map_sa_type_to_sql_type(sa_type): |
| type_map = { |
| sqltypes.Integer: "INTEGER", |
| sqltypes.String: "TEXT", |
| sqltypes.Float: "REAL", |
| sqltypes.Boolean: "BOOLEAN", |
| sqltypes.DateTime: "DATETIME", |
| sqltypes.Text: "TEXT" |
| } |
| return type_map.get(type(sa_type), "TEXT") |
|
|
| def _get_default_sql(default): |
| if default is None: |
| return "" |
| if isinstance(default.arg, bool): |
| return f" DEFAULT {str(default.arg).upper()}" |
| if isinstance(default.arg, (int, float)): |
| return f" DEFAULT {default.arg}" |
| if isinstance(default.arg, str): |
| return f" DEFAULT '{default.arg}'" |
| return "" |
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| if not DISABLE_DATABASE: |
| await create_tables() |
|
|
| yield |
| |
| await app.state.client.aclose() |
|
|
| app = FastAPI(lifespan=lifespan, debug=is_debug) |
|
|
| def generate_markdown_docs(): |
| openapi_schema = app.openapi() |
|
|
| markdown = f"# {openapi_schema['info']['title']}\n\n" |
| markdown += f"Version: {openapi_schema['info']['version']}\n\n" |
| markdown += f"{openapi_schema['info'].get('description', '')}\n\n" |
|
|
| markdown += "## API Endpoints\n\n" |
|
|
| paths = openapi_schema['paths'] |
| for path, path_info in paths.items(): |
| for method, operation in path_info.items(): |
| markdown += f"### {method.upper()} {path}\n\n" |
| markdown += f"{operation.get('summary', '')}\n\n" |
| markdown += f"{operation.get('description', '')}\n\n" |
|
|
| if 'parameters' in operation: |
| markdown += "Parameters:\n" |
| for param in operation['parameters']: |
| markdown += f"- {param['name']} ({param['in']}): {param.get('description', '')}\n" |
|
|
| markdown += "\n---\n\n" |
|
|
| return markdown |
|
|
| @app.get("/docs/markdown") |
| async def get_markdown_docs(): |
| markdown = generate_markdown_docs() |
| return Response( |
| content=markdown, |
| media_type="text/markdown" |
| ) |
|
|
| @app.exception_handler(HTTPException) |
| async def http_exception_handler(request: Request, exc: HTTPException): |
| if exc.status_code == 404: |
| logger.error(f"404 Error: {exc.detail}") |
| return JSONResponse( |
| status_code=exc.status_code, |
| content={"message": exc.detail}, |
| ) |
|
|
| import uuid |
| import asyncio |
| import contextvars |
| request_info = contextvars.ContextVar('request_info', default={}) |
|
|
| async def parse_request_body(request: Request): |
| if request.method == "POST" and "application/json" in request.headers.get("content-type", ""): |
| try: |
| return await request.json() |
| except json.JSONDecodeError: |
| return None |
| return None |
|
|
| from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession |
| from sqlalchemy.orm import declarative_base, sessionmaker |
| from sqlalchemy import Column, Integer, String, Float, DateTime, select, Boolean, Text |
| from sqlalchemy.sql import func |
|
|
| |
| Base = declarative_base() |
|
|
| class RequestStat(Base): |
| __tablename__ = 'request_stats' |
| id = Column(Integer, primary_key=True) |
| request_id = Column(String) |
| endpoint = Column(String) |
| client_ip = Column(String) |
| process_time = Column(Float) |
| first_response_time = Column(Float) |
| provider = Column(String) |
| model = Column(String) |
| |
| api_key = Column(String) |
| is_flagged = Column(Boolean, default=False) |
| text = Column(Text) |
| prompt_tokens = Column(Integer, default=0) |
| completion_tokens = Column(Integer, default=0) |
| total_tokens = Column(Integer, default=0) |
| |
| timestamp = Column(DateTime(timezone=True), server_default=func.now()) |
|
|
| class ChannelStat(Base): |
| __tablename__ = 'channel_stats' |
| id = Column(Integer, primary_key=True) |
| request_id = Column(String) |
| provider = Column(String) |
| model = Column(String) |
| api_key = Column(String) |
| success = Column(Boolean, default=False) |
| timestamp = Column(DateTime(timezone=True), server_default=func.now()) |
|
|
|
|
| if not DISABLE_DATABASE: |
| |
| db_path = os.getenv('DB_PATH', './data/stats.db') |
|
|
| |
| data_dir = os.path.dirname(db_path) |
| os.makedirs(data_dir, exist_ok=True) |
|
|
| |
| |
| db_engine = create_async_engine('sqlite+aiosqlite:///' + db_path, echo=is_debug) |
| async_session = sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False) |
|
|
| from starlette.types import Scope, Receive, Send |
| from starlette.responses import Response |
|
|
| from decimal import Decimal, getcontext |
|
|
| |
| getcontext().prec = 17 |
|
|
| def calculate_cost(model: str, input_tokens: int, output_tokens: int) -> Decimal: |
| costs = { |
| "gpt-4": {"input": Decimal('5.0') / Decimal('1000000'), "output": Decimal('15.0') / Decimal('1000000')}, |
| "claude-3-sonnet": {"input": Decimal('3.0') / Decimal('1000000'), "output": Decimal('15.0') / Decimal('1000000')} |
| } |
|
|
| if model not in costs: |
| logger.error(f"Unknown model: {model}") |
| return 0 |
|
|
| model_costs = costs[model] |
| input_cost = Decimal(input_tokens) * model_costs["input"] |
| output_cost = Decimal(output_tokens) * model_costs["output"] |
| total_cost = input_cost + output_cost |
|
|
| |
| return total_cost.quantize(Decimal('0.000000000000001')) |
|
|
| async def update_stats(current_info): |
| if DISABLE_DATABASE: |
| return |
| |
| async with async_session() as session: |
| async with session.begin(): |
| try: |
| columns = [column.key for column in RequestStat.__table__.columns] |
| filtered_info = {k: v for k, v in current_info.items() if k in columns} |
| new_request_stat = RequestStat(**filtered_info) |
| session.add(new_request_stat) |
| await session.commit() |
| except Exception as e: |
| await session.rollback() |
| logger.error(f"Error updating stats: {str(e)}") |
|
|
| async def update_channel_stats(request_id, provider, model, api_key, success): |
| if DISABLE_DATABASE: |
| return |
| async with async_session() as session: |
| async with session.begin(): |
| try: |
| channel_stat = ChannelStat( |
| request_id=request_id, |
| provider=provider, |
| model=model, |
| api_key=api_key, |
| success=success, |
| ) |
| session.add(channel_stat) |
| await session.commit() |
| except Exception as e: |
| await session.rollback() |
| logger.error(f"Error updating channel stats: {str(e)}") |
|
|
| class LoggingStreamingResponse(Response): |
| def __init__(self, content, status_code=200, headers=None, media_type=None, current_info=None): |
| super().__init__(content=None, status_code=status_code, headers=headers, media_type=media_type) |
| self.body_iterator = content |
| self._closed = False |
| self.current_info = current_info |
|
|
| |
| if 'content-length' in self.headers: |
| del self.headers['content-length'] |
| |
| self.headers['transfer-encoding'] = 'chunked' |
|
|
| async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: |
| await send({ |
| 'type': 'http.response.start', |
| 'status': self.status_code, |
| 'headers': self.raw_headers, |
| }) |
|
|
| try: |
| async for chunk in self._logging_iterator(): |
| await send({ |
| 'type': 'http.response.body', |
| 'body': chunk, |
| 'more_body': True, |
| }) |
| finally: |
| await send({ |
| 'type': 'http.response.body', |
| 'body': b'', |
| 'more_body': False, |
| }) |
| if hasattr(self.body_iterator, 'aclose') and not self._closed: |
| await self.body_iterator.aclose() |
| self._closed = True |
|
|
| process_time = time() - self.current_info["start_time"] |
| self.current_info["process_time"] = process_time |
| await update_stats(self.current_info) |
|
|
| async def _logging_iterator(self): |
| try: |
| async for chunk in self.body_iterator: |
| if isinstance(chunk, str): |
| chunk = chunk.encode('utf-8') |
| line = chunk.decode('utf-8') |
| if is_debug: |
| logger.info(f"{line}") |
| if line.startswith("data:"): |
| line = line.lstrip("data: ") |
| if not line.startswith("[DONE]") and not line.startswith("OK"): |
| try: |
| resp: dict = json.loads(line) |
| input_tokens = safe_get(resp, "message", "usage", "input_tokens", default=0) |
| input_tokens = safe_get(resp, "usage", "prompt_tokens", default=0) |
| output_tokens = safe_get(resp, "usage", "completion_tokens", default=0) |
| total_tokens = input_tokens + output_tokens |
|
|
| self.current_info["prompt_tokens"] = input_tokens |
| self.current_info["completion_tokens"] = output_tokens |
| self.current_info["total_tokens"] = total_tokens |
| except Exception as e: |
| logger.error(f"Error parsing response: {str(e)}, line: {repr(line)}") |
| continue |
| yield chunk |
| except Exception as e: |
| raise |
| finally: |
| logger.debug("_logging_iterator finished") |
|
|
| async def close(self): |
| if not self._closed: |
| self._closed = True |
| if hasattr(self.body_iterator, 'aclose'): |
| await self.body_iterator.aclose() |
|
|
| class StatsMiddleware(BaseHTTPMiddleware): |
| def __init__(self, app): |
| super().__init__(app) |
|
|
| async def dispatch(self, request: Request, call_next): |
| start_time = time() |
|
|
| enable_moderation = False |
|
|
| config = app.state.config |
| |
| if request.headers.get("x-api-key"): |
| token = request.headers.get("x-api-key") |
| elif request.headers.get("Authorization"): |
| token = request.headers.get("Authorization").split(" ")[1] |
| else: |
| token = None |
| if token: |
| try: |
| api_list = app.state.api_list |
| api_index = api_list.index(token) |
| enable_moderation = safe_get(config, 'api_keys', api_index, "preferences", "ENABLE_MODERATION", default=False) |
| except ValueError: |
| |
| pass |
| else: |
| |
| enable_moderation = config.get('ENABLE_MODERATION', False) |
|
|
| |
| request_id = str(uuid.uuid4()) |
|
|
| |
| request_info_data = { |
| "request_id": request_id, |
| "start_time": start_time, |
| "endpoint": f"{request.method} {request.url.path}", |
| "client_ip": request.client.host, |
| "process_time": 0, |
| "first_response_time": -1, |
| "provider": None, |
| "model": None, |
| "success": False, |
| "api_key": token, |
| "is_flagged": False, |
| "text": None, |
| "prompt_tokens": 0, |
| "completion_tokens": 0, |
| |
| "total_tokens": 0 |
| } |
|
|
| |
| current_request_info = request_info.set(request_info_data) |
| current_info = request_info.get() |
|
|
| parsed_body = await parse_request_body(request) |
| if parsed_body: |
| try: |
| request_model = UnifiedRequest.model_validate(parsed_body).data |
| model = request_model.model |
| current_info["model"] = model |
|
|
| if request_model.request_type == "chat": |
| moderated_content = request_model.get_last_text_message() |
| elif request_model.request_type == "image": |
| moderated_content = request_model.prompt |
| if moderated_content: |
| current_info["text"] = moderated_content |
|
|
|
|
| if enable_moderation and moderated_content: |
| moderation_response = await self.moderate_content(moderated_content, token) |
| is_flagged = moderation_response.get('results', [{}])[0].get('flagged', False) |
|
|
| if is_flagged: |
| logger.error(f"Content did not pass the moral check: %s", moderated_content) |
| process_time = time() - start_time |
| current_info["process_time"] = process_time |
| current_info["is_flagged"] = is_flagged |
| await self.update_stats(current_info) |
| return JSONResponse( |
| status_code=400, |
| content={"error": "Content did not pass the moral check, please modify and try again."} |
| ) |
| except RequestValidationError: |
| logger.error(f"Invalid request body: {parsed_body}") |
| pass |
| except Exception as e: |
| if is_debug: |
| import traceback |
| traceback.print_exc() |
|
|
| logger.error(f"处理请求或进行道德检查时出错: {str(e)}") |
|
|
| try: |
| response = await call_next(request) |
|
|
| if request.url.path.startswith("/v1"): |
| if isinstance(response, (FastAPIStreamingResponse, StarletteStreamingResponse)) or type(response).__name__ == '_StreamingResponse': |
| response = LoggingStreamingResponse( |
| content=response.body_iterator, |
| status_code=response.status_code, |
| media_type=response.media_type, |
| headers=response.headers, |
| current_info=current_info, |
| ) |
| elif hasattr(response, 'json'): |
| logger.info(f"Response: {await response.json()}") |
| else: |
| logger.info(f"Response: type={type(response).__name__}, status_code={response.status_code}, headers={response.headers}") |
|
|
| return response |
| finally: |
| |
| request_info.reset(current_request_info) |
|
|
| async def moderate_content(self, content, token): |
| moderation_request = ModerationRequest(input=content) |
|
|
| |
| response = await moderations(moderation_request, token) |
|
|
| |
| moderation_result = b"" |
| async for chunk in response.body_iterator: |
| if isinstance(chunk, str): |
| moderation_result += chunk.encode('utf-8') |
| else: |
| moderation_result += chunk |
|
|
| |
| moderation_data = json.loads(moderation_result.decode('utf-8')) |
|
|
| return moderation_data |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| app.add_middleware(StatsMiddleware) |
|
|
| @app.middleware("http") |
| async def ensure_config(request: Request, call_next): |
| if not hasattr(app.state, 'config'): |
| logger.warning("Config not found, attempting to reload") |
| app.state.config, app.state.api_keys_db, app.state.api_list = await load_config(app) |
|
|
| for item in app.state.api_keys_db: |
| if item.get("role") == "admin": |
| app.state.admin_api_key = item.get("api") |
| if not hasattr(app.state, "admin_api_key"): |
| if len(app.state.api_keys_db) >= 1: |
| app.state.admin_api_key = app.state.api_keys_db[0].get("api") |
| else: |
| raise Exception("No admin API key found") |
|
|
| return await call_next(request) |
|
|
| |
| async def process_request(request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest], provider: Dict, endpoint=None, token=None): |
| url = provider['base_url'] |
| parsed_url = urlparse(url) |
| |
| engine = None |
| if parsed_url.netloc == 'generativelanguage.googleapis.com': |
| engine = "gemini" |
| elif parsed_url.netloc == 'aiplatform.googleapis.com': |
| engine = "vertex" |
| elif parsed_url.netloc == 'api.cloudflare.com': |
| engine = "cloudflare" |
| elif parsed_url.netloc == 'api.anthropic.com' or parsed_url.path.endswith("v1/messages"): |
| engine = "claude" |
| elif parsed_url.netloc == 'openrouter.ai': |
| engine = "openrouter" |
| elif parsed_url.netloc == 'api.cohere.com': |
| engine = "cohere" |
| request.stream = True |
| else: |
| engine = "gpt" |
|
|
| model_dict = get_model_dict(provider) |
| if "claude" not in model_dict[request.model] \ |
| and "gpt" not in model_dict[request.model] \ |
| and "gemini" not in model_dict[request.model] \ |
| and parsed_url.netloc != 'api.cloudflare.com' \ |
| and parsed_url.netloc != 'api.cohere.com': |
| engine = "openrouter" |
|
|
| if "claude" in model_dict[request.model] and engine == "vertex": |
| engine = "vertex-claude" |
|
|
| if "gemini" in model_dict[request.model] and engine == "vertex": |
| engine = "vertex-gemini" |
|
|
| if "o1-preview" in model_dict[request.model] or "o1-mini" in model_dict[request.model]: |
| engine = "o1" |
| request.stream = False |
|
|
| if endpoint == "/v1/images/generations": |
| engine = "dalle" |
| request.stream = False |
|
|
| if endpoint == "/v1/audio/transcriptions": |
| engine = "whisper" |
| request.stream = False |
|
|
| if endpoint == "/v1/moderations": |
| engine = "moderation" |
| request.stream = False |
|
|
| if endpoint == "/v1/embeddings": |
| engine = "embedding" |
| request.stream = False |
|
|
| if provider.get("engine"): |
| engine = provider["engine"] |
|
|
| logger.info(f"provider: {provider['provider']:<11} model: {request.model:<22} engine: {engine}") |
|
|
| url, headers, payload = await get_payload(request, engine, provider) |
| if is_debug: |
| logger.info(json.dumps(headers, indent=4, ensure_ascii=False)) |
| if payload.get("file"): |
| pass |
| else: |
| logger.info(json.dumps(payload, indent=4, ensure_ascii=False)) |
| current_info = request_info.get() |
| try: |
| model = model_dict[request.model] |
| if request.stream: |
| generator = fetch_response_stream(app.state.client, url, headers, payload, engine, model) |
| wrapped_generator, first_response_time = await error_handling_wrapper(generator) |
| response = StarletteStreamingResponse(wrapped_generator, media_type="text/event-stream") |
| else: |
| generator = fetch_response(app.state.client, url, headers, payload, engine, model) |
| wrapped_generator, first_response_time = await error_handling_wrapper(generator) |
| first_element = await anext(wrapped_generator) |
| first_element = first_element.lstrip("data: ") |
| print("first_element", first_element) |
| first_element = json.loads(first_element) |
| response = StarletteStreamingResponse(iter([json.dumps(first_element)]), media_type="application/json") |
| |
|
|
| |
| await update_channel_stats(current_info["request_id"], provider['provider'], request.model, token, success=True) |
| |
| current_info["first_response_time"] = first_response_time |
| current_info["success"] = True |
| current_info["provider"] = provider['provider'] |
|
|
| return response |
| except (Exception, HTTPException, asyncio.CancelledError, httpx.ReadError, httpx.RemoteProtocolError) as e: |
| await update_channel_stats(current_info["request_id"], provider['provider'], request.model, token, success=False) |
| |
|
|
| raise e |
|
|
| def weighted_round_robin(weights): |
| provider_names = list(weights.keys()) |
| current_weights = {name: 0 for name in provider_names} |
| num_selections = total_weight = sum(weights.values()) |
| weighted_provider_list = [] |
|
|
| for _ in range(num_selections): |
| max_ratio = -1 |
| selected_letter = None |
|
|
| for name in provider_names: |
| current_weights[name] += weights[name] |
| ratio = current_weights[name] / weights[name] |
|
|
| if ratio > max_ratio: |
| max_ratio = ratio |
| selected_letter = name |
|
|
| weighted_provider_list.append(selected_letter) |
| current_weights[selected_letter] -= total_weight |
|
|
| return weighted_provider_list |
|
|
| import random |
|
|
| def lottery_scheduling(weights): |
| total_tickets = sum(weights.values()) |
| selections = [] |
| for _ in range(total_tickets): |
| ticket = random.randint(1, total_tickets) |
| cumulative = 0 |
| for provider, weight in weights.items(): |
| cumulative += weight |
| if ticket <= cumulative: |
| selections.append(provider) |
| break |
| return selections |
|
|
| def get_provider_rules(model_rule, config, request_model): |
| provider_rules = [] |
| if model_rule == "all": |
| |
| for provider in config["providers"]: |
| model_dict = get_model_dict(provider) |
| for model in model_dict.keys(): |
| provider_rules.append(provider["provider"] + "/" + model) |
|
|
| elif "/" in model_rule: |
| if model_rule.startswith("<") and model_rule.endswith(">"): |
| model_rule = model_rule[1:-1] |
| |
| for provider in config['providers']: |
| model_dict = get_model_dict(provider) |
| if model_rule in model_dict.keys(): |
| provider_rules.append(provider['provider'] + "/" + model_rule) |
| else: |
| provider_name = model_rule.split("/")[0] |
| model_name_split = "/".join(model_rule.split("/")[1:]) |
| models_list = [] |
| for provider in config['providers']: |
| model_dict = get_model_dict(provider) |
| if provider['provider'] == provider_name: |
| models_list.extend(list(model_dict.keys())) |
| |
| |
| |
| |
|
|
| |
| if model_name_split == "*": |
| if request_model in models_list: |
| provider_rules.append(provider_name + "/" + request_model) |
|
|
| |
| for models_list_model in models_list: |
| if request_model.endswith("*") and models_list_model.startswith(request_model.rstrip("*")): |
| provider_rules.append(provider_name + "/" + models_list_model) |
|
|
| |
| elif model_name_split == request_model \ |
| or (request_model.endswith("*") and model_name_split.startswith(request_model.rstrip("*"))): |
| if model_name_split in models_list: |
| provider_rules.append(provider_name + "/" + model_name_split) |
|
|
| else: |
| for provider in config["providers"]: |
| model_dict = get_model_dict(provider) |
| if model_rule in model_dict.keys(): |
| provider_rules.append(provider["provider"] + "/" + model_rule) |
|
|
| return provider_rules |
|
|
| def get_provider_list(provider_rules, config, request_model): |
| provider_list = [] |
| |
| for item in provider_rules: |
| for provider in config['providers']: |
| if "/" in item and provider['provider'] == item.split("/")[0]: |
| new_provider = copy.deepcopy(provider) |
| model_dict = get_model_dict(provider) |
| model_name_split = "/".join(item.split("/")[1:]) |
| |
| new_provider["model"] = [{model_dict[model_name_split]: request_model}] |
| if request_model in model_dict.keys() and model_name_split == request_model: |
| provider_list.append(new_provider) |
|
|
| elif request_model.endswith("*") and model_name_split.startswith(request_model.rstrip("*")): |
| provider_list.append(new_provider) |
| return provider_list |
|
|
| def get_matching_providers(request_model, config, api_index): |
| provider_rules = [] |
|
|
| for model_rule in config['api_keys'][api_index]['model']: |
| provider_rules.extend(get_provider_rules(model_rule, config, request_model)) |
|
|
| provider_list = get_provider_list(provider_rules, config, request_model) |
|
|
| |
| return provider_list |
|
|
| import asyncio |
| class ModelRequestHandler: |
| def __init__(self): |
| self.last_provider_indices = defaultdict(lambda: -1) |
| self.locks = defaultdict(asyncio.Lock) |
|
|
| async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest], token: str, endpoint=None): |
| config = app.state.config |
| api_list = app.state.api_list |
| api_index = api_list.index(token) |
|
|
| if not safe_get(config, 'api_keys', api_index, 'model'): |
| raise HTTPException(status_code=404, detail="No matching model found") |
|
|
| request_model = request.model |
| matching_providers = get_matching_providers(request_model, config, api_index) |
| num_matching_providers = len(matching_providers) |
|
|
| if not matching_providers: |
| raise HTTPException(status_code=404, detail="No matching model found") |
|
|
| |
| scheduling_algorithm = safe_get(config, 'api_keys', api_index, "preferences", "SCHEDULING_ALGORITHM", default="fixed_priority") |
| if scheduling_algorithm == "random": |
| matching_providers = random.sample(matching_providers, num_matching_providers) |
|
|
| weights = safe_get(config, 'api_keys', api_index, "weights") |
|
|
| |
| all_providers = set(provider['provider'] for provider in matching_providers) |
|
|
| intersection = None |
| if weights and all_providers: |
| weight_keys = set(weights.keys()) |
| provider_rules = [] |
| for model_rule in weight_keys: |
| provider_rules.extend(get_provider_rules(model_rule, config, request_model)) |
| provider_list = get_provider_list(provider_rules, config, request_model) |
| weight_keys = set([provider['provider'] for provider in provider_list]) |
| |
| |
| |
| intersection = all_providers.intersection(weight_keys) |
|
|
| if weights and intersection: |
| weights = dict(filter(lambda item: item[0] in intersection, weights.items())) |
|
|
| if scheduling_algorithm == "weighted_round_robin": |
| weighted_provider_name_list = weighted_round_robin(weights) |
| elif scheduling_algorithm == "lottery": |
| weighted_provider_name_list = lottery_scheduling(weights) |
| else: |
| weighted_provider_name_list = list(weights.keys()) |
| |
|
|
| new_matching_providers = [] |
| for provider_name in weighted_provider_name_list: |
| for provider in matching_providers: |
| if provider['provider'] == provider_name: |
| new_matching_providers.append(provider) |
| matching_providers = new_matching_providers |
|
|
| if is_debug: |
| for provider in matching_providers: |
| logger.info("available provider: %s", json.dumps(provider, indent=4, ensure_ascii=False, default=circular_list_encoder)) |
|
|
| status_code = 500 |
| error_message = None |
|
|
| start_index = 0 |
| if scheduling_algorithm != "fixed_priority": |
| async with self.locks[request_model]: |
| self.last_provider_indices[request_model] = (self.last_provider_indices[request_model] + 1) % num_matching_providers |
| start_index = self.last_provider_indices[request_model] |
|
|
| auto_retry = safe_get(config, 'api_keys', api_index, "preferences", "AUTO_RETRY", default=True) |
|
|
| for i in range(num_matching_providers + 1): |
| current_index = (start_index + i) % num_matching_providers |
| provider = matching_providers[current_index] |
| try: |
| response = await process_request(request, provider, endpoint, token) |
| return response |
| except HTTPException as e: |
| logger.error(f"Error with provider {provider['provider']}: {str(e)}") |
| status_code = e.status_code |
| error_message = e.detail |
|
|
| if auto_retry: |
| continue |
| else: |
| raise HTTPException(status_code=500, detail=f"Error: Current provider response failed: {error_message}") |
| except (Exception, asyncio.CancelledError, httpx.ReadError, httpx.RemoteProtocolError) as e: |
| logger.error(f"Error with provider {provider['provider']}: {str(e)}") |
| if is_debug: |
| import traceback |
| traceback.print_exc() |
| error_message = str(e) |
| if auto_retry: |
| continue |
| else: |
| raise HTTPException(status_code=500, detail=f"Error: Current provider response failed: {error_message}") |
|
|
| current_info = request_info.get() |
| current_info["first_response_time"] = -1 |
| current_info["success"] = False |
| current_info["provider"] = None |
| raise HTTPException(status_code=status_code, detail=f"All {request.model} error: {error_message}") |
|
|
| model_handler = ModelRequestHandler() |
|
|
| def parse_rate_limit(limit_string): |
| |
| time_units = { |
| 's': 1, 'sec': 1, 'second': 1, |
| 'm': 60, 'min': 60, 'minute': 60, |
| 'h': 3600, 'hr': 3600, 'hour': 3600, |
| 'd': 86400, 'day': 86400, |
| 'mo': 2592000, 'month': 2592000, |
| 'y': 31536000, 'year': 31536000 |
| } |
|
|
| |
| match = re.match(r'^(\d+)/(\w+)$', limit_string) |
| if not match: |
| raise ValueError(f"Invalid rate limit format: {limit_string}") |
|
|
| count, unit = match.groups() |
| count = int(count) |
|
|
| |
| if unit not in time_units: |
| raise ValueError(f"Unknown time unit: {unit}") |
|
|
| seconds = time_units[unit] |
|
|
| return (count, seconds) |
|
|
| class InMemoryRateLimiter: |
| def __init__(self): |
| self.requests = defaultdict(list) |
|
|
| async def is_rate_limited(self, key: str, limit: int, period: int) -> bool: |
| now = time() |
| self.requests[key] = [req for req in self.requests[key] if req > now - period] |
| if len(self.requests[key]) >= limit: |
| return True |
| self.requests[key].append(now) |
| return False |
|
|
| rate_limiter = InMemoryRateLimiter() |
|
|
| async def get_user_rate_limit(api_index: str = None): |
| |
| |
| config = app.state.config |
| raw_rate_limit = safe_get(config, 'api_keys', api_index, "preferences", "RATE_LIMIT") |
|
|
| if not api_index or not raw_rate_limit: |
| return (30, 60) |
|
|
| rate_limit = parse_rate_limit(raw_rate_limit) |
| return rate_limit |
|
|
| security = HTTPBearer() |
|
|
| async def rate_limit_dependency(request: Request, credentials: HTTPAuthorizationCredentials = Depends(security)): |
| token = credentials.credentials if credentials else None |
| api_list = app.state.api_list |
| try: |
| api_index = api_list.index(token) |
| except ValueError: |
| print("error: Invalid or missing API Key:", token) |
| api_index = None |
| token = None |
| limit, period = await get_user_rate_limit(api_index) |
|
|
| |
| client_ip = request.client.host |
| rate_limit_key = f"{client_ip}:{token}" if token else client_ip |
|
|
| if await rate_limiter.is_rate_limited(rate_limit_key, limit, period): |
| raise HTTPException(status_code=429, detail="Too many requests") |
|
|
| def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)): |
| api_list = app.state.api_list |
| token = credentials.credentials |
| if token not in api_list: |
| raise HTTPException(status_code=403, detail="Invalid or missing API Key") |
| return token |
|
|
| def verify_admin_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)): |
| api_list = app.state.api_list |
| token = credentials.credentials |
| if token not in api_list: |
| raise HTTPException(status_code=403, detail="Invalid or missing API Key") |
| for api_key in app.state.api_keys_db: |
| if api_key['api'] == token: |
| if api_key.get('role') != "admin": |
| raise HTTPException(status_code=403, detail="Permission denied") |
| return token |
|
|
| @app.post("/v1/chat/completions", dependencies=[Depends(rate_limit_dependency)]) |
| async def request_model(request: RequestModel, token: str = Depends(verify_api_key)): |
| return await model_handler.request_model(request, token) |
|
|
| @app.options("/v1/chat/completions", dependencies=[Depends(rate_limit_dependency)]) |
| async def options_handler(): |
| return JSONResponse(status_code=200, content={"detail": "OPTIONS allowed"}) |
|
|
| @app.get("/v1/models", dependencies=[Depends(rate_limit_dependency)]) |
| async def list_models(token: str = Depends(verify_api_key)): |
| models = post_all_models(token, app.state.config, app.state.api_list) |
| return JSONResponse(content={ |
| "object": "list", |
| "data": models |
| }) |
|
|
| @app.post("/v1/images/generations", dependencies=[Depends(rate_limit_dependency)]) |
| async def images_generations( |
| request: ImageGenerationRequest, |
| token: str = Depends(verify_api_key) |
| ): |
| return await model_handler.request_model(request, token, endpoint="/v1/images/generations") |
|
|
| @app.post("/v1/embeddings", dependencies=[Depends(rate_limit_dependency)]) |
| async def embeddings( |
| request: EmbeddingRequest, |
| token: str = Depends(verify_api_key) |
| ): |
| return await model_handler.request_model(request, token, endpoint="/v1/embeddings") |
|
|
| @app.post("/v1/moderations", dependencies=[Depends(rate_limit_dependency)]) |
| async def moderations( |
| request: ModerationRequest, |
| token: str = Depends(verify_api_key) |
| ): |
| return await model_handler.request_model(request, token, endpoint="/v1/moderations") |
|
|
| from fastapi import UploadFile, File, Form, HTTPException |
| import io |
| @app.post("/v1/audio/transcriptions", dependencies=[Depends(rate_limit_dependency)]) |
| async def audio_transcriptions( |
| file: UploadFile = File(...), |
| model: str = Form(...), |
| token: str = Depends(verify_api_key) |
| ): |
| try: |
| |
| content = await file.read() |
| file_obj = io.BytesIO(content) |
|
|
| |
| request = AudioTranscriptionRequest( |
| file=(file.filename, file_obj, file.content_type), |
| model=model |
| ) |
|
|
| return await model_handler.request_model(request, token, endpoint="/v1/audio/transcriptions") |
| except UnicodeDecodeError: |
| raise HTTPException(status_code=400, detail="Invalid audio file encoding") |
| except Exception as e: |
| if is_debug: |
| import traceback |
| traceback.print_exc() |
| raise HTTPException(status_code=500, detail=f"Error processing audio file: {str(e)}") |
|
|
| @app.get("/v1/generate-api-key", dependencies=[Depends(rate_limit_dependency)]) |
| def generate_api_key(): |
| |
| chars = string.ascii_letters + string.digits |
| |
| random_string = ''.join(secrets.choice(chars) for _ in range(36)) |
| api_key = "sk-" + random_string |
| return JSONResponse(content={"api_key": api_key}) |
|
|
| |
| from datetime import datetime, timedelta, timezone |
| from sqlalchemy import func, desc, case |
| from fastapi import Query |
|
|
| @app.get("/v1/stats", dependencies=[Depends(rate_limit_dependency)]) |
| async def get_stats( |
| request: Request, |
| token: str = Depends(verify_admin_api_key), |
| hours: int = Query(default=24, ge=1, le=720, description="Number of hours to look back for stats (1-720)") |
| ): |
| ''' |
| ## 获取统计数据 |
| |
| 使用 `/v1/stats` 获取最近 24 小时各个渠道的使用情况统计。同时带上 自己的 uni-api 的 admin API key。 |
| |
| 数据包括: |
| |
| 1. 每个渠道下面每个模型的成功率,成功率从高到低排序。 |
| 2. 每个渠道总的成功率,成功率从高到低排序。 |
| 3. 每个模型在所有渠道总的请求次数。 |
| 4. 每个端点的请求次数。 |
| 5. 每个ip请求的次数。 |
| |
| `/v1/stats?hours=48` 参数 `hours` 可以控制返回最近多少小时的数据统计,不传 `hours` 这个参数,默认统计最近 24 小时的统计数据。 |
| |
| 还有其他统计数据,可以自己写sql在数据库自己查。其他数据包括:首字时间,每个请求的总处理时间,每次请求是否成功,每次请求是否符合道德审查,每次请求的文本内容,每次请求的 API key,每次请求的输入 token,输出 token 数量。 |
| ''' |
| if DISABLE_DATABASE: |
| return JSONResponse(content={"stats": {}}) |
| async with async_session() as session: |
| |
| start_time = datetime.now(timezone.utc) - timedelta(hours=hours) |
|
|
| |
| channel_model_stats = await session.execute( |
| select( |
| ChannelStat.provider, |
| ChannelStat.model, |
| func.count().label('total'), |
| func.sum(case((ChannelStat.success == True, 1), else_=0)).label('success_count') |
| ) |
| .where(ChannelStat.timestamp >= start_time) |
| .group_by(ChannelStat.provider, ChannelStat.model) |
| ) |
| channel_model_stats = channel_model_stats.fetchall() |
|
|
| |
| channel_stats = await session.execute( |
| select( |
| ChannelStat.provider, |
| func.count().label('total'), |
| func.sum(case((ChannelStat.success == True, 1), else_=0)).label('success_count') |
| ) |
| .where(ChannelStat.timestamp >= start_time) |
| .group_by(ChannelStat.provider) |
| ) |
| channel_stats = channel_stats.fetchall() |
|
|
| |
| model_stats = await session.execute( |
| select(RequestStat.model, func.count().label('count')) |
| .where(RequestStat.timestamp >= start_time) |
| .group_by(RequestStat.model) |
| .order_by(desc('count')) |
| ) |
| model_stats = model_stats.fetchall() |
|
|
| |
| endpoint_stats = await session.execute( |
| select(RequestStat.endpoint, func.count().label('count')) |
| .where(RequestStat.timestamp >= start_time) |
| .group_by(RequestStat.endpoint) |
| .order_by(desc('count')) |
| ) |
| endpoint_stats = endpoint_stats.fetchall() |
|
|
| |
| ip_stats = await session.execute( |
| select(RequestStat.client_ip, func.count().label('count')) |
| .where(RequestStat.timestamp >= start_time) |
| .group_by(RequestStat.client_ip) |
| .order_by(desc('count')) |
| ) |
| ip_stats = ip_stats.fetchall() |
|
|
| |
| stats = { |
| "time_range": f"Last {hours} hours", |
| "channel_model_success_rates": [ |
| { |
| "provider": stat.provider, |
| "model": stat.model, |
| "success_rate": stat.success_count / stat.total if stat.total > 0 else 0, |
| "total_requests": stat.total |
| } for stat in sorted(channel_model_stats, key=lambda x: x.success_count / x.total if x.total > 0 else 0, reverse=True) |
| ], |
| "channel_success_rates": [ |
| { |
| "provider": stat.provider, |
| "success_rate": stat.success_count / stat.total if stat.total > 0 else 0, |
| "total_requests": stat.total |
| } for stat in sorted(channel_stats, key=lambda x: x.success_count / x.total if x.total > 0 else 0, reverse=True) |
| ], |
| "model_request_counts": [ |
| { |
| "model": stat.model, |
| "count": stat.count |
| } for stat in model_stats |
| ], |
| "endpoint_request_counts": [ |
| { |
| "endpoint": stat.endpoint, |
| "count": stat.count |
| } for stat in endpoint_stats |
| ], |
| "ip_request_counts": [ |
| { |
| "ip": stat.client_ip, |
| "count": stat.count |
| } for stat in ip_stats |
| ] |
| } |
|
|
| return JSONResponse(content=stats) |
|
|
|
|
|
|
| from fastapi import FastAPI, Request |
| from fastapi import Form as FastapiForm, HTTPException, Depends |
| from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse |
| from fastapi.security import APIKeyHeader |
| from typing import Optional, List |
|
|
| from xue import HTML, Head, Body, Div, xue_initialize, Script, Ul, Li |
| from xue.components.menubar import ( |
| Menubar, MenubarMenu, MenubarTrigger, MenubarContent, |
| MenubarItem, MenubarSeparator |
| ) |
| from xue.components import input, dropdown, sheet, form, button, checkbox, sidebar |
| from xue.components.model_config_row import model_config_row |
| |
| |
| |
| from components.provider_table import data_table |
|
|
| from ruamel.yaml import YAML |
| yaml = YAML() |
| yaml.preserve_quotes = True |
| yaml.indent(mapping=2, sequence=4, offset=2) |
|
|
| frontend_router = APIRouter() |
|
|
| API_KEY_NAME = "X-API-Key" |
| api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) |
| async def get_api_key(request: Request, x_api_key: Optional[str] = Depends(api_key_header)): |
| if not x_api_key: |
| x_api_key = request.cookies.get("x_api_key") or request.query_params.get("x_api_key") |
| |
| |
| |
| |
|
|
| if x_api_key == app.state.admin_api_key: |
| return x_api_key |
| else: |
| return None |
|
|
| async def frontend_rate_limit_dependency(request: Request, x_api_key: str = Depends(get_api_key)): |
| token = x_api_key if x_api_key else None |
| limit, period = 100, 60 |
|
|
| |
| client_ip = request.client.host |
| rate_limit_key = f"{client_ip}:{token}" if token else client_ip |
|
|
| if await rate_limiter.is_rate_limited(rate_limit_key, limit, period): |
| raise HTTPException(status_code=429, detail="Too many requests") |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| xue_initialize(tailwind=True) |
|
|
| data_table_columns = [ |
| |
| {"label": "Provider", "value": "provider", "sortable": True}, |
| {"label": "Base url", "value": "base_url", "sortable": True}, |
| |
| {"label": "Tools", "value": "tools", "sortable": True}, |
| ] |
|
|
| @frontend_router.get("/login", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)]) |
| async def login_page(): |
| return HTML( |
| Head(title="登录"), |
| Body( |
| Div( |
| form.Form( |
| form.FormField("API Key", "x_api_key", type="password", placeholder="输入API密钥", required=True), |
| Div(id="error-message", class_="text-red-500 mt-2"), |
| Div( |
| button.button("提交", variant="primary", type="submit"), |
| class_="flex justify-end mt-4" |
| ), |
| hx_post="/verify-api-key", |
| hx_target="#error-message", |
| hx_swap="innerHTML", |
| class_="space-y-4" |
| ), |
| class_="container mx-auto p-4 max-w-md" |
| ) |
| ) |
| ).render() |
|
|
|
|
| @frontend_router.post("/verify-api-key", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)]) |
| async def verify_api_key(x_api_key: str = FastapiForm(...)): |
| if x_api_key == app.state.admin_api_key: |
| response = JSONResponse(content={"success": True}) |
| response.headers["HX-Redirect"] = "/" |
| response.set_cookie( |
| key="x_api_key", |
| value=x_api_key, |
| httponly=True, |
| max_age=1800, |
| secure=False, |
| samesite="lax" |
| ) |
| return response |
| else: |
| return Div("无效的API密钥", class_="text-red-500").render() |
|
|
| |
| sidebar_items = [ |
| { |
| "icon": "layout-dashboard", |
| |
| "label": "Dashboard", |
| "value": "dashboard", |
| "hx": {"get": "/dashboard", "target": "#main-content"} |
| }, |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| ] |
|
|
| @frontend_router.get("/", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)]) |
| async def root(x_api_key: str = Depends(get_api_key)): |
| if not x_api_key: |
| return RedirectResponse(url="/login", status_code=303) |
|
|
| result = HTML( |
| Head( |
| Script(""" |
| document.addEventListener('DOMContentLoaded', function() { |
| const filterInput = document.getElementById('users-table-filter'); |
| filterInput.addEventListener('input', function() { |
| const filterValue = this.value; |
| htmx.ajax('GET', `/filter-table?filter=${filterValue}`, '#users-table'); |
| }); |
| }); |
| """), |
| title="uni-api" |
| ), |
| Body( |
| Div( |
| sidebar.Sidebar("zap", "uni-api", sidebar_items, is_collapsed=False, active_item="dashboard"), |
| Div( |
| Div( |
| data_table(data_table_columns, app.state.config["providers"], "users-table"), |
| class_="p-4" |
| ), |
| Div(id="sheet-container"), |
| id="main-content", |
| class_="ml-[240px] p-6 transition-[margin] duration-200 ease-in-out" |
| ), |
| class_="flex" |
| ), |
| class_="container mx-auto", |
| id="body" |
| ) |
| ).render() |
| |
| return result |
|
|
| @frontend_router.get("/sidebar/toggle", response_class=HTMLResponse) |
| async def toggle_sidebar(is_collapsed: bool = False): |
| return sidebar.Sidebar( |
| "zap", |
| "uni-api", |
| sidebar_items, |
| is_collapsed=not is_collapsed, |
| active_item="dashboard" |
| ).render() |
|
|
| @frontend_router.get("/dropdown-menu/{menu_id}/{row_id}", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)]) |
| async def get_columns_menu(menu_id: str, row_id: str): |
| columns = [ |
| { |
| "label": "Edit", |
| "value": "edit", |
| "hx-get": f"/edit-sheet/{row_id}", |
| "hx-target": "#sheet-container", |
| "hx-swap": "innerHTML" |
| }, |
| { |
| "label": "Duplicate", |
| "value": "duplicate", |
| "hx-post": f"/duplicate/{row_id}", |
| "hx-target": "body", |
| "hx-swap": "outerHTML" |
| }, |
| { |
| "label": "Delete", |
| "value": "delete", |
| "hx-delete": f"/delete/{row_id}", |
| "hx-target": "body", |
| "hx-swap": "outerHTML", |
| "hx-confirm": "Are you sure you want to delete this configuration?" |
| }, |
| ] |
| result = dropdown.dropdown_menu_content(menu_id, columns).render() |
| print(result) |
| return result |
|
|
| @frontend_router.get("/dropdown-menu/{menu_id}", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)]) |
| async def get_columns_menu(menu_id: str): |
| result = dropdown.dropdown_menu_content(menu_id, data_table_columns).render() |
| print(result) |
| return result |
|
|
| @frontend_router.get("/filter-table", response_class=HTMLResponse) |
| async def filter_table(filter: str = ""): |
| filtered_data = [ |
| (i, provider) for i, provider in enumerate(app.state.config["providers"]) |
| if filter.lower() in str(provider["provider"]).lower() or |
| filter.lower() in str(provider["base_url"]).lower() or |
| filter.lower() in str(provider["tools"]).lower() |
| ] |
| return data_table(data_table_columns, [p for _, p in filtered_data], "users-table", with_filter=False, row_ids=[i for i, _ in filtered_data]).render() |
|
|
| @frontend_router.post("/add-model", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)]) |
| async def add_model(): |
| new_model_id = f"model{hash(str(time()))}" |
| new_model = model_config_row(new_model_id).render() |
| return new_model |
|
|
| def render_api_keys(row_id, api_keys): |
| return Ul( |
| *[Li( |
| Div( |
| Div( |
| input.input( |
| type="text", |
| placeholder="Enter API key", |
| value=api_key, |
| name=f"api_key_{i}", |
| class_="flex-grow w-full" |
| ), |
| class_="flex-grow" |
| ), |
| button.button( |
| "Delete", |
| variant="outline", |
| type="button", |
| class_="ml-2", |
| hx_delete=f"/delete-api-key/{row_id}/{i}", |
| hx_target="#api-keys-container", |
| hx_swap="outerHTML" |
| ), |
| class_="flex items-center mb-2 w-full" |
| ) |
| ) for i, api_key in enumerate(api_keys)], |
| id="api-keys-container", |
| class_="space-y-2 w-full" |
| ) |
|
|
| @frontend_router.get("/edit-sheet/{row_id}", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)]) |
| async def get_edit_sheet(row_id: str, x_api_key: str = Depends(get_api_key)): |
| row_data = get_row_data(row_id) |
| print("row_data", row_data) |
|
|
| model_list = [] |
| for index, model in enumerate(row_data["model"]): |
| if isinstance(model, str): |
| model_list.append(model_config_row(f"model{index}", model, "", True)) |
| if isinstance(model, dict): |
| |
| key, value = list(model.items())[0] |
| model_list.append(model_config_row(f"model{index}", key, value, True)) |
|
|
| |
| api_keys = row_data["api"] if isinstance(row_data["api"], list) else [row_data["api"]] |
| api_key_inputs = render_api_keys(row_id, api_keys) |
|
|
| sheet_id = "edit-sheet" |
| edit_sheet_content = sheet.SheetContent( |
| sheet.SheetHeader( |
| sheet.SheetTitle("Edit Item"), |
| sheet.SheetDescription("Make changes to your item here.") |
| ), |
| sheet.SheetBody( |
| Div( |
| form.Form( |
| form.FormField("Provider", "provider", value=row_data["provider"], placeholder="Enter provider name", required=True), |
| form.FormField("Base URL", "base_url", value=row_data["base_url"], placeholder="Enter base URL", required=True), |
| |
| Div( |
| Div("API Keys", class_="text-lg font-semibold mb-2"), |
| api_key_inputs, |
| button.button( |
| "Add API Key", |
| class_="mt-2", |
| hx_post=f"/add-api-key/{row_id}", |
| hx_target="#api-keys-container", |
| hx_swap="outerHTML" |
| ), |
| class_="mb-4" |
| ), |
| Div( |
| Div("Models", class_="text-lg font-semibold mb-2"), |
| Div( |
| *model_list, |
| id="models-container", |
| class_="space-y-2 max-h-[40vh] overflow-y-auto" |
| ), |
| button.button( |
| "Add Model", |
| class_="mt-2", |
| hx_post="/add-model", |
| hx_target="#models-container", |
| hx_swap="beforeend" |
| ), |
| class_="mb-4" |
| ), |
| Div( |
| checkbox.checkbox("tools", "Enable Tools", checked=row_data["tools"], name="tools"), |
| class_="mb-4" |
| ), |
| form.FormField("Notes", "notes", value=row_data.get("notes", ""), placeholder="Enter any additional notes"), |
| Div( |
| button.button("Submit", variant="primary", type="submit"), |
| button.button("Cancel", variant="outline", type="button", class_="ml-2", onclick=f"toggleSheet('{sheet_id}')"), |
| class_="flex justify-end mt-4" |
| ), |
| hx_post=f"/submit/{row_id}", |
| hx_swap="outerHTML", |
| hx_target="body", |
| class_="space-y-4" |
| ), |
| class_="container mx-auto p-4 max-w-2xl" |
| ) |
| ), |
| class_="max-h-[90vh] overflow-y-auto" |
| ) |
|
|
| result = sheet.Sheet( |
| sheet_id, |
| Div(), |
| edit_sheet_content, |
| width="80%", |
| max_width="800px" |
| ).render() |
| return result |
|
|
| @frontend_router.post("/add-api-key/{row_id}", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)]) |
| async def add_api_key(row_id: str): |
| row_data = get_row_data(row_id) |
| api_keys = row_data["api"] if isinstance(row_data["api"], list) else [row_data["api"]] |
| api_keys.append("") |
|
|
| api_key_inputs = render_api_keys(row_id, api_keys) |
|
|
| return api_key_inputs.render() |
|
|
| @frontend_router.delete("/delete-api-key/{row_id}/{index}", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)]) |
| async def delete_api_key(row_id: str, index: int): |
| row_data = get_row_data(row_id) |
| api_keys = row_data["api"] if isinstance(row_data["api"], list) else [row_data["api"]] |
| if len(api_keys) > 1: |
| del api_keys[index] |
|
|
| api_key_inputs = render_api_keys(row_id, api_keys) |
|
|
| return api_key_inputs.render() |
|
|
| @frontend_router.get("/add-provider-sheet", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)]) |
| async def get_add_provider_sheet(): |
| sheet_id = "add-provider-sheet" |
| edit_sheet_content = sheet.SheetContent( |
| sheet.SheetHeader( |
| sheet.SheetTitle("Add New Provider"), |
| sheet.SheetDescription("Enter details for the new provider.") |
| ), |
| sheet.SheetBody( |
| Div( |
| form.Form( |
| form.FormField("Provider", "provider", placeholder="Enter provider name", required=True), |
| form.FormField("Base URL", "base_url", placeholder="Enter base URL", required=True), |
| form.FormField("API Key", "api_key", type="text", placeholder="Enter API key"), |
| Div( |
| Div("Models", class_="text-lg font-semibold mb-2"), |
| Div(id="models-container"), |
| button.button( |
| "Add Model", |
| class_="mt-2", |
| hx_post="/add-model", |
| hx_target="#models-container", |
| hx_swap="beforeend" |
| ), |
| class_="mb-4" |
| ), |
| Div( |
| checkbox.checkbox("tools", "Enable Tools", name="tools"), |
| class_="mb-4" |
| ), |
| form.FormField("Notes", "notes", placeholder="Enter any additional notes"), |
| Div( |
| button.button("Submit", variant="primary", type="submit"), |
| button.button("Cancel", variant="outline", type="button", class_="ml-2", onclick=f"toggleSheet('{sheet_id}')"), |
| class_="flex justify-end mt-4" |
| ), |
| hx_post="/submit/new", |
| hx_swap="outerHTML", |
| hx_target="body", |
| class_="space-y-4" |
| ), |
| class_="container mx-auto p-4 max-w-2xl" |
| ) |
| ) |
| ) |
|
|
| result = sheet.Sheet( |
| sheet_id, |
| Div(), |
| edit_sheet_content, |
| width="80%", |
| max_width="800px" |
| ).render() |
| return result |
|
|
| def get_row_data(row_id): |
| index = int(row_id) |
| |
| return app.state.config["providers"][index] |
|
|
| def update_row_data(row_id, updated_data): |
| print(row_id, updated_data) |
| index = int(row_id) |
| app.state.config["providers"][index] = updated_data |
|
|
| @frontend_router.post("/submit/{row_id}", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)]) |
| async def submit_form( |
| row_id: str, |
| request: Request, |
| provider: str = FastapiForm(...), |
| base_url: str = FastapiForm(...), |
| |
| tools: Optional[str] = FastapiForm(None), |
| notes: Optional[str] = FastapiForm(None), |
| x_api_key: str = Depends(get_api_key) |
| ): |
| form_data = await request.form() |
|
|
| api_keys = [value for key, value in form_data.items() if key.startswith("api_key_") and value] |
|
|
| |
| models = [] |
| for key, value in form_data.items(): |
| if key.startswith("model_name_"): |
| model_id = key.split("_")[-1] |
| enabled = form_data.get(f"model_enabled_{model_id}") == "on" |
| rename = form_data.get(f"model_rename_{model_id}") |
| if value: |
| if rename: |
| models.append({value: rename}) |
| else: |
| models.append(value) |
|
|
| updated_data = { |
| "provider": provider, |
| "base_url": base_url, |
| "api": api_keys[0] if len(api_keys) == 1 else api_keys, |
| "model": models, |
| "tools": tools == "on", |
| "notes": notes, |
| } |
|
|
| print("updated_data", updated_data) |
|
|
| if row_id == "new": |
| |
| app.state.config["providers"].append(updated_data) |
| else: |
| |
| update_row_data(row_id, updated_data) |
|
|
| |
| if not DISABLE_DATABASE: |
| save_api_yaml(app.state.config) |
|
|
| return await root() |
|
|
| @frontend_router.post("/duplicate/{row_id}", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)]) |
| async def duplicate_row(row_id: str): |
| index = int(row_id) |
| original_data = app.state.config["providers"][index] |
| new_data = original_data.copy() |
| new_data["provider"] += "-copy" |
| app.state.config["providers"].insert(index + 1, new_data) |
|
|
| |
| if not DISABLE_DATABASE: |
| save_api_yaml(app.state.config) |
|
|
| return await root() |
|
|
| @frontend_router.delete("/delete/{row_id}", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)]) |
| async def delete_row(row_id: str): |
| index = int(row_id) |
| del app.state.config["providers"][index] |
|
|
| |
| if not DISABLE_DATABASE: |
| save_api_yaml(app.state.config) |
|
|
| return await root() |
|
|
| app.include_router(frontend_router, tags=["frontend"]) |
|
|
| |
| |
| |
|
|
| if __name__ == '__main__': |
| import uvicorn |
| uvicorn.run( |
| "__main__:app", |
| host="0.0.0.0", |
| port=8000, |
| reload=True, |
| reload_dirs=["./"], |
| reload_includes=["*.py", "api.yaml"], |
| ws="none", |
| |
| ) |