| import os |
| import time |
| import threading |
| import asyncio |
| from contextlib import asynccontextmanager |
| from typing import List, Optional, Union, Any |
|
|
| import httpx |
| import numpy as np |
| from fastapi import FastAPI, HTTPException, Request, Depends |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import JSONResponse |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
| from pydantic import BaseModel, Field |
|
|
| |
| |
| |
| API_KEY = os.environ.get("API_KEY", "") |
| if not API_KEY: |
| print("[WARNING] API_KEY environment variable is not set, all requests will be rejected!") |
| EMBED_MODEL_ID = "BAAI/bge-m3" |
| RERANK_MODEL_ID = "BAAI/bge-reranker-v2-m3" |
| SELF_URL = "http://localhost:7860" |
| KEEPALIVE_SEC = 240 |
| HF_HOME = os.environ.get("HF_HOME", "/app/hf_cache") |
|
|
| os.environ["HF_HOME"] = HF_HOME |
|
|
| |
| |
| |
| models: dict = { |
| "embed": None, |
| "reranker": None, |
| "embed_status": "loading", |
| "rerank_status": "loading", |
| "start_time": time.time(), |
| } |
|
|
|
|
| |
| |
| |
| def load_models(): |
| try: |
| from FlagEmbedding import BGEM3FlagModel |
| models["embed_status"] = "loading" |
| models["embed"] = BGEM3FlagModel( |
| EMBED_MODEL_ID, |
| use_fp16=False, |
| ) |
| models["embed_status"] = "ready" |
| print("[INFO] Embedding model loaded โ") |
| except Exception as e: |
| models["embed_status"] = f"error: {e}" |
| print(f"[ERROR] Embedding model failed: {e}") |
|
|
| try: |
| from FlagEmbedding import FlagReranker |
| models["rerank_status"] = "loading" |
| models["reranker"] = FlagReranker( |
| RERANK_MODEL_ID, |
| use_fp16=False, |
| ) |
| models["rerank_status"] = "ready" |
| print("[INFO] Reranker model loaded โ") |
| except Exception as e: |
| models["rerank_status"] = f"error: {e}" |
| print(f"[ERROR] Reranker model failed: {e}") |
|
|
|
|
| def keepalive_loop(): |
| """ๅๅฐ็บฟ็จ๏ผๅฎๆถ ping ่ช่บซ๏ผ้ฒๆญข HF Spaces ไผ็ """ |
| time.sleep(60) |
| while True: |
| try: |
| import httpx as _httpx |
| _httpx.get(f"{SELF_URL}/health", timeout=10) |
| print(f"[KEEPALIVE] ping ok @ {time.strftime('%H:%M:%S')}") |
| except Exception as e: |
| print(f"[KEEPALIVE] ping failed: {e}") |
| time.sleep(KEEPALIVE_SEC) |
|
|
|
|
| |
| |
| |
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| |
| threading.Thread(target=load_models, daemon=True).start() |
| threading.Thread(target=keepalive_loop, daemon=True).start() |
| yield |
|
|
|
|
| |
| |
| |
| app = FastAPI( |
| title="BGE API", |
| version="1.0.0", |
| lifespan=lifespan, |
| ) |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
|
|
| |
| |
| |
| security = HTTPBearer(auto_error=False) |
|
|
| def verify_api_key( |
| request: Request, |
| credentials: Optional[HTTPAuthorizationCredentials] = Depends(security), |
| ): |
| |
| token = None |
| if credentials: |
| token = credentials.credentials |
| else: |
| token = request.query_params.get("api_key") |
|
|
| if token != API_KEY: |
| raise HTTPException( |
| status_code=401, |
| detail={"error": {"message": "Invalid API key", "type": "invalid_request_error"}}, |
| ) |
| return token |
|
|
|
|
| |
| |
| |
| class EmbeddingRequest(BaseModel): |
| input: Union[str, List[str]] |
| model: str = "bge-m3:latest" |
| encoding_format: str = "float" |
|
|
| class RerankRequest(BaseModel): |
| model: str = "BAAI/bge-reranker-v2-m3" |
| query: str |
| documents: List[str] |
| top_n: Optional[int] = None |
| return_documents: bool = False |
|
|
| class EmbeddingObject(BaseModel): |
| object: str = "embedding" |
| index: int |
| embedding: List[float] |
|
|
| class Usage(BaseModel): |
| prompt_tokens: int |
| total_tokens: int |
|
|
| class EmbeddingResponse(BaseModel): |
| object: str = "list" |
| data: List[EmbeddingObject] |
| model: str |
| usage: Usage |
|
|
| class RerankResult(BaseModel): |
| index: int |
| relevance_score: float |
| document: Optional[Any] = None |
|
|
| class RerankResponse(BaseModel): |
| object: str = "list" |
| model: str |
| results: List[RerankResult] |
| usage: Usage |
|
|
|
|
| |
| |
| |
|
|
| @app.get("/", include_in_schema=False) |
| async def root(): |
| """ๆ น็ฎๅฝ๏ผๆพ็คบๆจกๅ่ฟ่ก็ถๆ""" |
| uptime = int(time.time() - models["start_time"]) |
| return JSONResponse({ |
| "service": "BGE Embedding & Reranker API", |
| "version": "1.0.0", |
| "status": "running", |
| "uptime_sec": uptime, |
| "models": { |
| "embedding": { |
| "id": "bge-m3:latest", |
| "hf_id": EMBED_MODEL_ID, |
| "status": models["embed_status"], |
| }, |
| "reranker": { |
| "id": "BAAI/bge-reranker-v2-m3", |
| "hf_id": RERANK_MODEL_ID, |
| "status": models["rerank_status"], |
| }, |
| }, |
| "endpoints": [ |
| "GET /v1/models", |
| "POST /v1/embeddings", |
| "POST /v1/rerank", |
| "GET /health", |
| ], |
| }) |
|
|
|
|
| @app.get("/health") |
| async def health(): |
| embed_ok = models["embed_status"] == "ready" |
| rerank_ok = models["rerank_status"] == "ready" |
| all_ok = embed_ok and rerank_ok |
| return JSONResponse( |
| status_code=200 if all_ok else 503, |
| content={ |
| "status": "ok" if all_ok else "degraded", |
| "embed_status": models["embed_status"], |
| "rerank_status": models["rerank_status"], |
| } |
| ) |
|
|
|
|
| @app.get("/v1/models", dependencies=[Depends(verify_api_key)]) |
| async def list_models(): |
| """OpenAI ๅ
ผๅฎน็ๆจกๅๅ่กจ""" |
| now = int(time.time()) |
| return { |
| "object": "list", |
| "data": [ |
| { |
| "id": "bge-m3:latest", |
| "object": "model", |
| "created": now, |
| "owned_by": "BAAI", |
| }, |
| { |
| "id": "BAAI/bge-reranker-v2-m3", |
| "object": "model", |
| "created": now, |
| "owned_by": "BAAI", |
| }, |
| ], |
| } |
|
|
|
|
| @app.post("/v1/embeddings", dependencies=[Depends(verify_api_key)]) |
| async def create_embeddings(req: EmbeddingRequest): |
| if models["embed_status"] != "ready": |
| raise HTTPException( |
| status_code=503, |
| detail=f"Embedding model not ready: {models['embed_status']}", |
| ) |
|
|
| texts = [req.input] if isinstance(req.input, str) else req.input |
| if not texts: |
| raise HTTPException(status_code=400, detail="input cannot be empty") |
|
|
| try: |
| result = models["embed"].encode( |
| texts, |
| batch_size=12, |
| max_length=8192, |
| return_dense=True, |
| return_sparse=False, |
| return_colbert_vecs=False, |
| ) |
| dense_vecs = result["dense_vecs"] |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| data = [ |
| EmbeddingObject(index=i, embedding=vec.tolist()) |
| for i, vec in enumerate(dense_vecs) |
| ] |
| total_tokens = sum(len(t.split()) for t in texts) |
|
|
| return EmbeddingResponse( |
| data=data, |
| model="bge-m3:latest", |
| usage=Usage(prompt_tokens=total_tokens, total_tokens=total_tokens), |
| ) |
|
|
|
|
| @app.post("/v1/rerank", dependencies=[Depends(verify_api_key)]) |
| async def rerank(req: RerankRequest): |
| if models["rerank_status"] != "ready": |
| raise HTTPException( |
| status_code=503, |
| detail=f"Reranker model not ready: {models['rerank_status']}", |
| ) |
| if not req.documents: |
| raise HTTPException(status_code=400, detail="documents cannot be empty") |
|
|
| try: |
| pairs = [[req.query, doc] for doc in req.documents] |
| scores = models["reranker"].compute_score(pairs, normalize=True) |
| if isinstance(scores, float): |
| scores = [scores] |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| ranked = sorted( |
| enumerate(scores), |
| key=lambda x: x[1], |
| reverse=True, |
| ) |
|
|
| top_n = req.top_n or len(ranked) |
| results = [] |
| for rank_idx, (doc_idx, score) in enumerate(ranked[:top_n]): |
| item = RerankResult( |
| index=doc_idx, |
| relevance_score=float(score), |
| ) |
| if req.return_documents: |
| item.document = {"text": req.documents[doc_idx]} |
| results.append(item) |
|
|
| total_tokens = len(req.query.split()) + sum(len(d.split()) for d in req.documents) |
|
|
| return RerankResponse( |
| model="BAAI/bge-reranker-v2-m3", |
| results=results, |
| usage=Usage(prompt_tokens=total_tokens, total_tokens=total_tokens), |
| ) |