File size: 3,785 Bytes
90528a8
e3f7ce0
90528a8
e3f7ce0
 
90528a8
 
 
 
ecfc38e
90528a8
 
31b67ca
90528a8
 
 
 
 
 
 
 
79829bb
90528a8
376886a
90528a8
 
 
3b88f19
2e5859f
 
90528a8
 
 
 
ecfc38e
e3f7ce0
 
 
2e5859f
80db4a8
2e5859f
e3f7ce0
2e5859f
 
e3f7ce0
 
2e5859f
e3f7ce0
 
2e5859f
e3f7ce0
79829bb
90528a8
 
 
e3f7ce0
 
2e5859f
 
 
 
e3f7ce0
90528a8
 
2e5859f
 
 
 
 
79829bb
 
90528a8
79829bb
90528a8
2e5859f
 
fb8f5fc
9e5acab
fb8f5fc
2e5859f
fb8f5fc
2e5859f
 
 
 
 
90528a8
fc5bdd9
376886a
fb8f5fc
376886a
fb8f5fc
 
 
376886a
31b67ca
 
 
 
 
 
2e5859f
 
 
79829bb
2e5859f
 
ecfc38e
90528a8
 
 
 
 
 
 
 
 
 
 
2e5859f
376886a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""
Rerank Endpoint Module

This module provides routes for reranking documents based on a query.
It accepts a list of documents and returns a ranked list based on relevance to the query.
"""

import time
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import JSONResponse
from loguru import logger

from src.models.schemas import RerankRequest
from src.core.manager import ModelManager
from src.core.exceptions import (
    ModelNotFoundError,
    ModelNotLoadedError,
    RerankingDocumentError,
    ValidationError,
)
from src.api.dependencies import get_model_manager
from src.utils.validators import extract_embedding_kwargs, ensure_model_type

router = APIRouter(prefix="/rerank", tags=["rerank"])


@router.post(
    "",
    summary="Rerank documents",
    description="Reranks the provided documents based on the given query.",
)
async def rerank_documents(
    request: RerankRequest,
    manager: ModelManager = Depends(get_model_manager),
):
    """
    Rerank documents based on a query.

    This endpoint processes a list of documents and returns them ranked
    according to their relevance to the query.

    Args:
        request: The request object containing the query and documents to rank
        manager: The model manager dependency to access the model

    Returns:
        RerankResponse: The response containing the ranked documents and processing time

    Raises:
        HTTPException: If there are validation errors, model loading issues, or unexpected errors
    """

    valid_docs = [
        (i, doc.strip()) for i, doc in enumerate(request.documents) if doc.strip()
    ]

    if not valid_docs:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail="No valid documents provided.",
        )

    try:
        kwargs = extract_embedding_kwargs(request)

        kwargs.pop("query", None)
        kwargs.pop("documents", None)
        kwargs.pop("top_k", None)

        model = manager.get_model(request.model)
        config = manager.model_configs[request.model]

        ensure_model_type(config, "rerank", request.model)

        start = time.time()

        documents_list = [doc for _, doc in valid_docs]

        ranking_results = model.rank_document(
            query=request.query,
            documents=documents_list,
            top_k=request.top_k,
            **kwargs,
        )

        processing_time = time.time() - start

        results = []

        for rank_result in ranking_results:
            doc_idx = rank_result.get("corpus_id", 0)
            if doc_idx < len(valid_docs):
                original_idx = valid_docs[doc_idx][0]  # Original index
                doc_text = documents_list[doc_idx]
                score = rank_result["score"]
        
                results.append({
                    "text": doc_text,
                    "score": float(score),     
                    "index": int(original_idx)
                })

        logger.info(
            f"Reranked {len(results)} documents in {processing_time:.3f}s "
            f"(model: {request.model})"
        )

        return JSONResponse(content=results)

    except (ValidationError, ModelNotFoundError) as e:
        raise HTTPException(status_code=e.status_code, detail=e.message)
    except ModelNotLoadedError as e:
        raise HTTPException(status_code=e.status_code, detail=e.message)
    except RerankingDocumentError as e:
        raise HTTPException(status_code=e.status_code, detail=e.message)
    except Exception as e:
        logger.exception("Unexpected error in rerank_documents")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=f"Failed to rerank documents: {str(e)}",
        )