Spaces:
Runtime error
Runtime error
| import logging | |
| from typing import Optional, Sequence, Type | |
| from langchain_core.documents import Document as LCDocument | |
| from pie_datasets import Dataset, IterableDataset | |
| from pytorch_ie import Document, WithDocumentTypeMixin | |
| from pytorch_ie.annotations import BinaryRelation, LabeledSpan | |
| from pytorch_ie.documents import TextBasedDocument | |
| from src.langchain_modules import DocumentAwareSpanRetriever | |
| logger = logging.getLogger(__name__) | |
| class DummyTaskmodule(WithDocumentTypeMixin): | |
| def __init__(self, document_type: Type[Document]): | |
| self._document_type = document_type | |
| def document_type(self) -> Optional[Type[Document]]: | |
| return self._document_type | |
| class SpanRetrievalBasedRelationExtractionPipeline: | |
| """Pipeline for adding binary relations between spans based on span retrieval within the same document. | |
| This pipeline retrieves spans for all existing spans as query and adds binary relations between the | |
| query spans and the retrieved spans. | |
| Args: | |
| retriever: The span retriever to use for retrieving spans. | |
| relation_label: The label to use for the binary relations. | |
| relation_layer_name: The name of the annotation layer to add the binary relations to. | |
| load_store_path: If provided, the retriever store(s) will be loaded from this path before processing. | |
| save_store_path: If provided, the retriever store(s) will be saved to this path after processing. | |
| fast_dev_run: Whether to run the pipeline in fast dev mode, i.e. only processing the first 2 documents. | |
| """ | |
| def __init__( | |
| self, | |
| retriever: DocumentAwareSpanRetriever, | |
| relation_label: str, | |
| relation_layer_name: str = "binary_relations", | |
| use_predicted_annotations: bool = False, | |
| load_store_path: Optional[str] = None, | |
| save_store_path: Optional[str] = None, | |
| fast_dev_run: bool = False, | |
| ): | |
| self.retriever = retriever | |
| if not self.retriever.retrieve_from_same_document: | |
| raise NotImplementedError("Retriever must retrieve from the same document") | |
| self.relation_label = relation_label | |
| self.relation_layer_name = relation_layer_name | |
| self.use_predicted_annotations = use_predicted_annotations | |
| self.load_store_path = load_store_path | |
| self.save_store_path = save_store_path | |
| if self.load_store_path is not None: | |
| self.retriever.load_from_directory(path=self.load_store_path) | |
| self.fast_dev_run = fast_dev_run | |
| # to make auto-conversion work: we request documents of type pipeline.taskmodule.document_type | |
| # from the dataset | |
| def taskmodule(self) -> DummyTaskmodule: | |
| return DummyTaskmodule(self.retriever.pie_document_type) | |
| def _construct_similarity_relations( | |
| self, | |
| query_results: list[LCDocument], | |
| query_span: LabeledSpan, | |
| ) -> list[BinaryRelation]: | |
| return [ | |
| BinaryRelation( | |
| head=query_span, | |
| tail=lc_doc.metadata["attached_span"], | |
| label=self.relation_label, | |
| score=float(lc_doc.metadata["relevance_score"]), | |
| ) | |
| for lc_doc in query_results | |
| ] | |
| def _process_single_document( | |
| self, | |
| document: Document, | |
| ) -> TextBasedDocument: | |
| if not isinstance(document, TextBasedDocument): | |
| raise ValueError("Document must be a TextBasedDocument") | |
| self.retriever.add_pie_documents( | |
| [document], use_predicted_annotations=self.use_predicted_annotations | |
| ) | |
| all_new_rels = [] | |
| spans = self.retriever.get_base_layer( | |
| document, use_predicted_annotations=self.use_predicted_annotations | |
| ) | |
| span_id2idx = self.retriever.get_span_id2idx_from_doc(document.id) | |
| for span_id, span_idx in span_id2idx.items(): | |
| query_span = spans[span_idx] | |
| query_result = self.retriever.invoke(input=span_id) | |
| query_rels = self._construct_similarity_relations(query_result, query_span=query_span) | |
| all_new_rels.extend(query_rels) | |
| if self.relation_layer_name not in document: | |
| raise ValueError(f"Document does not have a layer named {self.relation_layer_name}") | |
| document[self.relation_layer_name].predictions.extend(all_new_rels) | |
| if self.retriever.retrieve_from_same_document and self.save_store_path is None: | |
| self.retriever.delete_documents([document.id]) | |
| return document | |
| def __call__(self, documents: Sequence[Document], inplace: bool = False) -> Sequence[Document]: | |
| if inplace: | |
| raise NotImplementedError("Inplace processing is not supported yet") | |
| if self.fast_dev_run: | |
| logger.warning("Fast dev run enabled, only processing the first 2 documents") | |
| documents = documents[:2] | |
| if not isinstance(documents, (Dataset, IterableDataset)): | |
| documents = Dataset.from_documents(documents) | |
| mapped_documents = documents.map(self._process_single_document) | |
| if self.save_store_path is not None: | |
| self.retriever.save_to_directory(path=self.save_store_path) | |
| return mapped_documents | |