|
|
from typing import Any, cast |
|
|
|
|
|
from langchain.retrievers import ContextualCompressionRetriever |
|
|
|
|
|
from langflow.base.vectorstores.model import ( |
|
|
LCVectorStoreComponent, |
|
|
check_cached_vector_store, |
|
|
) |
|
|
from langflow.field_typing import Retriever, VectorStore |
|
|
from langflow.io import ( |
|
|
DropdownInput, |
|
|
HandleInput, |
|
|
MultilineInput, |
|
|
SecretStrInput, |
|
|
StrInput, |
|
|
) |
|
|
from langflow.schema import Data |
|
|
from langflow.schema.dotdict import dotdict |
|
|
from langflow.template.field.base import Output |
|
|
|
|
|
|
|
|
class NvidiaRerankComponent(LCVectorStoreComponent): |
|
|
display_name = "NVIDIA Rerank" |
|
|
description = "Rerank documents using the NVIDIA API and a retriever." |
|
|
icon = "NVIDIA" |
|
|
legacy: bool = True |
|
|
|
|
|
inputs = [ |
|
|
MultilineInput( |
|
|
name="search_query", |
|
|
display_name="Search Query", |
|
|
), |
|
|
StrInput( |
|
|
name="base_url", |
|
|
display_name="Base URL", |
|
|
value="https://integrate.api.nvidia.com/v1", |
|
|
refresh_button=True, |
|
|
info="The base URL of the NVIDIA API. Defaults to https://integrate.api.nvidia.com/v1.", |
|
|
), |
|
|
DropdownInput( |
|
|
name="model", |
|
|
display_name="Model", |
|
|
options=["nv-rerank-qa-mistral-4b:1"], |
|
|
value="nv-rerank-qa-mistral-4b:1", |
|
|
), |
|
|
SecretStrInput(name="api_key", display_name="API Key"), |
|
|
HandleInput(name="retriever", display_name="Retriever", input_types=["Retriever"]), |
|
|
] |
|
|
|
|
|
outputs = [ |
|
|
Output( |
|
|
display_name="Retriever", |
|
|
name="base_retriever", |
|
|
method="build_base_retriever", |
|
|
), |
|
|
Output( |
|
|
display_name="Search Results", |
|
|
name="search_results", |
|
|
method="search_documents", |
|
|
), |
|
|
] |
|
|
|
|
|
def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None): |
|
|
if field_name == "base_url" and field_value: |
|
|
try: |
|
|
build_model = self.build_model() |
|
|
ids = [model.id for model in build_model.available_models] |
|
|
build_config["model"]["options"] = ids |
|
|
build_config["model"]["value"] = ids[0] |
|
|
except Exception as e: |
|
|
msg = f"Error getting model names: {e}" |
|
|
raise ValueError(msg) from e |
|
|
return build_config |
|
|
|
|
|
def build_model(self): |
|
|
try: |
|
|
from langchain_nvidia_ai_endpoints import NVIDIARerank |
|
|
except ImportError as e: |
|
|
msg = "Please install langchain-nvidia-ai-endpoints to use the NVIDIA model." |
|
|
raise ImportError(msg) from e |
|
|
return NVIDIARerank(api_key=self.api_key, model=self.model, base_url=self.base_url) |
|
|
|
|
|
def build_base_retriever(self) -> Retriever: |
|
|
nvidia_reranker = self.build_model() |
|
|
retriever = ContextualCompressionRetriever(base_compressor=nvidia_reranker, base_retriever=self.retriever) |
|
|
return cast("Retriever", retriever) |
|
|
|
|
|
async def search_documents(self) -> list[Data]: |
|
|
retriever = self.build_base_retriever() |
|
|
documents = await retriever.ainvoke(self.search_query, config={"callbacks": self.get_langchain_callbacks()}) |
|
|
data = self.to_data(documents) |
|
|
self.status = data |
|
|
return data |
|
|
|
|
|
@check_cached_vector_store |
|
|
def build_vector_store(self) -> VectorStore: |
|
|
msg = "NVIDIA Rerank does not support vector stores." |
|
|
raise NotImplementedError(msg) |
|
|
|