File size: 5,281 Bytes
90528a8
e3f7ce0
90528a8
e3f7ce0
 
90528a8
 
 
2e5859f
90528a8
 
 
 
 
 
 
 
 
 
 
 
 
 
2e5859f
90528a8
 
 
2e5859f
 
 
 
90528a8
 
 
 
e3f7ce0
 
 
 
2e5859f
 
 
e3f7ce0
2e5859f
 
e3f7ce0
 
2e5859f
e3f7ce0
 
2e5859f
e3f7ce0
90528a8
 
 
 
e3f7ce0
 
2e5859f
 
 
 
e3f7ce0
90528a8
2e5859f
90528a8
2e5859f
 
 
 
 
 
 
90528a8
 
 
2e5859f
 
 
 
90528a8
 
fb8f5fc
 
 
 
 
 
 
2e5859f
 
fb8f5fc
 
 
 
 
2e5859f
fb8f5fc
2e5859f
 
 
 
 
90528a8
fb8f5fc
 
 
 
9958d9a
fb8f5fc
 
fc5bdd9
fb8f5fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e5859f
 
 
 
 
 
 
 
 
 
 
 
90528a8
 
 
 
 
 
 
 
 
 
 
2e5859f
fb8f5fc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""
Rerank Endpoint Module

This module provides routes for reranking documents based on a query.
It accepts a list of documents and returns a ranked list based on relevance to the query.
"""

import time
from typing import List
from fastapi import APIRouter, Depends, HTTPException, status
from loguru import logger

from src.models.schemas import RerankRequest, RerankResponse, RerankResult
from src.core.manager import ModelManager
from src.core.exceptions import (
    ModelNotFoundError,
    ModelNotLoadedError,
    RerankingDocumentError,
    ValidationError,
)
from src.api.dependencies import get_model_manager
from src.utils.validators import extract_embedding_kwargs

router = APIRouter(tags=["rerank"])


@router.post(
    "/rerank",
    response_model=RerankResponse,
    summary="Rerank documents",
    description="Reranks the provided documents based on the given query.",
)
async def rerank_documents(
    request: RerankRequest,
    manager: ModelManager = Depends(get_model_manager),
) -> RerankResponse:
    """
    Rerank documents based on a query.

    This endpoint processes a list of documents and returns them ranked
    according to their relevance to the query.

    Args:
        request: The request object containing the query and documents to rank
        manager: The model manager dependency to access the model

    Returns:
        RerankResponse: The response containing the ranked documents and processing time

    Raises:
        HTTPException: If there are validation errors, model loading issues, or unexpected errors
    """
    # Filter out empty documents and keep original indices
    valid_docs = [
        (i, doc.strip()) for i, doc in enumerate(request.documents) if doc.strip()
    ]

    if not valid_docs:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail="No valid documents provided.",
        )

    try:
        # Extract kwargs but exclude rerank-specific fields
        kwargs = extract_embedding_kwargs(request)

        # Remove fields that are already passed as positional arguments
        # to avoid "got multiple values for argument" error
        kwargs.pop("query", None)
        kwargs.pop("documents", None)
        kwargs.pop("top_k", None)

        model = manager.get_model(request.model_id)
        config = manager.model_configs[request.model_id]

        if config.type != "rerank":
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=f"Model '{request.model_id}' is not a rerank model. Type: {config.type}",
            )

        # Debug logs BEFORE calling rank_document
        logger.debug(f"Rerank request - Query: '{request.query}'")
        logger.debug(f"Documents to rank: {len(valid_docs)}")
        if valid_docs:
            logger.debug(f"First document: {valid_docs[0][1][:100]}...")
        logger.debug(f"Top K: {request.top_k}")

        start = time.time()

        # Extract documents for ranking
        documents_list = [doc for _, doc in valid_docs]
        
        # Call rank_document - returns only top_k results
        ranking_results = model.rank_document(
            query=request.query,
            documents=documents_list,
            top_k=request.top_k,
            **kwargs,
        )

        processing_time = time.time() - start

        # Debug logs AFTER rank_document
        logger.debug(f"Ranking returned {len(ranking_results)} results")
        if ranking_results:
            logger.debug(f"Top result score: {ranking_results[0]}")

        # Build results from ranking_results
        # ranking_results already contains top_k items with scores
        results = []
        
        for rank_result in ranking_results:
            # Get original index from valid_docs
            doc_idx = rank_result.get('corpus_id', 0)  # Index in filtered list
            if doc_idx < len(valid_docs):
                original_idx = valid_docs[doc_idx][0]  # Original index
                doc_text = documents_list[doc_idx]
                score = rank_result['score']
                
                results.append(
                    RerankResult(
                        text=doc_text,
                        score=score,
                        index=original_idx
                    )
                )

        logger.info(
            f"Reranked {len(results)} documents in {processing_time:.3f}s "
            f"(model: {request.model_id})"
        )

        return RerankResponse(
            model_id=request.model_id,
            processing_time=processing_time,
            query=request.query,
            results=results,
        )

    except (ValidationError, ModelNotFoundError) as e:
        raise HTTPException(status_code=e.status_code, detail=e.message)
    except ModelNotLoadedError as e:
        raise HTTPException(status_code=e.status_code, detail=e.message)
    except RerankingDocumentError as e:
        raise HTTPException(status_code=e.status_code, detail=e.message)
    except Exception as e:
        logger.exception("Unexpected error in rerank_documents")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=f"Failed to rerank documents: {str(e)}",
        )