Spaces:
Sleeping
Sleeping
File size: 7,859 Bytes
19fc84f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 | import os
import re
import ast
import bz2
import sqlite3
import json
from collections import defaultdict
from datasets import load_dataset
class DataLoader:
def __init__(self, dataset: str):
self.dataset = dataset
def load_qa_data(self, output_path: str):
if os.path.exists(output_path):
print(f"Dataset already exists at {output_path}.")
else:
print(f"Loading {self.dataset} dataset.")
if self.dataset == "fact_score":
load_fact_score_data(output_path)
elif self.dataset == "hotpot_qa":
load_hotpot_qa_data(output_path)
elif self.dataset == "pop_qa":
load_pop_qa_data(output_path)
elif self.dataset == "medlfqa":
output_path = load_medlfqa_data("data/.source_data/MedLFQA")
clean_medlfqa_data(data_path=output_path, output_path=output_path)
def create_wiki_db(
self,
source_path: str = "data/raw/WikiDB/enwiki-20171001-pages-meta-current-withlinks-abstracts",
output_path: str = "data/raw/WikiDB/enwiki_20190113.db",
):
"Create a SQLite database from the Wikipedia dump data."
if os.path.exists(output_path):
print(f"Database already exists at {output_path}.")
return
if not os.path.exists(source_path):
raise FileNotFoundError(f"Source path {source_path} not found.")
else:
print(f"Reading data from {source_path}")
# Create a connection to the SQLite database
conn = sqlite3.connect(output_path)
cursor = conn.cursor()
# Create a table to store the content
cursor.execute("""DROP TABLE IF EXISTS wiki_content""")
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS wiki_content (
id INTEGER PRIMARY KEY AUTOINCREMENT,
title TEXT,
url TEXT,
text TEXT
)
"""
)
# Iterate through each bz2 file in the folder
for folder in os.listdir(source_path):
folder_path = f"{source_path}/{folder}"
for file_name in os.listdir(folder_path):
if file_name.endswith(".bz2"):
file_path = os.path.join(folder_path, file_name)
with bz2.open(file_path, "rt") as file:
content = file.read()
lines = content.split("\n")
for line in lines:
if line.strip():
data = json.loads(line)
line = ast.literal_eval(line)
id = line.get("id", "")
title = line["title"]
url = line.get("url", "")
text = str(line.get("text", ""))
cursor.execute(
"""
INSERT INTO wiki_content (id, title, url, text)
VALUES (?, ?, ?, ?)
""",
(id, title, url, text),
)
# print(f'Inserted {title} into the database')
# Commit the changes and close the connection
conn.commit()
conn.close()
print(f"Created database at {output_path}")
def load_fact_score_data(output_path: str):
# raise NotImplementedError
pass
def load_hotpot_qa_data(output_path: str):
"""Load HotpotQA dataset and save validation set to json file."""
dataset = load_dataset("kilt_tasks", "hotpotqa")
dataset["validation"].to_json(output_path, orient="records", lines=True)
print("HotpotQA validation set saved to", output_path)
return
def load_pop_qa_data(output_path: str):
"""Load PopQA dataset and save test set to json file."""
dataset = load_dataset("akariasai/popQA")
dataset["test"].to_json(output_path, orient="records", lines=True)
print("PopQA test set saved to", output_path)
return
def load_medlfqa_data(output_path: str = "data/.source_data/MedLFQA"):
"""Load MedLFQA dataset and save to json file."""
if not os.path.exists(f"{output_path}"):
os.system(f"mkdir -p {output_path}")
dataset_names = [
"healthsearch_qa",
"kqa_golden",
"kqa_silver_wogold",
"live_qa",
"medication_qa",
]
for fname in dataset_names:
if f"{fname}.jsonl" in os.listdir(output_path):
print(f"Dataset {fname} already exists.")
continue
else:
os.system(
f"wget -O {output_path}/{fname}.jsonl https://raw.githubusercontent.com/jjcherian/conformal-safety/refs/heads/main/data/MedLFQAv2/{fname}.jsonl"
)
print(f"MedLFQA dataset saved to {output_path}")
return output_path
def remove_specific_leading_chars(input_string):
# Remove leading commas
input_string = re.sub(r"^,+", "", input_string)
# Remove numbers followed by a comma
return re.sub(r"^\d+,+", "", input_string)
def clean_medlfqa_data(data_path: str, output_path: str):
"""Clean the MedLFQA dataset to remove unwanted characters and fields."""
suffix = ".jsonl"
datasets = {}
# Load datasets
for fname in os.listdir(data_path):
if fname.endswith(suffix):
dataset_name = fname[: -len(suffix)]
with open(os.path.join(data_path, fname), "r") as fp:
datasets[dataset_name] = [json.loads(line) for line in fp]
# Clean questions and filter duplicates
filtered_datasets = {}
redundant_prompts = defaultdict(int)
for name, dataset in datasets.items():
seen_questions = set()
filtered_dataset = []
for pt in dataset:
pt["Question"] = remove_specific_leading_chars(pt["Question"]).strip()
if pt["Question"] not in seen_questions:
seen_questions.add(pt["Question"])
filtered_dataset.append(pt)
redundant_prompts[pt["Question"]] += 1
filtered_datasets[name] = filtered_dataset
# Filter out questions that are redundant across datasets
for name, dataset in filtered_datasets.items():
if name not in {"kqa_golden", "live_qa"}:
filtered_datasets[name] = [
pt for pt in dataset if redundant_prompts[pt["Question"]] == 1
]
if not os.path.exists(output_path):
os.system(f"mkdir -p {output_path}")
# Save cleaned datasets
for name, dataset in filtered_datasets.items():
filepath = os.path.join(output_path, f"{name}.json")
json_objects = []
for pt in dataset:
json_objects.append(pt)
with open(filepath, "w") as outfile:
json.dump(json_objects, outfile, indent=4)
# for pt in dataset:
# json.dump(pt, outfile)
# outfile.write('\n')
print(f"Saved {name} dataset to {filepath}")
# example code
if __name__ == "__main__":
# loader = DataLoader("fact_score")
# loader.load_qa_data("data/raw/FactScore/factscore_names.txt")
# loader = DataLoader("hotpot_qa")
# loader.load_qa_data("data/raw/HotpotQA/hotpotqa_validation_set.jsonl")
# loader = DataLoader("pop_qa")
# loader.load_qa_data("data/raw/PopQA/popQA_test.json")
loader = DataLoader("medlfqa")
loader.load_qa_data("data/raw/MedLFQA/")
loader.create_wiki_db(output_path="data/raw/WikiDB/enwiki-20230401.db")
|