Spaces:
Running
Running
| from typing import Any, Coroutine | |
| import httpx | |
| from langchain_core.callbacks.manager import ( | |
| AsyncCallbackManagerForRetrieverRun, | |
| CallbackManagerForRetrieverRun, | |
| ) | |
| from langchain_core.documents import Document | |
| from langchain_core.retrievers import BaseRetriever | |
| from pydantic import Field, PrivateAttr, model_validator | |
| from .settings import EmmRetrieversSettings | |
| def as_lc_docs(dicts: list[dict]) -> list[Document]: | |
| return [ | |
| Document(page_content=d["page_content"], metadata=d["metadata"]) for d in dicts | |
| ] | |
| # the simple retriver is built with fixed spec/filter/params/route config | |
| # and the can be used many times with different queries. | |
| # Note these are cheap to construct. | |
| class EmmRetrieverV1(BaseRetriever): | |
| settings: EmmRetrieversSettings | |
| spec: dict | |
| filter: dict | None = None | |
| params: dict = Field(default_factory=dict) | |
| route: str = "/r/rag-minimal/query" | |
| add_ref_key: bool = True | |
| _client: httpx.Client = PrivateAttr() | |
| _aclient: httpx.AsyncClient = PrivateAttr() | |
| # ------- interface impl: | |
| def _get_relevant_documents( | |
| self, query: str, *, run_manager: CallbackManagerForRetrieverRun | |
| ) -> list[Document]: | |
| r = self._client.post(**self.search_post_kwargs(query)) | |
| if r.status_code == 422: | |
| print("ERROR:\n", r.json()) | |
| r.raise_for_status() | |
| resp = r.json() | |
| return self._as_lc_docs(resp["documents"]) | |
| async def _aget_relevant_documents( | |
| self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun | |
| ) -> Coroutine[Any, Any, list[Document]]: | |
| r = await self._aclient.post(**self.search_post_kwargs(query)) | |
| if r.status_code == 422: | |
| print("ERROR:\n", r.json()) | |
| r.raise_for_status() | |
| resp = r.json() | |
| return self._as_lc_docs(resp["documents"]) | |
| # --------- | |
| def create_clients(self): | |
| _auth_headers = { | |
| "Authorization": f"Bearer {self.settings.API_KEY.get_secret_value()}" | |
| } | |
| kwargs = dict( | |
| base_url=self.settings.API_BASE, | |
| headers=_auth_headers, | |
| timeout=self.settings.DEFAULT_TIMEOUT, | |
| ) | |
| self._client = httpx.Client(**kwargs) | |
| self._aclient = httpx.AsyncClient(**kwargs) | |
| return self | |
| def apply_default_params(self): | |
| self.params = { | |
| **{ | |
| "cluster_name": self.settings.DEFAULT_CLUSTER, | |
| "index": self.settings.DEFAULT_INDEX, | |
| }, | |
| **(self.params or {}), | |
| } | |
| return self | |
| def _as_lc_docs(self, dicts: list[dict]) -> list[Document]: | |
| docs = as_lc_docs(dicts) | |
| if self.add_ref_key: | |
| for i, d in enumerate(docs): | |
| d.metadata["ref_key"] = i | |
| return docs | |
| def search_post_kwargs(self, query: str): | |
| return dict( | |
| url=self.route, | |
| params=self.params, | |
| json={"query": query, "spec": self.spec, "filter": self.filter}, | |
| ) | |