Bibek Raj Ghimire commited on
Commit
e95a894
·
1 Parent(s): 201ecae

Update reranker.py (#99)

Browse files
Files changed (1) hide show
  1. sage/reranker.py +36 -21
sage/reranker.py CHANGED
@@ -20,30 +20,45 @@ class RerankerProvider(Enum):
20
  VOYAGE = "voyage"
21
 
22
 
23
- def build_reranker(provider: str, model: Optional[str] = None, top_k: Optional[int] = 5) -> BaseDocumentCompressor:
24
  if provider == RerankerProvider.NONE.value:
25
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  if provider == RerankerProvider.HUGGINGFACE.value:
27
- model = model or "cross-encoder/ms-marco-MiniLM-L-6-v2"
28
  encoder_model = HuggingFaceCrossEncoder(model_name=model)
29
  return CrossEncoderReranker(model=encoder_model, top_n=top_k)
30
- if provider == RerankerProvider.COHERE.value:
31
- if not os.environ.get("COHERE_API_KEY"):
32
- raise ValueError("Please set the COHERE_API_KEY environment variable")
33
- model = model or "rerank-english-v3.0"
34
- return CohereRerank(model=model, cohere_api_key=os.environ.get("COHERE_API_KEY"), top_n=top_k)
35
- if provider == RerankerProvider.NVIDIA.value:
36
- if not os.environ.get("NVIDIA_API_KEY"):
37
- raise ValueError("Please set the NVIDIA_API_KEY environment variable")
38
- model = model or "nvidia/nv-rerankqa-mistral-4b-v3"
39
- return NVIDIARerank(model=model, api_key=os.environ.get("NVIDIA_API_KEY"), top_n=top_k, truncate="END")
40
- if provider == RerankerProvider.JINA.value:
41
- if not os.environ.get("JINA_API_KEY"):
42
- raise ValueError("Please set the JINA_API_KEY environment variable")
43
- return JinaRerank(top_n=top_k)
44
- if provider == RerankerProvider.VOYAGE.value:
45
- if not os.environ.get("VOYAGE_API_KEY"):
46
- raise ValueError("Please set the VOYAGE_API_KEY environment variable")
47
- model = model or "rerank-1"
48
- return VoyageAIRerank(model=model, api_key=os.environ.get("VOYAGE_API_KEY"), top_k=top_k)
49
  raise ValueError(f"Invalid reranker provider: {provider}")
 
20
  VOYAGE = "voyage"
21
 
22
 
23
+ def build_reranker(provider: str, model: Optional[str] = None, top_k: int = 5) -> Optional[BaseDocumentCompressor]:
24
  if provider == RerankerProvider.NONE.value:
25
  return None
26
+
27
+ api_key_env_vars = {
28
+ RerankerProvider.COHERE.value: "COHERE_API_KEY",
29
+ RerankerProvider.NVIDIA.value: "NVIDIA_API_KEY",
30
+ RerankerProvider.JINA.value: "JINA_API_KEY",
31
+ RerankerProvider.VOYAGE.value: "VOYAGE_API_KEY"
32
+ }
33
+
34
+ provider_defaults = {
35
+ RerankerProvider.HUGGINGFACE.value: "cross-encoder/ms-marco-MiniLM-L-6-v2",
36
+ RerankerProvider.COHERE.value: "rerank-english-v3.0",
37
+ RerankerProvider.NVIDIA.value: "nvidia/nv-rerankqa-mistral-4b-v3",
38
+ RerankerProvider.VOYAGE.value: "rerank-1"
39
+ }
40
+
41
+ model = model or provider_defaults.get(provider)
42
+
43
  if provider == RerankerProvider.HUGGINGFACE.value:
 
44
  encoder_model = HuggingFaceCrossEncoder(model_name=model)
45
  return CrossEncoderReranker(model=encoder_model, top_n=top_k)
46
+
47
+ if provider in api_key_env_vars:
48
+ api_key = os.getenv(api_key_env_vars[provider])
49
+ if not api_key:
50
+ raise ValueError(f"Please set the {api_key_env_vars[provider]} environment variable")
51
+
52
+ if provider == RerankerProvider.COHERE.value:
53
+ return CohereRerank(model=model, cohere_api_key=api_key, top_n=top_k)
54
+
55
+ if provider == RerankerProvider.NVIDIA.value:
56
+ return NVIDIARerank(model=model, api_key=api_key, top_n=top_k, truncate="END")
57
+
58
+ if provider == RerankerProvider.JINA.value:
59
+ return JinaRerank(top_n=top_k)
60
+
61
+ if provider == RerankerProvider.VOYAGE.value:
62
+ return VoyageAIRerank(model=model, api_key=api_key, top_k=top_k)
63
+
 
64
  raise ValueError(f"Invalid reranker provider: {provider}")