| import os |
| import json |
| import random |
| from jsonschema import RefResolver, validate |
| from collections import defaultdict |
|
|
| from src.rag.retrieval import DocDB |
| from src.data_processor.raw_data_processor import IRawDataProcessor |
| from src.data_processor.fact_score_processor import FactScoreProcessor |
| from src.data_processor.hotpot_qa_processor import HotpotQAProcessor |
| from src.data_processor.pop_qa_processor import PopQAProcessor |
| from src.data_processor.medlf_qa_processor import MedLFQAProcessor |
|
|
|
|
| class QueryProcessor(IRawDataProcessor): |
| """Main query processor that delegates to specific dataset processors""" |
|
|
| def __init__( |
| self, |
| db_path: str = "data/raw/WikiDB/enwiki-20230401.db", |
| query_size: int = None, |
| ): |
| self.db = DocDB(db_path=db_path, data_path=None) |
| self.dataset = None |
| self.query_size = query_size |
| self.processors = { |
| "fact_score": FactScoreProcessor(), |
| "hotpot_qa": HotpotQAProcessor(), |
| "pop_qa": PopQAProcessor(), |
| "medlf_qa": MedLFQAProcessor(), |
| } |
|
|
| def get_queries( |
| self, |
| dataset: str, |
| input_file: str, |
| output_dir: str, |
| output_file: str, |
| seed: int = 42, |
| ): |
| """ |
| Reads raw data from a file and extracts queries, storing them in a JSON file. |
| Returns a dictionary mapping query inputs to their answers. |
| |
| Args: |
| dataset: The name of the dataset to process |
| input_file: Path to the input file with raw data |
| output_file: Path where processed queries will be saved |
| query_size: Number of queries to sample (None or -1 for all) |
| seed: Random seed for reproducible sampling |
| |
| Returns: |
| dict: A dictionary mapping query inputs to their answers |
| """ |
| self.dataset = dataset |
| self.input_file = input_file |
|
|
| |
| query_path = os.path.join(output_dir, output_file) |
| if os.path.exists(query_path): |
| print(f"{query_path} already exists.") |
| with open(query_path, "r", encoding="utf-8") as jsonfile: |
| queries = json.load(jsonfile) |
|
|
| |
| else: |
| if not os.path.exists(output_dir): |
| os.makedirs(output_dir) |
|
|
| |
| processor = self.processors.get(dataset) |
| if not processor: |
| raise ValueError(f"Unsupported dataset: {dataset}") |
|
|
| |
| queries = processor.process_queries(input_file) |
|
|
| |
| with open(query_path, "w", encoding="utf-8") as jsonfile: |
| json.dump(queries, jsonfile, indent=4) |
|
|
| print(f"Queries saved to {output_file}") |
|
|
| |
| if self.query_size and self.query_size != -1 and len(queries) > self.query_size: |
| random.seed(seed) |
| if "groups" in queries[0]: |
| |
| group_to_queries = defaultdict(list) |
| for item in queries: |
| for group in item.get("groups", []): |
| group_to_queries[group].append(item) |
| |
| |
| group_sizes = {g: len(qs) for g, qs in group_to_queries.items()} |
| sorted_groups = sorted(group_sizes.items(), key=lambda x: x[1]) |
|
|
| remaining_size = self.query_size |
| group_allocation = {} |
|
|
| |
| |
| |
| |
| remaining_groups = [] |
| for group, size in sorted_groups: |
| fair_share = remaining_size // (len(sorted_groups) - len(group_allocation)) if (len(sorted_groups) - len(group_allocation)) > 0 else 0 |
| if size <= fair_share: |
| group_allocation[group] = size |
| remaining_size -= size |
| else: |
| remaining_groups.append(group) |
|
|
| |
| for group in remaining_groups: |
| fair_share = remaining_size // (len(remaining_groups) - len([g for g in group_allocation if g in remaining_groups])) |
| allocated = min(fair_share, group_sizes[group]) |
| group_allocation[group] = allocated |
| remaining_size -= allocated |
|
|
| |
| sampled = [] |
| for group, count in group_allocation.items(): |
| sampled.extend(random.sample(group_to_queries[group], count)) |
| self.queries = sampled |
| |
| else: |
| self.queries = random.sample(queries, self.query_size) |
|
|
| |
| query_path = os.path.join( |
| output_dir, f"sampled_{self.query_size}_{output_file}" |
| ) |
| with open(query_path, "w", encoding="utf-8") as jsonfile: |
| json.dump(self.queries, jsonfile, indent=4) |
|
|
| else: |
| self.queries = queries |
|
|
| |
| return { |
| query["input"]: query["output"]["answer"] for query in self.queries |
| }, query_path |
|
|
| def get_documents(self, query_dir: str, output_dir: str, output_file: str) -> str: |
| """ |
| Reads structured query data from a JSON file and generates a corresponding document list. |
| |
| Args: |
| query_dir: Directory containing query data. |
| output_dir: Directory to save the output file. |
| output_file: Name of the output file. |
| |
| Returns: |
| Path to the output file. |
| """ |
| |
| os.makedirs(output_dir, exist_ok=True) |
|
|
| |
| output_path = os.path.join( |
| output_dir, f"sampled_{self.query_size}_{output_file}" |
| ) |
|
|
| |
| if os.path.exists(output_path): |
| print(f"{output_path} already exists.") |
| return output_path |
|
|
| |
| processor = self.processors.get(self.dataset) |
| if not processor: |
| raise ValueError(f"Unsupported dataset: {self.dataset}") |
|
|
| |
| if self.dataset in ["fact_score", "hotpot_qa", "pop_qa"]: |
| for query in self.queries: |
| self._validate_schema(query) |
|
|
| |
| queries_to_use = None |
| if self.query_size and self.query_size != -1: |
| queries_to_use = self.queries |
|
|
| |
| documents = processor.process_documents( |
| query_dir, self.db, queries_to_use, raw_query_dir=self.input_file |
| ) |
|
|
| |
| with open(output_path, "w", encoding="utf-8") as jsonfile: |
| json.dump(documents, jsonfile, indent=4, ensure_ascii=False) |
|
|
| print(f"Document list saved to {output_path}.") |
| return output_path |
|
|
| def _validate_schema(self, query: dict): |
| """Validate a query against schema""" |
| base_schema = None |
| wiki_schema = None |
| with open( |
| "data/processed/base_schema.json", "r", encoding="utf-8" |
| ) as schemafile: |
| base_schema = json.load(schemafile) |
|
|
| with open( |
| "data/processed/wiki_schema.json", "r", encoding="utf-8" |
| ) as schemafile: |
| wiki_schema = json.load(schemafile) |
|
|
| resolver = RefResolver("data/processed/base_schema.json", base_schema) |
| validate(instance=query, schema=wiki_schema, resolver=resolver) |
|
|
|
|
| if __name__ == "__main__": |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
|
|
| medlf_query_processor = QueryProcessor(db_path="data/raw/WikiDB/enwiki-20230401.db") |
| medlf_query_processor.get_queries( |
| dataset="medlf_qa", |
| input_file="data/raw/MedLFQA", |
| output_file="data/processed/MedLFQA/medlf_qa_queries.json", |
| ) |
| medlf_query_processor.get_documents( |
| query_dir="data/processed/MedLFQA/medlf_qa_queries.json", |
| output_file="data/processed/MedLFQA/medlf_qa_documents.txt", |
| ) |
|
|