update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e
verified
| import json | |
| import logging | |
| import os | |
| import tempfile | |
| from pathlib import Path | |
| from typing import Iterable, List, Optional, Sequence | |
| import gradio as gr | |
| import pandas as pd | |
| from acl_anthology import Anthology | |
| from pie_datasets import Dataset, IterableDataset, load_dataset | |
| from pytorch_ie import Pipeline | |
| from pytorch_ie.documents import ( | |
| TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions, | |
| TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, | |
| ) | |
| from tqdm import tqdm | |
| from src.demo.annotation_utils import create_documents, get_merger | |
| from src.demo.data_utils import load_text_from_arxiv | |
| from src.demo.rendering_utils import ( | |
| RENDER_WITH_DISPLACY, | |
| RENDER_WITH_PRETTY_TABLE, | |
| render_displacy, | |
| render_pretty_table, | |
| ) | |
| from src.demo.retriever_utils import get_text_spans_and_relations_from_document | |
| from src.langchain_modules import ( | |
| DocumentAwareSpanRetriever, | |
| DocumentAwareSpanRetrieverWithRelations, | |
| ) | |
| from src.utils.pdf_utils.acl_anthology_utils import XML2RawPapers | |
| from src.utils.pdf_utils.process_pdf import FulltextExtractor, PDFDownloader | |
| logger = logging.getLogger(__name__) | |
| def add_annotated_pie_documents( | |
| retriever: DocumentAwareSpanRetriever, | |
| pie_documents: Sequence[TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions], | |
| use_predicted_annotations: bool, | |
| verbose: bool = False, | |
| ) -> None: | |
| if verbose: | |
| gr.Info(f"Create span embeddings for {len(pie_documents)} documents...") | |
| num_docs_before = len(retriever.docstore) | |
| retriever.add_pie_documents(pie_documents, use_predicted_annotations=use_predicted_annotations) | |
| # number of documents that were overwritten | |
| num_overwritten_docs = num_docs_before + len(pie_documents) - len(retriever.docstore) | |
| # warn if documents were overwritten | |
| if num_overwritten_docs > 0: | |
| gr.Warning(f"{num_overwritten_docs} documents were overwritten.") | |
| def process_texts( | |
| texts: Iterable[str], | |
| doc_ids: Iterable[str], | |
| argumentation_model: Optional[Pipeline], | |
| retriever: DocumentAwareSpanRetriever, | |
| split_regex_escaped: Optional[str], | |
| handle_parts_of_same: bool = False, | |
| verbose: bool = False, | |
| ) -> None: | |
| # check that doc_ids are unique | |
| if len(set(doc_ids)) != len(list(doc_ids)): | |
| raise gr.Error("Document IDs must be unique.") | |
| pie_documents = create_documents( | |
| texts=texts, | |
| doc_ids=doc_ids, | |
| split_regex=split_regex_escaped, | |
| ) | |
| if argumentation_model is not None: | |
| if verbose: | |
| gr.Info(f"Annotate {len(pie_documents)} documents...") | |
| pie_documents = argumentation_model(pie_documents, inplace=True) | |
| else: | |
| gr.Warning( | |
| "Annotation is disabled (no model was loaded). No annotations will be added to the documents." | |
| ) | |
| # this needs to be done also if the documents are not annotated because | |
| # it adjusts the document type | |
| if handle_parts_of_same: | |
| merger = get_merger() | |
| pie_documents = [merger(document) for document in pie_documents] | |
| add_annotated_pie_documents( | |
| retriever=retriever, | |
| pie_documents=pie_documents, | |
| use_predicted_annotations=True, | |
| verbose=verbose, | |
| ) | |
| def add_annotated_pie_documents_from_dataset( | |
| retriever: DocumentAwareSpanRetriever, verbose: bool = False, **load_dataset_kwargs | |
| ) -> None: | |
| try: | |
| gr.Info( | |
| "Loading PIE dataset with parameters:\n" + json.dumps(load_dataset_kwargs, indent=2) | |
| ) | |
| dataset = load_dataset(**load_dataset_kwargs) | |
| if not isinstance(dataset, (Dataset, IterableDataset)): | |
| raise gr.Error("Loaded dataset is not of type PIE (Iterable)Dataset.") | |
| try: | |
| dataset_converted = dataset.to_document_type( | |
| TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions | |
| ) | |
| except ValueError: | |
| gr.Warning( | |
| "The dataset does not seem to have registered converter to create multi-spans. " | |
| "Try to Load as single-spans and to convert to multi-spans manually ..." | |
| ) | |
| dataset_converted_single_span = dataset.to_document_type( | |
| TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions | |
| ) | |
| merger = get_merger() | |
| dataset_converted = dataset_converted_single_span.map( | |
| merger, | |
| result_document_type=TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions, | |
| ) | |
| def _clear_metadata( | |
| doc: TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions, | |
| ) -> TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions: | |
| result = doc.copy() | |
| result.metadata = dict() | |
| return result | |
| # adding documents with different metadata format to the retriever breaks it, | |
| # so we clear the metadata field beforehand | |
| dataset_converted_without_metadata = dataset_converted.map( | |
| _clear_metadata, | |
| result_document_type=TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions, | |
| ) | |
| add_annotated_pie_documents( | |
| retriever=retriever, | |
| pie_documents=dataset_converted_without_metadata, | |
| use_predicted_annotations=False, | |
| verbose=verbose, | |
| ) | |
| except Exception as e: | |
| raise gr.Error(f"Failed to load dataset: {e}") | |
| def wrapped_process_text( | |
| doc_id: str, text: str, retriever: DocumentAwareSpanRetriever, **kwargs | |
| ) -> str: | |
| try: | |
| process_texts(doc_ids=[doc_id], texts=[text], retriever=retriever, **kwargs) | |
| except Exception as e: | |
| raise gr.Error(f"Failed to process text: {e}") | |
| # Return as dict and document to avoid serialization issues | |
| return doc_id | |
| def process_uploaded_files( | |
| file_names: List[str], | |
| retriever: DocumentAwareSpanRetriever, | |
| layer_captions: dict[str, str], | |
| **kwargs, | |
| ) -> pd.DataFrame: | |
| try: | |
| doc_ids = [] | |
| texts = [] | |
| for file_name in file_names: | |
| if file_name.lower().endswith(".txt"): | |
| # read the file content | |
| with open(file_name, "r", encoding="utf-8") as f: | |
| text = f.read() | |
| base_file_name = os.path.basename(file_name) | |
| doc_ids.append(base_file_name) | |
| texts.append(text) | |
| else: | |
| raise gr.Error(f"Unsupported file format: {file_name}") | |
| process_texts(texts=texts, doc_ids=doc_ids, retriever=retriever, verbose=True, **kwargs) | |
| except Exception as e: | |
| raise gr.Error(f"Failed to process uploaded files: {e}") | |
| return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True) | |
| def process_uploaded_pdf_files( | |
| pdf_fulltext_extractor: Optional[FulltextExtractor], | |
| file_names: List[str], | |
| retriever: DocumentAwareSpanRetriever, | |
| layer_captions: dict[str, str], | |
| **kwargs, | |
| ) -> pd.DataFrame: | |
| try: | |
| if pdf_fulltext_extractor is None: | |
| raise gr.Error("PDF fulltext extractor is not available.") | |
| doc_ids = [] | |
| texts = [] | |
| for file_name in file_names: | |
| if file_name.lower().endswith(".pdf"): | |
| # extract the fulltext from the pdf | |
| text_and_extraction_data = pdf_fulltext_extractor(file_name) | |
| if text_and_extraction_data is None: | |
| raise gr.Error(f"Failed to extract fulltext from PDF: {file_name}") | |
| text, _ = text_and_extraction_data | |
| base_file_name = os.path.basename(file_name) | |
| doc_ids.append(base_file_name) | |
| texts.append(text) | |
| else: | |
| raise gr.Error(f"Unsupported file format: {file_name}") | |
| process_texts(texts=texts, doc_ids=doc_ids, retriever=retriever, verbose=True, **kwargs) | |
| except Exception as e: | |
| raise gr.Error(f"Failed to process uploaded files: {e}") | |
| return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True) | |
| def load_acl_anthology_venues( | |
| venues: List[str], | |
| pdf_fulltext_extractor: Optional[FulltextExtractor], | |
| retriever: DocumentAwareSpanRetriever, | |
| layer_captions: dict[str, str], | |
| acl_anthology_data_dir: Optional[str], | |
| pdf_output_dir: Optional[str], | |
| show_progress: bool = True, | |
| **kwargs, | |
| ) -> pd.DataFrame: | |
| try: | |
| if pdf_fulltext_extractor is None: | |
| raise gr.Error("PDF fulltext extractor is not available.") | |
| if acl_anthology_data_dir is None: | |
| raise gr.Error("ACL Anthology data directory is not provided.") | |
| if pdf_output_dir is None: | |
| raise gr.Error("PDF output directory is not provided.") | |
| xml2raw_papers = XML2RawPapers( | |
| anthology=Anthology(datadir=Path(acl_anthology_data_dir)), | |
| venue_id_whitelist=venues, | |
| verbose=False, | |
| ) | |
| pdf_downloader = PDFDownloader() | |
| doc_ids = [] | |
| texts = [] | |
| os.makedirs(pdf_output_dir, exist_ok=True) | |
| papers = xml2raw_papers() | |
| if show_progress: | |
| papers_list = list(papers) | |
| papers = tqdm(papers_list, desc="extracting fulltext") | |
| gr.Info( | |
| f"Downloading and extracting fulltext from {len(papers_list)} papers in venues: {venues}" | |
| ) | |
| for paper in papers: | |
| if paper.url is not None: | |
| pdf_save_path = pdf_downloader.download( | |
| paper.url, opath=Path(pdf_output_dir) / f"{paper.name}.pdf" | |
| ) | |
| fulltext_extraction_output = pdf_fulltext_extractor(pdf_save_path) | |
| if fulltext_extraction_output: | |
| text, _ = fulltext_extraction_output | |
| doc_id = f"aclanthology.org/{paper.name}" | |
| doc_ids.append(doc_id) | |
| texts.append(text) | |
| else: | |
| gr.Warning(f"Failed to extract fulltext from PDF: {paper.url}") | |
| process_texts(texts=texts, doc_ids=doc_ids, retriever=retriever, verbose=True, **kwargs) | |
| except Exception as e: | |
| raise gr.Error(f"Failed to process uploaded files: {e}") | |
| return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True) | |
| def wrapped_add_annotated_pie_documents_from_dataset( | |
| retriever: DocumentAwareSpanRetriever, verbose: bool, layer_captions: dict[str, str], **kwargs | |
| ) -> pd.DataFrame: | |
| try: | |
| add_annotated_pie_documents_from_dataset(retriever=retriever, verbose=verbose, **kwargs) | |
| except Exception as e: | |
| raise gr.Error(f"Failed to add annotated PIE documents from dataset: {e}") | |
| return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True) | |
| def download_processed_documents( | |
| retriever: DocumentAwareSpanRetriever, | |
| file_name: str = "retriever_store", | |
| ) -> Optional[str]: | |
| if len(retriever.docstore) == 0: | |
| gr.Warning("No documents to download.") | |
| return None | |
| # zip the directory | |
| file_path = os.path.join(tempfile.gettempdir(), file_name) | |
| gr.Info(f"Zipping the retriever store to '{file_name}' ...") | |
| result_file_path = retriever.save_to_archive(base_name=file_path, format="zip") | |
| return result_file_path | |
| def upload_processed_documents( | |
| file_name: str, | |
| retriever: DocumentAwareSpanRetriever, | |
| layer_captions: dict[str, str], | |
| ) -> pd.DataFrame: | |
| # load the documents from the zip file or directory | |
| retriever.load_from_disc(file_name) | |
| # return the overview of the document store | |
| return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True) | |
| def process_text_from_arxiv( | |
| arxiv_id: str, retriever: DocumentAwareSpanRetriever, abstract_only: bool = False, **kwargs | |
| ) -> str: | |
| try: | |
| text, doc_id = load_text_from_arxiv(arxiv_id=arxiv_id, abstract_only=abstract_only) | |
| except Exception as e: | |
| raise gr.Error(f"Failed to load text from arXiv: {e}") | |
| return wrapped_process_text(doc_id=doc_id, text=text, retriever=retriever, **kwargs) | |
| def render_annotated_document( | |
| retriever: DocumentAwareSpanRetrieverWithRelations, | |
| document_id: str, | |
| render_with: str, | |
| render_kwargs_json: str, | |
| highlight_span_ids: Optional[List[str]] = None, | |
| ) -> str: | |
| text, spans, span_id2idx, relations = get_text_spans_and_relations_from_document( | |
| retriever=retriever, document_id=document_id | |
| ) | |
| render_kwargs = json.loads(render_kwargs_json) | |
| if render_with == RENDER_WITH_PRETTY_TABLE: | |
| html = render_pretty_table( | |
| text=text, | |
| spans=spans, | |
| span_id2idx=span_id2idx, | |
| binary_relations=relations, | |
| **render_kwargs, | |
| ) | |
| elif render_with == RENDER_WITH_DISPLACY: | |
| html = render_displacy( | |
| text=text, | |
| spans=spans, | |
| span_id2idx=span_id2idx, | |
| binary_relations=relations, | |
| highlight_span_ids=highlight_span_ids, | |
| **render_kwargs, | |
| ) | |
| else: | |
| raise ValueError(f"Unknown render_with value: {render_with}") | |
| return html | |