update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e
verified
| import json | |
| import logging | |
| import os | |
| import shutil | |
| from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Tuple | |
| from datasets import Dataset as HFDataset | |
| from langchain_core.documents import Document as LCDocument | |
| from pie_datasets import Dataset, DatasetDict, concatenate_datasets | |
| from pytorch_ie.documents import TextBasedDocument | |
| from .pie_document_store import PieDocumentStore | |
| logger = logging.getLogger(__name__) | |
| class DatasetsPieDocumentStore(PieDocumentStore): | |
| """PIE Document store that uses Huggingface Datasets as the backend.""" | |
| def __init__(self) -> None: | |
| self._data: Optional[Dataset] = None | |
| # keys map to indices in the dataset | |
| self._keys: Dict[str, int] = {} | |
| self._metadata: Dict[str, Any] = {} | |
| def __len__(self): | |
| return len(self._keys) | |
| def _get_pie_docs_by_indices(self, indices: Iterable[int]) -> Sequence[TextBasedDocument]: | |
| if self._data is None: | |
| return [] | |
| return self._data.apply_hf_func(func=HFDataset.select, indices=indices) | |
| def mget(self, keys: Sequence[str]) -> List[LCDocument]: | |
| if self._data is None or len(keys) == 0: | |
| return [] | |
| keys_in_data = [key for key in keys if key in self._keys] | |
| indices = [self._keys[key] for key in keys_in_data] | |
| dataset = self._get_pie_docs_by_indices(indices) | |
| metadatas = [self._metadata.get(key, {}) for key in keys_in_data] | |
| return [self.wrap(pie_doc, **metadata) for pie_doc, metadata in zip(dataset, metadatas)] | |
| def mset(self, items: Sequence[Tuple[str, LCDocument]]) -> None: | |
| if len(items) == 0: | |
| return | |
| keys, new_docs = zip(*items) | |
| pie_docs, metadatas = zip(*[self.unwrap_with_metadata(doc) for doc in new_docs]) | |
| if self._data is None: | |
| idx_start = 0 | |
| self._data = Dataset.from_documents(pie_docs) | |
| else: | |
| # we pass the features to the new dataset to mitigate issues caused by | |
| # slightly different inferred features | |
| dataset = Dataset.from_documents(pie_docs, features=self._data.features) | |
| idx_start = len(self._data) | |
| self._data = concatenate_datasets([self._data, dataset], clear_metadata=False) | |
| keys_dict = {key: idx for idx, key in zip(range(idx_start, len(self._data)), keys)} | |
| self._keys.update(keys_dict) | |
| self._metadata.update( | |
| {key: metadata for key, metadata in zip(keys, metadatas) if metadata} | |
| ) | |
| def add_pie_dataset( | |
| self, | |
| dataset: Dataset, | |
| keys: Optional[List[str]] = None, | |
| metadatas: Optional[List[Dict[str, Any]]] = None, | |
| ) -> None: | |
| if len(dataset) == 0: | |
| return | |
| if keys is None: | |
| keys = [doc.id for doc in dataset] | |
| if len(keys) != len(set(keys)): | |
| raise ValueError("Keys must be unique.") | |
| if None in keys: | |
| raise ValueError("Keys must not be None.") | |
| if metadatas is None: | |
| metadatas = [{} for _ in range(len(dataset))] | |
| if len(keys) != len(dataset) or len(keys) != len(metadatas): | |
| raise ValueError("Keys, dataset and metadatas must have the same length.") | |
| if self._data is None: | |
| idx_start = 0 | |
| self._data = dataset | |
| else: | |
| idx_start = len(self._data) | |
| self._data = concatenate_datasets([self._data, dataset], clear_metadata=False) | |
| keys_dict = {key: idx for idx, key in zip(range(idx_start, len(self._data)), keys)} | |
| self._keys.update(keys_dict) | |
| metadatas_dict = {key: metadata for key, metadata in zip(keys, metadatas) if metadata} | |
| self._metadata.update(metadatas_dict) | |
| def mdelete(self, keys: Sequence[str]) -> None: | |
| for key in keys: | |
| idx = self._keys.pop(key, None) | |
| if idx is not None: | |
| self._metadata.pop(key, None) | |
| def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]: | |
| return (key for key in self._keys if prefix is None or key.startswith(prefix)) | |
| def _purge_invalid_entries(self): | |
| if self._data is None or len(self._keys) == len(self._data): | |
| return | |
| self._data = self._get_pie_docs_by_indices(self._keys.values()) | |
| def _save_to_directory(self, path: str, batch_size: Optional[int] = None, **kwargs) -> None: | |
| self._purge_invalid_entries() | |
| if len(self) == 0: | |
| logger.warning("No documents to save.") | |
| return | |
| all_doc_ids = list(self._keys) | |
| all_metadatas: List[Dict[str, Any]] = [self._metadata.get(key, {}) for key in all_doc_ids] | |
| pie_documents_path = os.path.join(path, "pie_documents") | |
| if os.path.exists(pie_documents_path): | |
| # remove existing directory | |
| logger.warning(f"Removing existing directory: {pie_documents_path}") | |
| shutil.rmtree(pie_documents_path) | |
| os.makedirs(pie_documents_path, exist_ok=True) | |
| DatasetDict({"train": self._data}).to_json(pie_documents_path, mode="w") | |
| doc_ids_path = os.path.join(path, "doc_ids.json") | |
| with open(doc_ids_path, "w") as f: | |
| json.dump(all_doc_ids, f) | |
| metadata_path = os.path.join(path, "metadata.json") | |
| with open(metadata_path, "w") as f: | |
| json.dump(all_metadatas, f) | |
| def _load_from_directory(self, path: str, **kwargs) -> None: | |
| doc_ids_path = os.path.join(path, "doc_ids.json") | |
| if os.path.exists(doc_ids_path): | |
| with open(doc_ids_path, "r") as f: | |
| all_doc_ids = json.load(f) | |
| else: | |
| logger.warning(f"File {doc_ids_path} does not exist, don't load any document ids.") | |
| all_doc_ids = None | |
| metadata_path = os.path.join(path, "metadata.json") | |
| if os.path.exists(metadata_path): | |
| with open(metadata_path, "r") as f: | |
| all_metadata = json.load(f) | |
| else: | |
| logger.warning(f"File {metadata_path} does not exist, don't load any metadata.") | |
| all_metadata = None | |
| pie_documents_path = os.path.join(path, "pie_documents") | |
| if not os.path.exists(pie_documents_path): | |
| logger.warning( | |
| f"Directory {pie_documents_path} does not exist, don't load any documents." | |
| ) | |
| return None | |
| # If we have a dataset already loaded, we use its features to load the new dataset | |
| # This is to mitigate issues caused by slightly different inferred features. | |
| features = self._data.features if self._data is not None else None | |
| pie_dataset = DatasetDict.from_json(data_dir=pie_documents_path, features=features) | |
| pie_docs = pie_dataset["train"] | |
| self.add_pie_dataset(pie_docs, keys=all_doc_ids, metadatas=all_metadata) | |
| logger.info(f"Loaded {len(pie_docs)} documents from {path} into docstore") | |