pls-rag / modules /corpus.py
m97j's picture
Initial codes commit
4fdc679
raw
history blame
2.11 kB
# rag/modules/corpus_store.py
from typing import List, Dict, Any
from datasets import load_dataset, DatasetDict, Dataset
from config import HF_CORPUS_REPO, HF_CORPUS_SUBSET, HF_CORPUS_SPLIT, MARKER_DIR, CORPUS_READY_MARK
from modules.utils import ensure_dir, exists, touch
_datasets: Dict[str, Dataset] = {}
def prepare_corpus():
"""
최초 1회만 parquet split을 로컬에 다운로드.
이후에는 로컬 캐시 사용.
"""
ensure_dir(MARKER_DIR)
if exists(CORPUS_READY_MARK):
return
subsets = HF_CORPUS_SUBSET.split(",") # "ko,en" → ["ko","en"]
for subset in subsets:
load_dataset(HF_CORPUS_REPO, subset.strip(), split=HF_CORPUS_SPLIT)
touch(CORPUS_READY_MARK)
def _get_datasets() -> Dict[str, Dataset]:
global _datasets
if not _datasets:
subsets = HF_CORPUS_SUBSET.split(",")
for subset in subsets:
_datasets[subset.strip()] = load_dataset(
HF_CORPUS_REPO, subset.strip(), split=HF_CORPUS_SPLIT
)
return _datasets
def fetch_contexts_by_ids(ids: List[int]) -> List[Dict[str, Any]]:
if not ids:
return []
datasets = _get_datasets()
id_set = set(ids)
results: List[Dict[str, Any]] = []
# 모든 subset을 순회하며 page_id 매칭
for subset, ds in datasets.items():
# filter를 사용하면 전체 순회보다 빠름 (병렬 최적화)
rows = ds.filter(lambda r: r["page_id"] in id_set)
id_to_row = {r["page_id"]: r for r in rows}
for i in ids:
r = id_to_row.get(i)
if r:
results.append({
"id": r["page_id"],
"title": r.get("title", ""),
"text": r.get("wikitext", ""),
"url": r.get("url", ""),
"metadata": {
"date_modified": r.get("date_modified", ""),
"in_language": r.get("in_language", ""),
"wikidata_id": r.get("wikidata_id", "")
}
})
return results