update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e
verified
| import json | |
| import logging | |
| from typing import Iterable, Optional, Sequence | |
| import gradio as gr | |
| from hydra.utils import instantiate | |
| from pie_modules.document.processing import RegexPartitioner, SpansViaRelationMerger | |
| # this is required to dynamically load the PIE models | |
| from pie_modules.models import * # noqa: F403 | |
| from pie_modules.taskmodules import * # noqa: F403 | |
| from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE | |
| from pytorch_ie import Pipeline | |
| from pytorch_ie.annotations import LabeledSpan | |
| from pytorch_ie.documents import ( | |
| TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions, | |
| TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, | |
| ) | |
| # this is required to dynamically load the PIE models | |
| from pytorch_ie.models import * # noqa: F403 | |
| from pytorch_ie.taskmodules import * # noqa: F403 | |
| from src.utils import parse_config | |
| logger = logging.getLogger(__name__) | |
| def get_merger() -> SpansViaRelationMerger: | |
| return SpansViaRelationMerger( | |
| relation_layer="binary_relations", | |
| link_relation_label="parts_of_same", | |
| create_multi_spans=True, | |
| result_document_type=TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions, | |
| result_field_mapping={ | |
| "labeled_spans": "labeled_multi_spans", | |
| "binary_relations": "binary_relations", | |
| "labeled_partitions": "labeled_partitions", | |
| }, | |
| combine_scores_method="product", | |
| ) | |
| def create_document( | |
| text: str, doc_id: str, split_regex: Optional[str] = None | |
| ) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions: | |
| """Create a TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided | |
| text. | |
| Parameters: | |
| text: The text to process. | |
| doc_id: The ID of the document. | |
| split_regex: A regular expression pattern to use for splitting the text into partitions. | |
| Returns: | |
| The processed document. | |
| """ | |
| document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions( | |
| id=doc_id, text=text, metadata={} | |
| ) | |
| if split_regex is not None: | |
| partitioner = RegexPartitioner( | |
| pattern=split_regex, partition_layer_name="labeled_partitions" | |
| ) | |
| document = partitioner(document) | |
| else: | |
| # add single partition from the whole text (the model only considers text in partitions) | |
| document.labeled_partitions.append(LabeledSpan(start=0, end=len(text), label="text")) | |
| return document | |
| def create_documents( | |
| texts: Iterable[str], doc_ids: Iterable[str], split_regex: Optional[str] = None | |
| ) -> Sequence[TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions]: | |
| """Create a sequence of TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided | |
| texts. | |
| Parameters: | |
| texts: The texts to process. | |
| doc_ids: The IDs of the documents. | |
| split_regex: A regular expression pattern to use for splitting the text into partitions. | |
| Returns: | |
| The processed documents. | |
| """ | |
| return [ | |
| create_document(text=text, doc_id=doc_id, split_regex=split_regex) | |
| for text, doc_id in zip(texts, doc_ids) | |
| ] | |
| def load_argumentation_model(config_str: str, **kwargs) -> Optional[Pipeline]: | |
| try: | |
| config = parse_config(config_str, format="yaml") | |
| if config is None or config == {}: | |
| gr.Warning("Empty argumentation model config provided. No model loaded.") | |
| return None | |
| # for PIE AutoPipeline, we need to handle the revision separately for | |
| # the taskmodule and the model | |
| if ( | |
| config.get("_target_", "").strip().endswith("AutoPipeline.from_pretrained") | |
| and "revision" in config | |
| ): | |
| revision = config.pop("revision") | |
| if "taskmodule_kwargs" not in config: | |
| config["taskmodule_kwargs"] = {} | |
| config["taskmodule_kwargs"]["revision"] = revision | |
| if "model_kwargs" not in config: | |
| config["model_kwargs"] = {} | |
| config["model_kwargs"]["revision"] = revision | |
| model = instantiate(config, **kwargs) | |
| gr.Info(f"Loaded argumentation model: {json.dumps({**config, **kwargs})}") | |
| except Exception as e: | |
| raise gr.Error(f"Failed to load argumentation model: {e}") | |
| return model | |
| def set_relation_types( | |
| argumentation_model: Pipeline, | |
| default: Optional[Sequence[str]] = None, | |
| ) -> gr.Dropdown: | |
| if isinstance(argumentation_model.taskmodule, PointerNetworkTaskModuleForEnd2EndRE): | |
| relation_types = argumentation_model.taskmodule.labels_per_layer["binary_relations"] | |
| else: | |
| raise gr.Error("Unsupported taskmodule for relation types") | |
| return gr.Dropdown( | |
| choices=relation_types, | |
| label="Argumentative Relation Types", | |
| value=default, | |
| multiselect=True, | |
| ) | |