pls-rag / modules /corpus.py
m97j's picture
Initial codes commit
33b550a
# rag/modules/corpus.py
from typing import List, Dict, Any
from datasets import load_dataset, Dataset
from config import HF_CORPUS_REPO, HF_CORPUS_SUBSET, HF_CORPUS_SPLIT, MARKER_DIR, CORPUS_READY_MARK
from modules.utils import ensure_dir, exists, touch
_datasets: Dict[str, Dataset] = {}
_id_to_row: Dict[int, Dict[str, Any]] = {}
def prepare_corpus():
"""
์ตœ์ดˆ 1ํšŒ๋งŒ parquet split์„ ๋กœ์ปฌ์— ๋‹ค์šด๋กœ๋“œ.
์ดํ›„์—๋Š” ๋กœ์ปฌ ์บ์‹œ ์‚ฌ์šฉ.
"""
ensure_dir(MARKER_DIR)
if exists(CORPUS_READY_MARK):
return
subsets = HF_CORPUS_SUBSET.split(",") # "ko,en" โ†’ ["ko","en"]
for subset in subsets:
load_dataset(HF_CORPUS_REPO, subset.strip(), split=HF_CORPUS_SPLIT)
touch(CORPUS_READY_MARK)
def _get_datasets() -> Dict[str, Dataset]:
global _datasets
if not _datasets:
subsets = HF_CORPUS_SUBSET.split(",")
for subset in subsets:
_datasets[subset.strip()] = load_dataset(
HF_CORPUS_REPO, subset.strip(), split=HF_CORPUS_SPLIT
)
return _datasets
def set_id_to_row(mapping: Dict[int, Dict[str, Any]]):
"""initializer์—์„œ ์ƒ์„ฑํ•œ page_id โ†’ row ๋งคํ•‘์„ ์ €์žฅ"""
global _id_to_row
_id_to_row = mapping
def fetch_contexts_by_ids(ids: List[int]) -> List[Dict[str, Any]]:
if not ids:
return []
results: List[Dict[str, Any]] = []
for i in ids:
r = _id_to_row.get(i)
if r:
results.append({
"id": r["page_id"],
"title": r.get("title", ""),
"text": r.get("wikitext", ""),
"url": r.get("url", ""),
"metadata": {
"date_modified": r.get("date_modified", ""),
"in_language": r.get("in_language", ""),
"wikidata_id": r.get("wikidata_id", "")
}
})
return results