from fastapi import FastAPI, HTTPException, Header, Depends, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse, JSONResponse import uvicorn import os import logging import httpx import json import asyncio logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI(title="FreeLLMAPI") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # ────────────────────────────────────────── # Provider 配置 # ────────────────────────────────────────── PROVIDER_MAP = { "GOOGLE_API_KEY": "google", "GROQ_API_KEY": "groq", "GITHUB_TOKEN": "github", "OPENROUTER_API_KEY": "openrouter", "MISTRAL_API_KEY": "mistral", "TOGETHER_API_KEY": "together", "NVIDIA_API_KEY": "nvidia", "COHERE_API_KEY": "cohere", "HF_TOKEN": "huggingface", "CEREBRAS_API_KEY": "cerebras", "SAMBANOVA_API_KEY": "sambanova", "CLOUDFLARE_API_TOKEN": "cloudflare", "ZHIPU_API_KEY": "zhipu", } PROVIDER_CONFIG = { "google": { "base_url": "https://generativelanguage.googleapis.com/v1beta/openai", "models": [ "gemini-2.0-flash", "gemini-2.0-flash-lite", "gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b", ], }, "groq": { "base_url": "https://api.groq.com/openai/v1", "models": [ "llama-3.3-70b-versatile", "llama-3.1-8b-instant", "llama3-70b-8192", "llama3-8b-8192", "mixtral-8x7b-32768", "gemma2-9b-it", ], }, "github": { "base_url": "https://models.inference.ai.azure.com", "models": [ "gpt-4o", "gpt-4o-mini", "Phi-3.5-mini-instruct", "Phi-3.5-MoE-instruct", "Meta-Llama-3.1-70B-Instruct", "Meta-Llama-3.1-405B-Instruct", ], }, "openrouter": { "base_url": "https://openrouter.ai/api/v1", "models": [ "mistralai/mistral-7b-instruct:free", "meta-llama/llama-3.2-3b-instruct:free", "google/gemma-3-1b-it:free", "deepseek/deepseek-r1:free", "deepseek/deepseek-chat:free", ], }, "mistral": { "base_url": "https://api.mistral.ai/v1", "models": [ "mistral-small-latest", "mistral-large-latest", "open-mistral-7b", "open-mixtral-8x7b", ], }, "together": { "base_url": "https://api.together.xyz/v1", "models": [ "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo", "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", "mistralai/Mixtral-8x7B-Instruct-v0.1", ], }, "nvidia": { "base_url": "https://integrate.api.nvidia.com/v1", "models": [ "meta/llama-3.1-70b-instruct", "meta/llama-3.1-8b-instruct", "mistralai/mixtral-8x7b-instruct", ], }, "cohere": { "base_url": "https://api.cohere.com/v2", "models": [ "command-r-plus", "command-r", "command", ], }, "huggingface": { "base_url": "https://api-inference.huggingface.co/v1", "models": [ "meta-llama/Llama-3.2-3B-Instruct", "mistralai/Mistral-7B-Instruct-v0.3", ], }, "cerebras": { "base_url": "https://api.cerebras.ai/v1", "models": [ "llama3.1-8b", "llama3.1-70b", ], }, "sambanova": { "base_url": "https://api.sambanova.ai/v1", "models": [ "Meta-Llama-3.1-8B-Instruct", "Meta-Llama-3.1-70B-Instruct", "Meta-Llama-3.1-405B-Instruct", ], }, "cloudflare": { "base_url": "https://api.cloudflare.com/client/v4/accounts/CF_ACCOUNT_ID/ai/v1", "models": [ "@cf/meta/llama-3.1-8b-instruct", "@cf/mistral/mistral-7b-instruct-v0.1", ], }, "zhipu": { "base_url": "https://open.bigmodel.cn/api/paas/v4", "models": [ "glm-4-flash", "glm-4", "glm-3-turbo", ], }, } # ────────────────────────────────────────── # 启动加载 # ────────────────────────────────────────── def load_config(): raw_keys = os.getenv("API_KEYS", "") api_keys = set(k.strip() for k in raw_keys.split(",") if k.strip()) providers = {} for env_name, provider_name in PROVIDER_MAP.items(): key_value = os.getenv(env_name, "").strip() if not key_value: continue cfg = PROVIDER_CONFIG.get(provider_name, {}) base_url = cfg.get("base_url", "") # Cloudflare 需要 Account ID if provider_name == "cloudflare": cf_account = os.getenv("CLOUDFLARE_ACCOUNT_ID", "") base_url = base_url.replace("CF_ACCOUNT_ID", cf_account) providers[provider_name] = { "api_key": key_value, "base_url": base_url.rstrip("/"), "models": cfg.get("models", []), } logger.info(f"✅ 加载 Provider: {provider_name}") logger.info(f"✅ API Keys 数量: {len(api_keys)}") logger.info(f"✅ Provider 数量: {len(providers)} {list(providers.keys())}") return api_keys, providers API_KEYS, PROVIDERS = load_config() # model → provider 快速查找表 MODEL_PROVIDER: dict = {} for _pname, _pcfg in PROVIDERS.items(): for _m in _pcfg["models"]: MODEL_PROVIDER[_m] = _pname # ────────────────────────────────────────── # 鉴权 # ────────────────────────────────────────── def verify_api_key(authorization: str = Header(...)): token = authorization.removeprefix("Bearer ").strip() if token not in API_KEYS: raise HTTPException(status_code=401, detail="Invalid API key") return token # ────────────────────────────────────────── # /v1 路由 # ────────────────────────────────────────── from fastapi import APIRouter v1 = APIRouter(prefix="/v1") # ── GET /v1 ── 访问 /v1 不报 404,返回说明 @v1.get("") @v1.get("/") async def v1_root(): return { "message": "FreeLLMAPI is running", "endpoints": { "models": "/v1/models", "chat_completions": "/v1/chat/completions", }, "providers": list(PROVIDERS.keys()), "total_models": len(MODEL_PROVIDER), } # ── GET /v1/models ── @v1.get("/models") async def list_models(_: str = Depends(verify_api_key)): data = [] for provider_name, pcfg in PROVIDERS.items(): for m in pcfg["models"]: data.append({ "id": m, "object": "model", "owned_by": provider_name, "created": 0, }) return {"object": "list", "data": data} # ── POST /v1/chat/completions ── @v1.post("/chat/completions") async def chat_completions( request: Request, _: str = Depends(verify_api_key), ): body = await request.json() model = body.get("model", "") stream = body.get("stream", False) # 查找 provider provider_name = MODEL_PROVIDER.get(model) if not provider_name: raise HTTPException( status_code=404, detail=f"模型 '{model}' 不存在,可用模型: {list(MODEL_PROVIDER.keys())}", ) provider = PROVIDERS[provider_name] url = f"{provider['base_url']}/chat/completions" headers = { "Authorization": f"Bearer {provider['api_key']}", "Content-Type": "application/json", } logger.info(f"转发请求 → {provider_name} | 模型: {model} | 流式: {stream}") # 流式响应中增加状态码检查 # ── 流式响应 ──(注意:if 在函数体内,缩进 4 空格) if stream: async def event_stream(): async with httpx.AsyncClient(timeout=120) as client: async with client.stream("POST", url, headers=headers, json=body) as resp: if resp.status_code != 200: err = await resp.aread() logger.error(f"上游流式错误 {resp.status_code}") yield json.dumps({"error": err.decode(errors="ignore")}) return async for chunk in resp.aiter_text(): yield chunk # 这个 return 必须缩进 8 空格(在 if stream: 内部) return StreamingResponse(event_stream(), media_type="text/event-stream") # ── 普通响应 ── async with httpx.AsyncClient(timeout=120) as client: resp = await client.post(url, headers=headers, json=body) if resp.status_code != 200: logger.error(f"上游错误 {resp.status_code}: {resp.text[:300]}") raise HTTPException( status_code=resp.status_code, detail=resp.text, ) return resp.json() # ── POST /v1/embeddings(部分 Provider 支持)── @v1.post("/embeddings") async def embeddings( request: Request, _: str = Depends(verify_api_key), ): body = await request.json() model = body.get("model", "") provider_name = MODEL_PROVIDER.get(model) if not provider_name: raise HTTPException(status_code=404, detail=f"模型 '{model}' 不存在") provider = PROVIDERS[provider_name] url = f"{provider['base_url']}/embeddings" headers = {"Authorization": f"Bearer {provider['api_key']}"} async with httpx.AsyncClient(timeout=60) as client: resp = await client.post(url, headers=headers, json=body) return resp.json() # 注册路由 app.include_router(v1) # ────────────────────────────────────────── # 根路由 # ────────────────────────────────────────── @app.get("/") async def root(): return { "message": "FreeLLMAPI", "version": "1.0.0", "docs": "/docs", "health": "/health", "api_base": "/v1", "providers": list(PROVIDERS.keys()), "total_models": len(MODEL_PROVIDER), } @app.get("/health") async def health(): return { "status": "ok", "api_keys": len(API_KEYS), "providers": list(PROVIDERS.keys()), "total_models": len(MODEL_PROVIDER), } # ────────────────────────────────────────── # 调试路由 # ────────────────────────────────────────── @app.get("/debug") async def debug(): env_check = {} for env_name in PROVIDER_MAP.keys(): val = os.getenv(env_name, "") env_check[env_name] = "✅ 已设置" if val else "❌ 未设置" routes = [] for route in app.routes: if hasattr(route, "methods"): routes.append({ "path": route.path, "methods": list(route.methods), }) providers_info = {} for name, cfg in PROVIDERS.items(): providers_info[name] = { "base_url": cfg["base_url"], "models": cfg["models"], "api_key": cfg["api_key"][:6] + "******", } return { "① 环境变量": env_check, "② API_KEYS": "✅ 已设置" if os.getenv("API_KEYS") else "❌ 未设置", "③ Providers": providers_info, "④ 全部模型": list(MODEL_PROVIDER.keys()), "⑤ 注册路由": routes, "⑥ Key数量": len(API_KEYS), } @app.get("/debug/test/{provider_name}") async def test_provider(provider_name: str): if provider_name not in PROVIDERS: return { "status": "❌ 未加载", "已加载": list(PROVIDERS.keys()), } cfg = PROVIDERS[provider_name] try: async with httpx.AsyncClient(timeout=10) as client: resp = await client.get( f"{cfg['base_url']}/models", headers={"Authorization": f"Bearer {cfg['api_key']}"}, ) return { "status": "✅ 连通" if resp.status_code == 200 else f"⚠️ {resp.status_code}", "provider": provider_name, "status_code": resp.status_code, } except Exception as e: return {"status": "❌ 失败", "error": str(e)} # ────────────────────────────────────────── # 启动 # ────────────────────────────────────────── if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)