| from fastapi import FastAPI, HTTPException, Depends, Header |
| from pydantic import BaseModel, Field |
| from sentence_transformers import SentenceTransformer |
| import logging, os |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| |
| async def check_authorization(authorization: str = Header(..., alias="Authorization")): |
| |
| if not authorization.startswith("Bearer "): |
| raise HTTPException(status_code=401, detail="Invalid Authorization header format") |
| |
| token = authorization[len("Bearer "):] |
| if token != os.environ.get("AUTHORIZATION"): |
| raise HTTPException(status_code=401, detail="Unauthorized access") |
| return token |
|
|
| app = FastAPI() |
|
|
| try: |
| |
| model = SentenceTransformer("BAAI/bge-reranker-large") |
| logger.info("Reranker model loaded successfully.") |
| except Exception as e: |
| logger.error(f"Failed to load model: {e}") |
| raise HTTPException(status_code=500, detail="Model loading failed. Check logs for details.") |
|
|
| class RerankerRequest(BaseModel): |
| query: str = Field(..., min_length=1, max_length=1000, description="The query text.") |
| documents: list[str] = Field(..., min_items=2, description="A list of documents to rerank.") |
| truncate: bool = Field(False, description="Whether to truncate the documents.") |
|
|
| @app.post("/rerank") |
| |
| async def rerank(request: RerankerRequest): |
| query = request.query |
| documents = request.documents |
|
|
| try: |
| if not query or not documents: |
| raise HTTPException(status_code=400, detail="Query and documents must be provided.") |
|
|
| from sentence_transformers import util |
|
|
| |
| query_embedding = model.encode(query, convert_to_tensor=True) |
| document_embeddings = model.encode(documents, convert_to_tensor=True) |
|
|
| |
| scores = util.cos_sim(query_embedding, document_embeddings)[0].tolist() |
|
|
| |
| results = [{"document": doc, "score": score} for doc, score in zip(documents, scores)] |
|
|
| |
| ranked_results = sorted(results, key=lambda x: x["score"], reverse=True) |
|
|
| return { |
| "object": "list", |
| "data": ranked_results, |
| "model": "BAAI/bge-reranker-large" |
| } |
| except Exception as e: |
| logger.error(f"Error processing reranking: {e}") |
| raise HTTPException(status_code=500, detail="Internal Server Error. Check logs for details.") |
|
|