Spaces:
Runtime error
Runtime error
| from haystack.document_stores.faiss import FAISSDocumentStore | |
| from haystack.nodes.retriever import EmbeddingRetriever | |
| from haystack.nodes.ranker import BaseRanker | |
| from haystack.pipelines import Pipeline | |
| from haystack.document_stores.base import BaseDocumentStore | |
| from haystack.schema import Document | |
| from typing import Optional, List | |
| import gradio as gr | |
| import numpy as np | |
| import requests | |
| import os | |
| RETRIEVER_URL = os.getenv("RETRIEVER_URL") | |
| RANKER_URL = os.getenv("RANKER_URL") | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| class Retriever(EmbeddingRetriever): | |
| def __init__( | |
| self, | |
| document_store: Optional[BaseDocumentStore] = None, | |
| top_k: int = 10, | |
| batch_size: int = 32, | |
| scale_score: bool = True, | |
| ): | |
| self.document_store = document_store | |
| self.top_k = top_k | |
| self.batch_size = batch_size | |
| self.scale_score = scale_score | |
| def embed_queries(self, queries: List[str]) -> np.ndarray: | |
| response = requests.post( | |
| RETRIEVER_URL, | |
| json={"queries": queries, "inputs": ""}, | |
| headers={"Authorization": f"Bearer {HF_TOKEN}"}, | |
| ) | |
| arrays = np.array(response.json()) | |
| return arrays | |
| def embed_documents(self, documents: List[Document]) -> np.ndarray: | |
| response = requests.post( | |
| RETRIEVER_URL, | |
| json={"documents": [d.to_dict() for d in documents], "inputs": ""}, | |
| headers={"Authorization": f"Bearer {HF_TOKEN}"}, | |
| ) | |
| arrays = np.array(response.json()) | |
| return arrays | |
| class Ranker(BaseRanker): | |
| def predict( | |
| self, query: str, documents: List[Document], top_k: Optional[int] = None | |
| ) -> List[Document]: | |
| documents = [d.to_dict() for d in documents] | |
| for doc in documents: | |
| doc["embedding"] = doc["embedding"].tolist() | |
| response = requests.post( | |
| RANKER_URL, | |
| json={ | |
| "query": query, | |
| "documents": documents, | |
| "top_k": top_k, | |
| "inputs": "", | |
| }, | |
| headers={"Authorization": f"Bearer {HF_TOKEN}"}, | |
| ).json() | |
| if "error" in response: | |
| raise Exception(response["error"]) | |
| return [Document.from_dict(d) for d in response] | |
| def predict_batch( | |
| self, | |
| queries: List[str], | |
| documents: List[List[Document]], | |
| batch_size: Optional[int] = None, | |
| top_k: Optional[int] = None, | |
| ) -> List[List[Document]]: | |
| documents = [[d.to_dict() for d in docs] for docs in documents] | |
| for docs in documents: | |
| for doc in docs: | |
| doc["embedding"] = doc["embedding"].tolist() | |
| response = requests.post( | |
| RANKER_URL, | |
| json={ | |
| "queries": queries, | |
| "documents": documents, | |
| "batch_size": batch_size, | |
| "top_k": top_k, | |
| "inputs": "", | |
| }, | |
| ).json() | |
| if "error" in response: | |
| raise Exception(response["error"]) | |
| return [[Document.from_dict(d) for d in docs] for docs in response] | |
| TOP_K = 2 | |
| BATCH_SIZE = 16 | |
| EXAMPLES = [ | |
| "There is a blue house on Oxford Street.", | |
| "Paris is the capital of France.", | |
| "The Eiffel Tower is in Paris.", | |
| "The Louvre is in Paris.", | |
| "London is the capital of England.", | |
| "Cairo is the capital of Egypt.", | |
| "The pyramids are in Egypt.", | |
| "The Sphinx is in Egypt.", | |
| ] | |
| if os.path.exists("faiss_document_store.db"): | |
| os.remove("faiss_document_store.db") | |
| document_store = FAISSDocumentStore(embedding_dim=384, return_embedding=True) | |
| document_store.write_documents( | |
| [Document(content=d, id=i) for i, d in enumerate(EXAMPLES)] | |
| ) | |
| retriever = Retriever(document_store=document_store, top_k=TOP_K, batch_size=BATCH_SIZE) | |
| document_store.update_embeddings(retriever=retriever) | |
| ranker = Ranker() | |
| pipe = Pipeline() | |
| pipe.add_node(component=retriever, name="Retriever", inputs=["Query"]) | |
| pipe.add_node(component=ranker, name="Ranker", inputs=["Retriever"]) | |
| def run(query: str) -> dict: | |
| output = pipe.run(query=query) | |
| return ( | |
| f"Closest document(s): {[output['documents'][i].content for i in range(TOP_K)]}" | |
| ) | |
| # warm up | |
| run("What is the capital of France?") | |
| gr.Interface( | |
| fn=run, | |
| inputs="text", | |
| outputs="text", | |
| title="Pipeline", | |
| examples=["What is the capital of France?"], | |
| description="A pipeline for retrieving and ranking documents.", | |
| ).launch() | |