File size: 3,637 Bytes
90528a8
e3f7ce0
90528a8
e3f7ce0
 
90528a8
 
 
e3f7ce0
90528a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3f7ce0
90528a8
 
 
e3f7ce0
90528a8
 
 
 
e3f7ce0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90528a8
 
 
 
e3f7ce0
 
 
 
90528a8
 
 
 
 
 
 
 
e3f7ce0
90528a8
 
 
 
e3f7ce0
90528a8
 
 
 
e3f7ce0
90528a8
 
e3f7ce0
 
 
 
 
 
 
90528a8
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""
Rerank Endpoint Module

This module provides routes for reranking documents based on a query.
It accepts a list of documents and returns a ranked list based on relevance to the query.
"""

import time
from typing import Union, List
from fastapi import APIRouter, Depends, HTTPException, status
from loguru import logger

from src.models.schemas import RerankRequest, RerankResponse, RerankResult
from src.core.manager import ModelManager
from src.core.exceptions import (
    ModelNotFoundError,
    ModelNotLoadedError,
    RerankingDocumentError,
    ValidationError,
)

from src.api.dependencies import get_model_manager
from src.utils.validators import extract_embedding_kwargs

router = APIRouter(tags=["rerank"]) 


@router.post(
    "/rerank", response_model=RerankResponse, summary="Rerank documents", description="Reranks the provided documents based on the given query."
)
async def rerank_documents(
    request: RerankRequest,
    manager: ModelManager = Depends(get_model_manager),
) -> RerankResponse:
    """
    Rerank documents based on a query.

    This endpoint processes a list of documents and returns them ranked according to their relevance to the query.
    
    Args:
        request (RerankRequest): The request object containing the query and documents to rank.
        manager (ModelManager): The model manager dependency to access the model.

    Returns:
        RerankResponse: The response containing the ranked documents and processing time.

    Raises:
        HTTPException: If there are validation errors, model loading issues, or unexpected errors.
    """
    # Filter out empty documents and keep original indices
    valid_docs = [
        (i, doc.strip()) for i, doc in enumerate(request.documents) if doc.strip()
    ]

    if not valid_docs:
        raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="No valid documents provided.")

    try:
        kwargs = extract_embedding_kwargs(request)
        model = manager.get_model(request.model_id)
        config = manager.model_configs[request.model_id]

        start = time.time()
        if config.type == "rerank":
            scores = model.rank_document(
                request.query, [doc for _, doc in valid_docs], request.top_k, **kwargs
            )
            processing_time = time.time() - start

            original_indices, documents_list = zip(*valid_docs)
            results: List[RerankResult] = []

            for i, (orig_idx, doc) in enumerate(zip(original_indices, documents_list)):
                results.append(RerankResult(text=doc, score=scores[i], index=orig_idx))

            # Sort results by score in descending order
            results.sort(key=lambda x: x.score, reverse=True)

            logger.info(f"Reranked documents in {processing_time:.3f} seconds")
            return RerankResponse(
                model_id=request.model_id,
                processing_time=processing_time,
                query=request.query,
                results=results,
            )

    except (ValidationError, ModelNotFoundError) as e:
        raise HTTPException(status_code=e.status_code, detail=e.message)
    except ModelNotLoadedError as e:
        raise HTTPException(status_code=e.status_code, detail=e.message)
    except RerankingDocumentError as e:
        raise HTTPException(status_code=e.status_code, detail=e.message)
    except Exception as e:
        logger.exception("Unexpected error in rerank_documents")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=f"Failed to create query embedding: {str(e)}",
        )