File size: 10,144 Bytes
19fc84f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 | import os
import json
import random
from jsonschema import RefResolver, validate
from collections import defaultdict
from src.rag.retrieval import DocDB
from src.data_processor.raw_data_processor import IRawDataProcessor
from src.data_processor.fact_score_processor import FactScoreProcessor
from src.data_processor.hotpot_qa_processor import HotpotQAProcessor
from src.data_processor.pop_qa_processor import PopQAProcessor
from src.data_processor.medlf_qa_processor import MedLFQAProcessor
class QueryProcessor(IRawDataProcessor):
"""Main query processor that delegates to specific dataset processors"""
def __init__(
self,
db_path: str = "data/raw/WikiDB/enwiki-20230401.db",
query_size: int = None,
):
self.db = DocDB(db_path=db_path, data_path=None)
self.dataset = None
self.query_size = query_size
self.processors = {
"fact_score": FactScoreProcessor(),
"hotpot_qa": HotpotQAProcessor(),
"pop_qa": PopQAProcessor(),
"medlf_qa": MedLFQAProcessor(),
}
def get_queries(
self,
dataset: str,
input_file: str,
output_dir: str,
output_file: str,
seed: int = 42,
):
"""
Reads raw data from a file and extracts queries, storing them in a JSON file.
Returns a dictionary mapping query inputs to their answers.
Args:
dataset: The name of the dataset to process
input_file: Path to the input file with raw data
output_file: Path where processed queries will be saved
query_size: Number of queries to sample (None or -1 for all)
seed: Random seed for reproducible sampling
Returns:
dict: A dictionary mapping query inputs to their answers
"""
self.dataset = dataset
self.input_file = input_file
# Case 1: Output file already exists - load instead of process
query_path = os.path.join(output_dir, output_file)
if os.path.exists(query_path):
print(f"{query_path} already exists.")
with open(query_path, "r", encoding="utf-8") as jsonfile:
queries = json.load(jsonfile)
# Case 2: Output file doesn't exist - process and save
else:
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Get the appropriate processor
processor = self.processors.get(dataset)
if not processor:
raise ValueError(f"Unsupported dataset: {dataset}")
# Process the queries
queries = processor.process_queries(input_file)
# Save processed queries
with open(query_path, "w", encoding="utf-8") as jsonfile:
json.dump(queries, jsonfile, indent=4)
print(f"Queries saved to {output_file}")
# Sample queries if needed
if self.query_size and self.query_size != -1 and len(queries) > self.query_size:
random.seed(seed)
if "groups" in queries[0]:
# Build group-to-queries mapping
group_to_queries = defaultdict(list)
for item in queries:
for group in item.get("groups", []):
group_to_queries[group].append(item)
# Sort base on group size ascending
group_sizes = {g: len(qs) for g, qs in group_to_queries.items()}
sorted_groups = sorted(group_sizes.items(), key=lambda x: x[1]) # (group, count)
remaining_size = self.query_size
group_allocation = {}
# sample query from each group as even as possible
# if smallest group has less items than required (e.g. sample 1000 quries from 5 groups, but smallest only have 100 items), take all
# rest (900) will be in remaining size for other (4) groups, do it until the next group size is greater than sample required
# First pass: allocate full group if it's too small
remaining_groups = []
for group, size in sorted_groups:
fair_share = remaining_size // (len(sorted_groups) - len(group_allocation)) if (len(sorted_groups) - len(group_allocation)) > 0 else 0
if size <= fair_share:
group_allocation[group] = size
remaining_size -= size
else:
remaining_groups.append(group)
# Second pass: fair allocation among remaining groups
for group in remaining_groups:
fair_share = remaining_size // (len(remaining_groups) - len([g for g in group_allocation if g in remaining_groups]))
allocated = min(fair_share, group_sizes[group])
group_allocation[group] = allocated
remaining_size -= allocated
# Now sample
sampled = []
for group, count in group_allocation.items():
sampled.extend(random.sample(group_to_queries[group], count))
self.queries = sampled
else:
self.queries = random.sample(queries, self.query_size)
# Write the sampled queries back to the output file
query_path = os.path.join(
output_dir, f"sampled_{self.query_size}_{output_file}"
)
with open(query_path, "w", encoding="utf-8") as jsonfile:
json.dump(self.queries, jsonfile, indent=4)
else:
self.queries = queries
# Create input to answer mapping
return {
query["input"]: query["output"]["answer"] for query in self.queries
}, query_path
def get_documents(self, query_dir: str, output_dir: str, output_file: str) -> str:
"""
Reads structured query data from a JSON file and generates a corresponding document list.
Args:
query_dir: Directory containing query data.
output_dir: Directory to save the output file.
output_file: Name of the output file.
Returns:
Path to the output file.
"""
# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)
# Construct output path
output_path = os.path.join(
output_dir, f"sampled_{self.query_size}_{output_file}"
)
# Return if output file already exists
if os.path.exists(output_path):
print(f"{output_path} already exists.")
return output_path
# Validate processor exists for the dataset
processor = self.processors.get(self.dataset)
if not processor:
raise ValueError(f"Unsupported dataset: {self.dataset}")
# Validate schema for specific datasets
if self.dataset in ["fact_score", "hotpot_qa", "pop_qa"]:
for query in self.queries:
self._validate_schema(query)
# Determine queries to use
queries_to_use = None
if self.query_size and self.query_size != -1:
queries_to_use = self.queries
# Process documents
documents = processor.process_documents(
query_dir, self.db, queries_to_use, raw_query_dir=self.input_file
)
# Save documents to output file
with open(output_path, "w", encoding="utf-8") as jsonfile:
json.dump(documents, jsonfile, indent=4, ensure_ascii=False)
print(f"Document list saved to {output_path}.")
return output_path
def _validate_schema(self, query: dict):
"""Validate a query against schema"""
base_schema = None
wiki_schema = None
with open(
"data/processed/base_schema.json", "r", encoding="utf-8"
) as schemafile:
base_schema = json.load(schemafile)
with open(
"data/processed/wiki_schema.json", "r", encoding="utf-8"
) as schemafile:
wiki_schema = json.load(schemafile)
resolver = RefResolver("data/processed/base_schema.json", base_schema)
validate(instance=query, schema=wiki_schema, resolver=resolver)
if __name__ == "__main__":
# wiki_query_processor = QueryProcessor(db_path="data/raw/WikiDB/enwiki-20230401.db")
# wiki_query_processor.get_queries(dataset="fact_score", input_file="data/raw/FactScore/raw_fact_score.json", output_file="data/processed/FactScore/fact_score_queries.json")
# wiki_query_processor.get_documents(query_dir="data/processed/FactScore/fact_score_queries.json", output_file="data/processed/FactScore/fact_score_documents.txt")
# wiki_query_processor = QueryProcessor(db_path="data/raw/WikiDB/enwiki-20230401.db")
# wiki_query_processor.get_queries(dataset="hotpot_qa", input_file="data/raw/HotpotQA/raw_hotpot_qa.json", output_file="data/processed/HotpotQA/hotpot_qa_queries.json")
# wiki_query_processor.get_documents(query_dir="data/processed/HotpotQA/hotpot_qa_queries.json", output_file="data/processed/HotpotQA/hotpot_qa_documents.txt")
# wiki_query_processor = QueryProcessor(db_path="data/raw/WikiDB/enwiki-20230401.db")
# wiki_query_processor.get_queries(dataset="pop_qa", input_file="data/raw/PopQA/raw_pop_qa.json", output_file="data/processed/PopQA/pop_qa_queries.json")
# wiki_query_processor.get_documents(query_dir="data/processed/PopQA/pop_qa_queries.json", output_file="data/processed/PopQA/pop_qa_documents.txt")
medlf_query_processor = QueryProcessor(db_path="data/raw/WikiDB/enwiki-20230401.db")
medlf_query_processor.get_queries(
dataset="medlf_qa",
input_file="data/raw/MedLFQA",
output_file="data/processed/MedLFQA/medlf_qa_queries.json",
)
medlf_query_processor.get_documents(
query_dir="data/processed/MedLFQA/medlf_qa_queries.json",
output_file="data/processed/MedLFQA/medlf_qa_documents.txt",
)
|