| from typing import List | |
| import requests | |
| class Retriever: | |
| def __init__(self): | |
| self.config = { | |
| "search_url": "http://127.0.0.1:8000/retrieve", | |
| "topk": 3 | |
| } | |
| def batch_search(self, queries: List[str] = None) -> List[str]: | |
| """ | |
| Batchified search for queries. | |
| Args: | |
| queries: queries to call the search engine | |
| Returns: | |
| search results which is concatenated into a string | |
| """ | |
| results = self._batch_search(queries)['result'] | |
| return [self._passages2string(result) for result in results] | |
| def _batch_search(self, queries): | |
| payload = { | |
| "queries": queries, | |
| "topk": self.config["topk"], | |
| "return_scores": True | |
| } | |
| return requests.post(self.config["search_url"], json=payload).json() | |
| def _passages2string(self, retrieval_result): | |
| format_reference = '' | |
| for idx, doc_item in enumerate(retrieval_result): | |
| content = doc_item['document']['contents'] | |
| title = content.split("\n")[0] | |
| text = "\n".join(content.split("\n")[1:]) | |
| format_reference += f"Doc {idx+1}(Title: {title}) {text}\n" | |
| return format_reference |