fahmiaziz98 commited on
Commit
e3f7ce0
·
1 Parent(s): 90528a8

[UPDATE]: add docstring

Browse files
Files changed (1) hide show
  1. src/api/routers/rerank.py +36 -17
src/api/routers/rerank.py CHANGED
@@ -1,11 +1,12 @@
1
  """
2
- Rerank endpoint
3
 
4
- This module provides routes for rerank documents
 
5
  """
6
 
7
  import time
8
- from typing import Union
9
  from fastapi import APIRouter, Depends, HTTPException, status
10
  from loguru import logger
11
 
@@ -21,21 +22,39 @@ from src.core.exceptions import (
21
  from src.api.dependencies import get_model_manager
22
  from src.utils.validators import extract_embedding_kwargs
23
 
24
- router = APIRouter(prefix="rerank", tags=["rerank"])
25
 
26
 
27
  @router.post(
28
- "/", response_model=RerankResponse, summary="reranking documents", description=""
29
  )
30
  async def rerank_documents(
31
  request: RerankRequest,
32
  manager: ModelManager = Depends(get_model_manager),
33
- ) -> Union[RerankResponse, HTTPException]:
34
- """"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  # Filter out empty documents and keep original indices
36
  valid_docs = [
37
  (i, doc.strip()) for i, doc in enumerate(request.documents) if doc.strip()
38
  ]
 
 
 
 
39
  try:
40
  kwargs = extract_embedding_kwargs(request)
41
  model = manager.get_model(request.model_id)
@@ -44,26 +63,26 @@ async def rerank_documents(
44
  start = time.time()
45
  if config.type == "rerank":
46
  scores = model.rank_document(
47
- request.query, request.documents, request.top_k, **kwargs
48
  )
49
  processing_time = time.time() - start
50
 
51
  original_indices, documents_list = zip(*valid_docs)
52
- results: list[RerankResult] = []
53
 
54
  for i, (orig_idx, doc) in enumerate(zip(original_indices, documents_list)):
55
  results.append(RerankResult(text=doc, score=scores[i], index=orig_idx))
56
 
57
- # Sort by score descending
58
  results.sort(key=lambda x: x.score, reverse=True)
59
 
60
- logger.info(f"Rerank documents in {processing_time:.3f} seconds")
61
- return RerankResponse(
62
- model_id=request.model_id,
63
- processing_time=processing_time,
64
- query=request.query,
65
- results=results,
66
- )
67
 
68
  except (ValidationError, ModelNotFoundError) as e:
69
  raise HTTPException(status_code=e.status_code, detail=e.message)
 
1
  """
2
+ Rerank Endpoint Module
3
 
4
+ This module provides routes for reranking documents based on a query.
5
+ It accepts a list of documents and returns a ranked list based on relevance to the query.
6
  """
7
 
8
  import time
9
+ from typing import Union, List
10
  from fastapi import APIRouter, Depends, HTTPException, status
11
  from loguru import logger
12
 
 
22
  from src.api.dependencies import get_model_manager
23
  from src.utils.validators import extract_embedding_kwargs
24
 
25
+ router = APIRouter(tags=["rerank"])
26
 
27
 
28
  @router.post(
29
+ "/rerank", response_model=RerankResponse, summary="Rerank documents", description="Reranks the provided documents based on the given query."
30
  )
31
  async def rerank_documents(
32
  request: RerankRequest,
33
  manager: ModelManager = Depends(get_model_manager),
34
+ ) -> RerankResponse:
35
+ """
36
+ Rerank documents based on a query.
37
+
38
+ This endpoint processes a list of documents and returns them ranked according to their relevance to the query.
39
+
40
+ Args:
41
+ request (RerankRequest): The request object containing the query and documents to rank.
42
+ manager (ModelManager): The model manager dependency to access the model.
43
+
44
+ Returns:
45
+ RerankResponse: The response containing the ranked documents and processing time.
46
+
47
+ Raises:
48
+ HTTPException: If there are validation errors, model loading issues, or unexpected errors.
49
+ """
50
  # Filter out empty documents and keep original indices
51
  valid_docs = [
52
  (i, doc.strip()) for i, doc in enumerate(request.documents) if doc.strip()
53
  ]
54
+
55
+ if not valid_docs:
56
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="No valid documents provided.")
57
+
58
  try:
59
  kwargs = extract_embedding_kwargs(request)
60
  model = manager.get_model(request.model_id)
 
63
  start = time.time()
64
  if config.type == "rerank":
65
  scores = model.rank_document(
66
+ request.query, [doc for _, doc in valid_docs], request.top_k, **kwargs
67
  )
68
  processing_time = time.time() - start
69
 
70
  original_indices, documents_list = zip(*valid_docs)
71
+ results: List[RerankResult] = []
72
 
73
  for i, (orig_idx, doc) in enumerate(zip(original_indices, documents_list)):
74
  results.append(RerankResult(text=doc, score=scores[i], index=orig_idx))
75
 
76
+ # Sort results by score in descending order
77
  results.sort(key=lambda x: x.score, reverse=True)
78
 
79
+ logger.info(f"Reranked documents in {processing_time:.3f} seconds")
80
+ return RerankResponse(
81
+ model_id=request.model_id,
82
+ processing_time=processing_time,
83
+ query=request.query,
84
+ results=results,
85
+ )
86
 
87
  except (ValidationError, ModelNotFoundError) as e:
88
  raise HTTPException(status_code=e.status_code, detail=e.message)