| | from fastapi import FastAPI, HTTPException, Request, Response |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from fastapi.responses import StreamingResponse, FileResponse |
| | from fastapi.staticfiles import StaticFiles |
| | import httpx |
| | import os |
| | import json |
| | from typing import List, Optional, Dict |
| | import requests |
| | from itertools import cycle |
| | import asyncio |
| | import time |
| | from datetime import datetime, timedelta |
| |
|
| | |
| | app = FastAPI() |
| |
|
| | |
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_credentials=True, |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| | |
| | class Config: |
| | OPENAI_API_BASE = "https://api.x.ai/v1" |
| | KEYS_URL = os.getenv("KEYS_URL", "") |
| | WHITELIST_IPS = os.getenv("WHITELIST_IPS", "").split(",") |
| | ADMIN_PASSWORD = os.getenv("ADMIN_PASSWORD", "admin") |
| |
|
| | |
| | class KeyStatus: |
| | def __init__(self, key: str): |
| | self.key = key |
| | self.status = "valid" |
| | self.cooling_until = None |
| | self.last_check = None |
| |
|
| | |
| | keys = [] |
| | key_cycle = None |
| | first_key = None |
| | key_status_map: Dict[str, KeyStatus] = {} |
| |
|
| | |
| | app.mount("/static", StaticFiles(directory="static"), name="static") |
| |
|
| | |
| | @app.get("/admin") |
| | async def admin(): |
| | return FileResponse("static/admin.html") |
| | |
| | def get_client_ip(request: Request) -> str: |
| | |
| | forwarded_for = request.headers.get("x-forwarded-for") |
| | if forwarded_for: |
| | return forwarded_for.split(",")[0].strip() |
| | |
| | real_ip = request.headers.get("x-real-ip") |
| | if real_ip: |
| | return real_ip |
| | |
| | return request.client.host |
| |
|
| | |
| | @app.middleware("http") |
| | async def ip_whitelist(request: Request, call_next): |
| | |
| | if "/api/" in request.url.path and "/api/admin/" not in request.url.path and "/api/keys" not in request.url.path: |
| | if Config.WHITELIST_IPS and Config.WHITELIST_IPS[0]: |
| | client_ip = get_client_ip(request) |
| | if client_ip not in Config.WHITELIST_IPS: |
| | raise HTTPException(status_code=403, detail="IP not allowed") |
| | return await call_next(request) |
| | |
| | |
| | def init_keys(): |
| | global keys, key_cycle, first_key, key_status_map |
| | try: |
| | if Config.KEYS_URL: |
| | response = requests.get(Config.KEYS_URL) |
| | keys = [k.strip() for k in response.text.splitlines() if k.strip()] |
| | else: |
| | with open("key.txt", "r") as f: |
| | keys = [k.strip() for k in f.readlines() if k.strip()] |
| | |
| | if keys: |
| | first_key = keys[0] |
| | key_cycle = cycle(keys) |
| | |
| | key_status_map = {key: KeyStatus(key) for key in keys} |
| | print(f"Loaded {len(keys)} API keys") |
| | except Exception as e: |
| | print(f"Error loading keys: {e}") |
| | keys = [] |
| | key_cycle = None |
| | first_key = None |
| | key_status_map = {} |
| |
|
| | |
| | def get_valid_key(): |
| | global key_cycle |
| | if not key_cycle: |
| | raise HTTPException(status_code=500, detail="No API keys available") |
| | |
| | |
| | for _ in range(len(keys)): |
| | key = next(key_cycle) |
| | key_info = key_status_map.get(key) |
| | if not key_info: |
| | key_info = KeyStatus(key) |
| | key_status_map[key] = key_info |
| | |
| | |
| | if key_info.status == "valid": |
| | return key |
| | elif key_info.status == "cooling": |
| | if key_info.cooling_until and datetime.now() > key_info.cooling_until: |
| | key_info.status = "valid" |
| | key_info.cooling_until = None |
| | return key |
| | |
| | raise HTTPException(status_code=500, detail="No valid API keys available") |
| |
|
| | |
| | def mark_key_cooling(key: str): |
| | if key in key_status_map: |
| | key_status_map[key].status = "cooling" |
| | key_status_map[key].cooling_until = datetime.now() + timedelta(days=30) |
| |
|
| | |
| | async def check_key_status(key: str) -> bool: |
| | try: |
| | async with httpx.AsyncClient() as client: |
| | response = await client.get( |
| | f"{Config.OPENAI_API_BASE}/models", |
| | headers={"Authorization": f"Bearer {key}"} |
| | ) |
| | is_valid = response.status_code == 200 |
| | key_info = key_status_map.get(key) |
| | if key_info: |
| | if not is_valid: |
| | key_info.status = "cooling" |
| | key_info.cooling_until = datetime.now() + timedelta(days=30) |
| | else: |
| | key_info.status = "valid" |
| | key_info.cooling_until = None |
| | key_info.last_check = datetime.now() |
| | return is_valid |
| | except Exception as e: |
| | print(f"Error checking key {key}: {e}") |
| | return False |
| |
|
| | |
| | async def stream_generator(response): |
| | buffer = "" |
| | try: |
| | async for chunk in response.aiter_bytes(): |
| | chunk_str = chunk.decode('utf-8') |
| | buffer += chunk_str |
| | |
| | while '\n\n' in buffer: |
| | event, buffer = buffer.split('\n\n', 1) |
| | if event.startswith('data: '): |
| | data = event[6:] |
| | if data.strip() == '[DONE]': |
| | yield f"data: [DONE]\n\n" |
| | else: |
| | try: |
| | json_data = json.loads(data) |
| | yield f"data: {json.dumps(json_data)}\n\n" |
| | except json.JSONDecodeError: |
| | print(f"JSON decode error for data: {data}") |
| | continue |
| | except Exception as e: |
| | print(f"Stream Error: {str(e)}") |
| | yield f"data: {json.dumps({'error': str(e)})}\n\n" |
| | |
| | def verify_admin(password: str): |
| | if not password or password != Config.ADMIN_PASSWORD: |
| | raise HTTPException(status_code=403, detail="Invalid admin password") |
| |
|
| | |
| | @app.post("/api/admin/login") |
| | async def admin_login(request: Request): |
| | data = await request.json() |
| | password = data.get("password") |
| | verify_admin(password) |
| | return {"status": "success"} |
| |
|
| | |
| | @app.get("/api/keys") |
| | async def list_keys(password: str): |
| | verify_admin(password) |
| | return { |
| | "keys": [ |
| | { |
| | "key": k, |
| | "status": key_status_map[k].status if k in key_status_map else "valid", |
| | "cooling_until": key_status_map[k].cooling_until.isoformat() if k in key_status_map and key_status_map[k].cooling_until else None, |
| | "last_check": key_status_map[k].last_check.isoformat() if k in key_status_map and key_status_map[k].last_check else None |
| | } |
| | for k in keys |
| | ] |
| | } |
| |
|
| | |
| | @app.post("/api/keys/add") |
| | async def add_key(request: Request): |
| | data = await request.json() |
| | verify_admin(data.get("password")) |
| | new_key = data.get("key", "").strip() |
| | |
| | if not new_key: |
| | raise HTTPException(status_code=400, detail="Key is required") |
| | if new_key in keys: |
| | raise HTTPException(status_code=400, detail="Key already exists") |
| | |
| | keys.append(new_key) |
| | key_status_map[new_key] = KeyStatus(new_key) |
| | |
| | |
| | global key_cycle |
| | key_cycle = cycle(keys) |
| | |
| | |
| | if not Config.KEYS_URL: |
| | with open("key.txt", "w") as f: |
| | f.write("\n".join(keys)) |
| | |
| | return {"status": "success"} |
| |
|
| | |
| | @app.delete("/api/keys/{key}") |
| | async def delete_key(key: str, password: str): |
| | verify_admin(password) |
| | if key in keys: |
| | keys.remove(key) |
| | if key in key_status_map: |
| | del key_status_map[key] |
| | |
| | |
| | global key_cycle |
| | key_cycle = cycle(keys) |
| | |
| | |
| | if not Config.KEYS_URL: |
| | with open("key.txt", "w") as f: |
| | f.write("\n".join(keys)) |
| | |
| | return {"status": "success"} |
| |
|
| | |
| | @app.post("/api/keys/delete-batch") |
| | async def delete_keys_batch(request: Request): |
| | data = await request.json() |
| | verify_admin(data.get("password")) |
| | keys_to_delete = data.get("keys", []) |
| | |
| | for key in keys_to_delete: |
| | if key in keys: |
| | keys.remove(key) |
| | if key in key_status_map: |
| | del key_status_map[key] |
| | |
| | |
| | global key_cycle |
| | key_cycle = cycle(keys) |
| | |
| | |
| | if not Config.KEYS_URL: |
| | with open("key.txt", "w") as f: |
| | f.write("\n".join(keys)) |
| | |
| | return {"status": "success"} |
| |
|
| | |
| | @app.get("/api/keys/check/{key}") |
| | async def check_single_key(key: str, password: str): |
| | verify_admin(password) |
| | if key not in keys: |
| | raise HTTPException(status_code=404, detail="Key not found") |
| | |
| | is_valid = await check_key_status(key) |
| | return {"status": "success", "valid": is_valid} |
| |
|
| | |
| | @app.post("/api/keys/check-all") |
| | async def check_all_keys(password: str): |
| | verify_admin(password) |
| | for key in keys: |
| | await check_key_status(key) |
| | return {"status": "success"} |
| | |
| | @app.get("/api/v1/models") |
| | async def list_models(): |
| | |
| | try: |
| | key = first_key |
| | if key in key_status_map and key_status_map[key].status == "cooling": |
| | key = get_valid_key() |
| | |
| | headers = { |
| | "Authorization": f"Bearer {key}", |
| | "Content-Type": "application/json" |
| | } |
| | |
| | async with httpx.AsyncClient() as client: |
| | response = await client.get( |
| | f"{Config.OPENAI_API_BASE}/models", |
| | headers=headers |
| | ) |
| | |
| | if response.status_code == 429: |
| | mark_key_cooling(key) |
| | |
| | key = get_valid_key() |
| | headers["Authorization"] = f"Bearer {key}" |
| | response = await client.get( |
| | f"{Config.OPENAI_API_BASE}/models", |
| | headers=headers |
| | ) |
| | |
| | return response.json() |
| | |
| | except Exception as e: |
| | print(f"Models Error: {str(e)}") |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| | |
| | @app.post("/api/v1/chat/completions") |
| | async def chat_completions(request: Request): |
| | try: |
| | |
| | body = await request.body() |
| | body_json = json.loads(body) |
| | |
| | |
| | key = get_valid_key() |
| | |
| | |
| | headers = { |
| | "Authorization": f"Bearer {key}", |
| | "Content-Type": "application/json", |
| | "Accept": "text/event-stream" if body_json.get("stream") else "application/json" |
| | } |
| | |
| | |
| | url = f"{Config.OPENAI_API_BASE}/chat/completions" |
| | |
| | async with httpx.AsyncClient(timeout=60.0) as client: |
| | response = await client.post( |
| | url, |
| | headers=headers, |
| | json=body_json |
| | ) |
| | |
| | |
| | if response.status_code == 429: |
| | mark_key_cooling(key) |
| | |
| | key = get_valid_key() |
| | headers["Authorization"] = f"Bearer {key}" |
| | response = await client.post( |
| | url, |
| | headers=headers, |
| | json=body_json |
| | ) |
| | |
| | |
| | if response.status_code != 200: |
| | return Response( |
| | content=response.text, |
| | status_code=response.status_code, |
| | media_type=response.headers.get("content-type", "application/json") |
| | ) |
| | |
| | |
| | if body_json.get("stream"): |
| | return StreamingResponse( |
| | stream_generator(response), |
| | media_type="text/event-stream", |
| | headers={ |
| | "Cache-Control": "no-cache", |
| | "Connection": "keep-alive", |
| | "Content-Type": "text/event-stream" |
| | } |
| | ) |
| | |
| | |
| | return Response( |
| | content=response.text, |
| | media_type=response.headers.get("content-type", "application/json") |
| | ) |
| | |
| | except Exception as e: |
| | print(f"Chat Error: {str(e)}") |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| | |
| | @app.api_route("/api/v1/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) |
| | async def proxy(path: str, request: Request): |
| | if path == "chat/completions": |
| | return await chat_completions(request) |
| | |
| | try: |
| | method = request.method |
| | body = await request.body() if method in ["POST", "PUT"] else None |
| | |
| | |
| | key = first_key |
| | if key in key_status_map and key_status_map[key].status == "cooling": |
| | key = get_valid_key() |
| | |
| | headers = { |
| | "Authorization": f"Bearer {key}", |
| | "Content-Type": "application/json" |
| | } |
| | |
| | async with httpx.AsyncClient() as client: |
| | response = await client.request( |
| | method=method, |
| | url=f"{Config.OPENAI_API_BASE}/{path}", |
| | headers=headers, |
| | content=body |
| | ) |
| | |
| | if response.status_code == 429: |
| | mark_key_cooling(key) |
| | |
| | key = get_valid_key() |
| | headers["Authorization"] = f"Bearer {key}" |
| | response = await client.request( |
| | method=method, |
| | url=f"{Config.OPENAI_API_BASE}/{path}", |
| | headers=headers, |
| | content=body |
| | ) |
| | |
| | return Response( |
| | content=response.text, |
| | status_code=response.status_code, |
| | media_type=response.headers.get("content-type", "application/json") |
| | ) |
| | |
| | except Exception as e: |
| | print(f"Proxy Error: {str(e)}") |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| | |
| | @app.get("/api/health") |
| | async def health_check(): |
| | return {"status": "healthy", "key_count": len(keys)} |
| |
|
| | |
| | @app.on_event("startup") |
| | async def startup_event(): |
| | init_keys() |