Sarp Bilgiç commited on
Commit
60917e8
·
1 Parent(s): 71afb18

reranking client fix

Browse files
Files changed (1) hide show
  1. src/clients/reranker_client.py +6 -4
src/clients/reranker_client.py CHANGED
@@ -1,17 +1,19 @@
1
  from fastembed.rerank.cross_encoder import TextCrossEncoder
2
  from typing import Optional
3
  import logging
4
- from src.core.settings import settings
5
  logger = logging.getLogger(__name__)
6
 
7
 
8
  class RerankerClient:
9
- def __init__(self, cache_dir: str = "./model_cache"):
10
- self.model_name = settings.reranker_model
 
 
11
  logger.info(f"Loading reranker: {self.model_name}")
12
 
13
  self.model = TextCrossEncoder(model_name=self.model_name, cache_dir=cache_dir)
14
  logger.info("Reranker ready")
15
 
16
  def rerank(self, query: str, documents: list[str]):
17
- return list(self.model.rerank(query, documents))
 
1
  from fastembed.rerank.cross_encoder import TextCrossEncoder
2
  from typing import Optional
3
  import logging
4
+
5
  logger = logging.getLogger(__name__)
6
 
7
 
8
  class RerankerClient:
9
+ DEFAULT_MODEL = "jinaai/jina-reranker-v1-turbo-en"
10
+
11
+ def __init__(self, model_name: Optional[str] = None, cache_dir: str = "./model_cache"):
12
+ self.model_name = model_name or self.DEFAULT_MODEL
13
  logger.info(f"Loading reranker: {self.model_name}")
14
 
15
  self.model = TextCrossEncoder(model_name=self.model_name, cache_dir=cache_dir)
16
  logger.info("Reranker ready")
17
 
18
  def rerank(self, query: str, documents: list[str]):
19
+ return self.model.rerank(query, documents)