File size: 2,245 Bytes
19fc84f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import json

from src.rag.retrieval import DocDB
from src.utils.string_utils import extract_tag_content
from src.data_processor.raw_data_processor import DatasetProcessor


class HotpotQAProcessor(DatasetProcessor):
    def process_queries(self, input_file: str) -> list:
        queries = []
        with open(input_file, "r", encoding="utf-8") as file:
            for line in file:
                data = json.loads(line)
                query = {
                    "input": data["input"],
                    "output": {
                        "answer": data["output"][0]["answer"],
                        "provenance": [
                            {
                                "title": item["title"],
                                "wikipedia_id": int(item["wikipedia_id"]),
                            }
                            for item in data["output"][0]["provenance"]
                        ],
                    },
                }
                queries.append(query)
        return queries

    def process_documents(
        self, query_file: str, db: DocDB, queries: dict = None, **kwargs
    ) -> dict:
        documents = {}

        # if sampled queries are provided, use them instead of the queries in the query_file
        # however, for medlfqa, the query file is mandatory
        if queries is None:
            with open(query_file, "r", encoding="utf-8") as jsonfile:
                queries = json.load(jsonfile)

        for query in queries:
            for provenance in query["output"]["provenance"]:
                title = provenance["title"]
                if title not in documents:
                    document = self._get_documents_per_query(title, db)
                    documents[title] = document
        return documents

    def _get_documents_per_query(self, title: str, db: DocDB) -> list:
        """Returns a list of documents for a given query."""
        contents = ""
        try:
            docs = db.get_text_from_title(title)
            for data in docs:
                contents += data["text"]
            return extract_tag_content(contents)
        except Exception as e:
            print(f"Error retrieving documents for title {title}: {e}")
            return []