import os import time import threading import asyncio from contextlib import asynccontextmanager from typing import List, Optional, Union, Any import httpx import numpy as np from fastapi import FastAPI, HTTPException, Request, Depends from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from pydantic import BaseModel, Field # ────────────────────────────────────────── # 配置 # ────────────────────────────────────────── API_KEY = os.environ.get("API_KEY", "") if not API_KEY: print("[WARNING] API_KEY environment variable is not set, all requests will be rejected!") EMBED_MODEL_ID = "BAAI/bge-m3" RERANK_MODEL_ID = "BAAI/bge-reranker-v2-m3" SELF_URL = "http://localhost:7860" # 保活 ping 目标 KEEPALIVE_SEC = 240 # 每 4 分钟 ping 一次(HF 5分钟超时) HF_HOME = os.environ.get("HF_HOME", "/app/hf_cache") os.environ["HF_HOME"] = HF_HOME # ────────────────────────────────────────── # 全局模型状态 # ────────────────────────────────────────── models: dict = { "embed": None, "reranker": None, "embed_status": "loading", # loading | ready | error "rerank_status": "loading", "start_time": time.time(), } # ────────────────────────────────────────── # 模型加载(异步后台,不阻塞启动) # ────────────────────────────────────────── def load_models(): try: from FlagEmbedding import BGEM3FlagModel models["embed_status"] = "loading" models["embed"] = BGEM3FlagModel( EMBED_MODEL_ID, use_fp16=False, # CPU 不支持 fp16 ) models["embed_status"] = "ready" print("[INFO] Embedding model loaded ✓") except Exception as e: models["embed_status"] = f"error: {e}" print(f"[ERROR] Embedding model failed: {e}") try: from FlagEmbedding import FlagReranker models["rerank_status"] = "loading" models["reranker"] = FlagReranker( RERANK_MODEL_ID, use_fp16=False, ) models["rerank_status"] = "ready" print("[INFO] Reranker model loaded ✓") except Exception as e: models["rerank_status"] = f"error: {e}" print(f"[ERROR] Reranker model failed: {e}") def keepalive_loop(): """后台线程:定时 ping 自身,防止 HF Spaces 休眠""" time.sleep(60) # 等启动完成 while True: try: import httpx as _httpx _httpx.get(f"{SELF_URL}/health", timeout=10) print(f"[KEEPALIVE] ping ok @ {time.strftime('%H:%M:%S')}") except Exception as e: print(f"[KEEPALIVE] ping failed: {e}") time.sleep(KEEPALIVE_SEC) # ────────────────────────────────────────── # 生命周期 # ────────────────────────────────────────── @asynccontextmanager async def lifespan(app: FastAPI): # 启动:后台线程加载模型 + 保活线程 threading.Thread(target=load_models, daemon=True).start() threading.Thread(target=keepalive_loop, daemon=True).start() yield # ────────────────────────────────────────── # FastAPI App # ────────────────────────────────────────── app = FastAPI( title="BGE API", version="1.0.0", lifespan=lifespan, ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # ────────────────────────────────────────── # 鉴权 # ────────────────────────────────────────── security = HTTPBearer(auto_error=False) def verify_api_key( request: Request, credentials: Optional[HTTPAuthorizationCredentials] = Depends(security), ): # 支持 Bearer token 和 ?api_key= 参数两种方式 token = None if credentials: token = credentials.credentials else: token = request.query_params.get("api_key") if token != API_KEY: raise HTTPException( status_code=401, detail={"error": {"message": "Invalid API key", "type": "invalid_request_error"}}, ) return token # ────────────────────────────────────────── # Pydantic 模型(OpenAI 兼容格式) # ────────────────────────────────────────── class EmbeddingRequest(BaseModel): input: Union[str, List[str]] model: str = "bge-m3:latest" encoding_format: str = "float" class RerankRequest(BaseModel): model: str = "BAAI/bge-reranker-v2-m3" query: str documents: List[str] top_n: Optional[int] = None return_documents: bool = False class EmbeddingObject(BaseModel): object: str = "embedding" index: int embedding: List[float] class Usage(BaseModel): prompt_tokens: int total_tokens: int class EmbeddingResponse(BaseModel): object: str = "list" data: List[EmbeddingObject] model: str usage: Usage class RerankResult(BaseModel): index: int relevance_score: float document: Optional[Any] = None class RerankResponse(BaseModel): object: str = "list" model: str results: List[RerankResult] usage: Usage # ────────────────────────────────────────── # 路由 # ────────────────────────────────────────── @app.get("/", include_in_schema=False) async def root(): """根目录:显示模型运行状态""" uptime = int(time.time() - models["start_time"]) return JSONResponse({ "service": "BGE Embedding & Reranker API", "version": "1.0.0", "status": "running", "uptime_sec": uptime, "models": { "embedding": { "id": "bge-m3:latest", "hf_id": EMBED_MODEL_ID, "status": models["embed_status"], }, "reranker": { "id": "BAAI/bge-reranker-v2-m3", "hf_id": RERANK_MODEL_ID, "status": models["rerank_status"], }, }, "endpoints": [ "GET /v1/models", "POST /v1/embeddings", "POST /v1/rerank", "GET /health", ], }) @app.get("/health") async def health(): embed_ok = models["embed_status"] == "ready" rerank_ok = models["rerank_status"] == "ready" all_ok = embed_ok and rerank_ok return JSONResponse( status_code=200 if all_ok else 503, content={ "status": "ok" if all_ok else "degraded", "embed_status": models["embed_status"], "rerank_status": models["rerank_status"], } ) @app.get("/v1/models", dependencies=[Depends(verify_api_key)]) async def list_models(): """OpenAI 兼容的模型列表""" now = int(time.time()) return { "object": "list", "data": [ { "id": "bge-m3:latest", "object": "model", "created": now, "owned_by": "BAAI", }, { "id": "BAAI/bge-reranker-v2-m3", "object": "model", "created": now, "owned_by": "BAAI", }, ], } @app.post("/v1/embeddings", dependencies=[Depends(verify_api_key)]) async def create_embeddings(req: EmbeddingRequest): if models["embed_status"] != "ready": raise HTTPException( status_code=503, detail=f"Embedding model not ready: {models['embed_status']}", ) texts = [req.input] if isinstance(req.input, str) else req.input if not texts: raise HTTPException(status_code=400, detail="input cannot be empty") try: result = models["embed"].encode( texts, batch_size=12, max_length=8192, return_dense=True, return_sparse=False, return_colbert_vecs=False, ) dense_vecs = result["dense_vecs"] # numpy array except Exception as e: raise HTTPException(status_code=500, detail=str(e)) data = [ EmbeddingObject(index=i, embedding=vec.tolist()) for i, vec in enumerate(dense_vecs) ] total_tokens = sum(len(t.split()) for t in texts) return EmbeddingResponse( data=data, model="bge-m3:latest", usage=Usage(prompt_tokens=total_tokens, total_tokens=total_tokens), ) @app.post("/v1/rerank", dependencies=[Depends(verify_api_key)]) async def rerank(req: RerankRequest): if models["rerank_status"] != "ready": raise HTTPException( status_code=503, detail=f"Reranker model not ready: {models['rerank_status']}", ) if not req.documents: raise HTTPException(status_code=400, detail="documents cannot be empty") try: pairs = [[req.query, doc] for doc in req.documents] scores = models["reranker"].compute_score(pairs, normalize=True) if isinstance(scores, float): scores = [scores] except Exception as e: raise HTTPException(status_code=500, detail=str(e)) ranked = sorted( enumerate(scores), key=lambda x: x[1], reverse=True, ) top_n = req.top_n or len(ranked) results = [] for rank_idx, (doc_idx, score) in enumerate(ranked[:top_n]): item = RerankResult( index=doc_idx, relevance_score=float(score), ) if req.return_documents: item.document = {"text": req.documents[doc_idx]} results.append(item) total_tokens = len(req.query.split()) + sum(len(d.split()) for d in req.documents) return RerankResponse( model="BAAI/bge-reranker-v2-m3", results=results, usage=Usage(prompt_tokens=total_tokens, total_tokens=total_tokens), )