model111 / larm /data /utils /retrieval_utils.py
LCZZZZ's picture
Upload MemGen code and data
e34b94f verified
from typing import List
import requests
class Retriever:
def __init__(self):
self.config = {
"search_url": "http://127.0.0.1:8000/retrieve",
"topk": 3
}
def batch_search(self, queries: List[str] = None) -> List[str]:
"""
Batchified search for queries.
Args:
queries: queries to call the search engine
Returns:
search results which is concatenated into a string
"""
results = self._batch_search(queries)['result']
return [self._passages2string(result) for result in results]
def _batch_search(self, queries):
payload = {
"queries": queries,
"topk": self.config["topk"],
"return_scores": True
}
return requests.post(self.config["search_url"], json=payload).json()
def _passages2string(self, retrieval_result):
format_reference = ''
for idx, doc_item in enumerate(retrieval_result):
content = doc_item['document']['contents']
title = content.split("\n")[0]
text = "\n".join(content.split("\n")[1:])
format_reference += f"Doc {idx+1}(Title: {title}) {text}\n"
return format_reference