|
|
from fastapi import FastAPI |
|
|
from reranker import RankLLM, RankListwiseOSLLM, Result, RankingExecInfo |
|
|
from pydantic import BaseModel |
|
|
from typing import Optional, List, Tuple |
|
|
|
|
|
|
|
|
reranker = RankListwiseOSLLM("Salesforce/SweRankLLM-small", device="cpu") |
|
|
|
|
|
class RerankRequest(BaseModel): |
|
|
query: str |
|
|
hits: List[Tuple[int, str]] |
|
|
|
|
|
class RerankResponse(BaseModel): |
|
|
hits: List[Tuple[int, str]] |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
@app.get("/") |
|
|
def hello_world(): |
|
|
return {"msg": "Success"} |
|
|
|
|
|
|
|
|
@app.get("/rerank") |
|
|
def rerank(request: RerankRequest): |
|
|
hits = request.hits |
|
|
sorted_hits = sorted(hits, key=lambda x: x[0]) |
|
|
|
|
|
result = Result( |
|
|
query=request["query"], |
|
|
hits = [{"content": hit} for hit in sorted_hits] |
|
|
) |
|
|
|
|
|
reranked_result = reranker.permutation_pipeline( |
|
|
result, |
|
|
1, |
|
|
len(hits), |
|
|
logging=True |
|
|
) |
|
|
|
|
|
response = [(i, item["content"]) for i, item in enumerate(reranked_result.hits)] |
|
|
|
|
|
return {"reranked": response} |
|
|
|
|
|
|