# rag/modules/corpus_store.py from typing import List, Dict, Any from datasets import load_dataset, DatasetDict, Dataset from config import HF_CORPUS_REPO, HF_CORPUS_SUBSET, HF_CORPUS_SPLIT, MARKER_DIR, CORPUS_READY_MARK from modules.utils import ensure_dir, exists, touch _datasets: Dict[str, Dataset] = {} def prepare_corpus(): """ 최초 1회만 parquet split을 로컬에 다운로드. 이후에는 로컬 캐시 사용. """ ensure_dir(MARKER_DIR) if exists(CORPUS_READY_MARK): return subsets = HF_CORPUS_SUBSET.split(",") # "ko,en" → ["ko","en"] for subset in subsets: load_dataset(HF_CORPUS_REPO, subset.strip(), split=HF_CORPUS_SPLIT) touch(CORPUS_READY_MARK) def _get_datasets() -> Dict[str, Dataset]: global _datasets if not _datasets: subsets = HF_CORPUS_SUBSET.split(",") for subset in subsets: _datasets[subset.strip()] = load_dataset( HF_CORPUS_REPO, subset.strip(), split=HF_CORPUS_SPLIT ) return _datasets def fetch_contexts_by_ids(ids: List[int]) -> List[Dict[str, Any]]: if not ids: return [] datasets = _get_datasets() id_set = set(ids) results: List[Dict[str, Any]] = [] # 모든 subset을 순회하며 page_id 매칭 for subset, ds in datasets.items(): # filter를 사용하면 전체 순회보다 빠름 (병렬 최적화) rows = ds.filter(lambda r: r["page_id"] in id_set) id_to_row = {r["page_id"]: r for r in rows} for i in ids: r = id_to_row.get(i) if r: results.append({ "id": r["page_id"], "title": r.get("title", ""), "text": r.get("wikitext", ""), "url": r.get("url", ""), "metadata": { "date_modified": r.get("date_modified", ""), "in_language": r.get("in_language", ""), "wikidata_id": r.get("wikidata_id", "") } }) return results