Spaces:
Runtime error
Runtime error
| """ | |
| workflow: | |
| Document | |
| -> (InputEncoding, TargetEncoding) -> TaskEncoding -> TaskBatchEncoding | |
| -> ModelBatchEncoding -> ModelBatchOutput | |
| -> TaskOutput | |
| -> Document | |
| """ | |
| import logging | |
| from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple, TypedDict, Union | |
| import numpy as np | |
| import torch | |
| from pytorch_ie.annotations import BinaryRelation, LabeledSpan, MultiLabeledBinaryRelation, Span | |
| from pytorch_ie.core import TaskEncoding, TaskModule | |
| from pytorch_ie.documents import TextDocument | |
| from pytorch_ie.models import ( | |
| TransformerTextClassificationModelBatchOutput, | |
| TransformerTextClassificationModelStepBatchEncoding, | |
| ) | |
| from pytorch_ie.utils.span import get_token_slice, is_contained_in | |
| from pytorch_ie.utils.window import get_window_around_slice | |
| from transformers import AutoTokenizer | |
| from transformers.file_utils import PaddingStrategy | |
| from transformers.tokenization_utils_base import BatchEncoding, TruncationStrategy | |
| from typing_extensions import TypeAlias | |
| TransformerReTextClassificationInputEncoding2: TypeAlias = Dict[str, Any] | |
| TransformerReTextClassificationTargetEncoding2: TypeAlias = Sequence[int] | |
| TransformerReTextClassificationTaskEncoding2: TypeAlias = TaskEncoding[ | |
| TextDocument, | |
| TransformerReTextClassificationInputEncoding2, | |
| TransformerReTextClassificationTargetEncoding2, | |
| ] | |
| class TransformerReTextClassificationTaskOutput2(TypedDict, total=False): | |
| labels: Sequence[str] | |
| probabilities: Sequence[float] | |
| _TransformerReTextClassificationTaskModule2: TypeAlias = TaskModule[ | |
| # _InputEncoding, _TargetEncoding, _TaskBatchEncoding, _ModelBatchOutput, _TaskOutput | |
| TextDocument, | |
| TransformerReTextClassificationInputEncoding2, | |
| TransformerReTextClassificationTargetEncoding2, | |
| TransformerTextClassificationModelStepBatchEncoding, | |
| TransformerTextClassificationModelBatchOutput, | |
| TransformerReTextClassificationTaskOutput2, | |
| ] | |
| HEAD = "head" | |
| TAIL = "tail" | |
| START = "start" | |
| END = "end" | |
| logger = logging.getLogger(__name__) | |
| class RelationArgument: | |
| def __init__( | |
| self, | |
| entity: LabeledSpan, | |
| role: str, | |
| offsets: Tuple[int, int], | |
| add_type_to_marker: bool, | |
| ) -> None: | |
| self.entity = entity | |
| self.role = role | |
| assert self.role in (HEAD, TAIL) | |
| self.offsets = offsets | |
| self.add_type_to_marker = add_type_to_marker | |
| def is_head(self) -> bool: | |
| return self.role == HEAD | |
| def is_tail(self) -> bool: | |
| return self.role == TAIL | |
| def as_start_marker(self) -> str: | |
| return self._get_marker(is_start=True) | |
| def as_end_marker(self) -> str: | |
| return self._get_marker(is_start=False) | |
| def _get_marker(self, is_start: bool = True) -> str: | |
| return f"[{'' if is_start else '/'}{'H' if self.is_head else 'T'}" + ( | |
| f":{self.entity.label}]" if self.add_type_to_marker else "]" | |
| ) | |
| def as_append_marker(self) -> str: | |
| return f"[{'H' if self.is_head else 'T'}={self.entity.label}]" | |
| def _enumerate_entity_pairs( | |
| entities: Sequence[Span], | |
| partition: Optional[Span] = None, | |
| relations: Optional[Sequence[BinaryRelation]] = None, | |
| ): | |
| """Given a list of `entities` iterate all valid pairs of entities, including inverted pairs. | |
| If a `partition` is provided, restrict pairs to be contained in that. If `relations` are given, | |
| return only pairs for which a predefined relation exists (e.g. in the case of relation | |
| classification for train,val,test splits in supervised datasets). | |
| """ | |
| existing_head_tail = {(relation.head, relation.tail) for relation in relations or []} | |
| for head in entities: | |
| if partition is not None and not is_contained_in( | |
| (head.start, head.end), (partition.start, partition.end) | |
| ): | |
| continue | |
| for tail in entities: | |
| if partition is not None and not is_contained_in( | |
| (tail.start, tail.end), (partition.start, partition.end) | |
| ): | |
| continue | |
| if head == tail: | |
| continue | |
| if relations is not None and (head, tail) not in existing_head_tail: | |
| continue | |
| yield head, tail | |
| class TransformerRETextClassificationTaskModule2(_TransformerReTextClassificationTaskModule2): | |
| """Marker based relation extraction. This taskmodule prepares the input token ids in such a way | |
| that before and after the candidate head and tail entities special marker tokens are inserted. | |
| Then, the modified token ids can be simply passed into a transformer based text classifier | |
| model. | |
| parameters: | |
| partition_annotation: str, optional. If specified, LabeledSpan annotations with this name are | |
| expected to define partitions of the document that will be processed individually, e.g. sentences | |
| or sections of the document text. | |
| none_label: str, defaults to "no_relation". The relation label that indicate dummy/negative relations. | |
| Predicted relations with that label will not be added to the document(s). | |
| max_window: int, optional. If specified, use the tokens in a window of maximal this amount of tokens | |
| around the center of head and tail entities and pass only that into the transformer. | |
| """ | |
| PREPARED_ATTRIBUTES = ["label_to_id", "entity_labels"] | |
| def __init__( | |
| self, | |
| tokenizer_name_or_path: str, | |
| entity_annotation: str = "entities", | |
| relation_annotation: str = "relations", | |
| partition_annotation: Optional[str] = None, | |
| none_label: str = "no_relation", | |
| padding: Union[bool, str, PaddingStrategy] = True, | |
| truncation: Union[bool, str, TruncationStrategy] = True, | |
| max_length: Optional[int] = None, | |
| pad_to_multiple_of: Optional[int] = None, | |
| multi_label: bool = False, | |
| label_to_id: Optional[Dict[str, int]] = None, | |
| add_type_to_marker: bool = False, | |
| single_argument_pair: bool = True, | |
| append_markers: bool = False, | |
| entity_labels: Optional[List[str]] = None, | |
| max_window: Optional[int] = None, | |
| log_first_n_examples: Optional[int] = None, | |
| **kwargs, | |
| ) -> None: | |
| super().__init__(**kwargs) | |
| self.save_hyperparameters() | |
| self.entity_annotation = entity_annotation | |
| self.relation_annotation = relation_annotation | |
| self.padding = padding | |
| self.truncation = truncation | |
| self.label_to_id = label_to_id or {} | |
| self.id_to_label = {v: k for k, v in self.label_to_id.items()} | |
| self.max_length = max_length | |
| self.pad_to_multiple_of = pad_to_multiple_of | |
| self.multi_label = multi_label | |
| self.add_type_to_marker = add_type_to_marker | |
| self.single_argument_pair = single_argument_pair | |
| self.append_markers = append_markers | |
| self.entity_labels = entity_labels | |
| self.partition_annotation = partition_annotation | |
| self.none_label = none_label | |
| self.max_window = max_window | |
| self.log_first_n_examples = log_first_n_examples | |
| self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) | |
| self.argument_markers = None | |
| self._logged_examples_counter = 0 | |
| def _prepare(self, documents: Sequence[TextDocument]) -> None: | |
| entity_labels: Set[str] = set() | |
| relation_labels: Set[str] = set() | |
| for document in documents: | |
| entities: Sequence[LabeledSpan] = document[self.entity_annotation] | |
| relations: Sequence[BinaryRelation] = document[self.relation_annotation] | |
| for entity in entities: | |
| entity_labels.add(entity.label) | |
| for relation in relations: | |
| relation_labels.add(relation.label) | |
| if self.none_label in relation_labels: | |
| relation_labels.remove(self.none_label) | |
| self.label_to_id = {label: i + 1 for i, label in enumerate(sorted(relation_labels))} | |
| self.label_to_id[self.none_label] = 0 | |
| self.entity_labels = sorted(entity_labels) | |
| def _post_prepare(self): | |
| self.argument_markers = self._initialize_argument_markers() | |
| self.tokenizer.add_tokens(self.argument_markers, special_tokens=True) | |
| self.argument_markers_to_id = { | |
| marker: self.tokenizer.vocab[marker] for marker in self.argument_markers | |
| } | |
| self.sep_token_id = self.tokenizer.vocab[self.tokenizer.sep_token] | |
| self.id_to_label = {v: k for k, v in self.label_to_id.items()} | |
| def _initialize_argument_markers(self) -> List[str]: | |
| argument_markers: Set[str] = set() | |
| for arg_type in [HEAD, TAIL]: | |
| for arg_pos in [START, END]: | |
| is_head = arg_type == HEAD | |
| is_start = arg_pos == START | |
| argument_markers.add(f"[{'' if is_start else '/'}{'H' if is_head else 'T'}]") | |
| if self.add_type_to_marker: | |
| for entity_type in self.entity_labels: # type: ignore | |
| argument_markers.add( | |
| f"[{'' if is_start else '/'}{'H' if is_head else 'T'}" | |
| f"{':' + entity_type if self.add_type_to_marker else ''}]" | |
| ) | |
| if self.append_markers: | |
| for entity_type in self.entity_labels: # type: ignore | |
| argument_markers.add(f"[{'H' if is_head else 'T'}={entity_type}]") | |
| return sorted(list(argument_markers)) | |
| def _encode_text( | |
| self, | |
| document: TextDocument, | |
| partition: Optional[Span] = None, | |
| add_special_tokens: bool = True, | |
| ) -> BatchEncoding: | |
| text = ( | |
| document.text[partition.start : partition.end] | |
| if partition is not None | |
| else document.text | |
| ) | |
| encoding = self.tokenizer( | |
| text, | |
| padding=False, | |
| truncation=self.truncation, | |
| max_length=self.max_length, | |
| is_split_into_words=False, | |
| return_offsets_mapping=False, | |
| add_special_tokens=add_special_tokens, | |
| ) | |
| return encoding | |
| def encode_input( | |
| self, | |
| document: TextDocument, | |
| is_training: bool = False, | |
| ) -> Optional[ | |
| Union[ | |
| TransformerReTextClassificationTaskEncoding2, | |
| Sequence[TransformerReTextClassificationTaskEncoding2], | |
| ] | |
| ]: | |
| assert ( | |
| self.argument_markers is not None | |
| ), "No argument markers available, was `prepare` already called?" | |
| entities: Sequence[Span] = document[self.entity_annotation] | |
| relations: Sequence[BinaryRelation] = document[self.relation_annotation] | |
| # if no relations are predefined, use None so that enumerate_entities yields all pairs | |
| # should be fixed to a parameter "generate_all" or "restrict_to_existing_relations" | |
| if len(relations) == 0: | |
| relations = None | |
| partitions: Sequence[Optional[Span]] | |
| if self.partition_annotation is not None: | |
| partitions = document[self.partition_annotation] | |
| else: | |
| # use single dummy partition | |
| partitions = [None] | |
| task_encodings: List[TransformerReTextClassificationTaskEncoding2] = [] | |
| for partition_idx, partition in enumerate(partitions): | |
| partition_offset = 0 if partition is None else partition.start | |
| add_special_tokens = self.max_window is None | |
| encoding = self._encode_text( | |
| document=document, partition=partition, add_special_tokens=add_special_tokens | |
| ) | |
| for (head, tail,) in _enumerate_entity_pairs( | |
| entities=entities, | |
| partition=partition, | |
| relations=relations, | |
| ): | |
| head_token_slice = get_token_slice( | |
| character_slice=(head.start, head.end), | |
| char_to_token_mapper=encoding.char_to_token, | |
| character_offset=partition_offset, | |
| ) | |
| tail_token_slice = get_token_slice( | |
| character_slice=(tail.start, tail.end), | |
| char_to_token_mapper=encoding.char_to_token, | |
| character_offset=partition_offset, | |
| ) | |
| # this happens if the head/tail start/end does not match a token start/end | |
| if head_token_slice is None or tail_token_slice is None: | |
| # if statistics is not None: | |
| # statistics["entity_token_alignment_error"][ | |
| # relation_mapping.get((head, tail), "TO_PREDICT") | |
| # ] += 1 | |
| logger.warning( | |
| f"Skipping invalid example {document.id}, cannot get token slice(s)" | |
| ) | |
| continue | |
| input_ids = encoding["input_ids"] | |
| # not sure if this is the correct way to get the tokens corresponding to the input_ids | |
| tokens = encoding.encodings[0].tokens | |
| # windowing | |
| if self.max_window is not None: | |
| head_start, head_end = head_token_slice | |
| tail_start, tail_end = tail_token_slice | |
| # The actual number of tokens will be lower than max_window because we add the | |
| # 4 marker tokens (before / after the head /tail) and the default special tokens | |
| # (e.g. CLS and SEP). | |
| num_added_special_tokens = len( | |
| self.tokenizer.build_inputs_with_special_tokens([]) | |
| ) | |
| max_tokens = self.max_window - 4 - num_added_special_tokens | |
| # the slice from the beginning of the first entity to the end of the second is required | |
| slice_required = (min(head_start, tail_start), max(head_end, tail_end)) | |
| window_slice = get_window_around_slice( | |
| slice=slice_required, | |
| max_window_size=max_tokens, | |
| available_input_length=len(input_ids), | |
| ) | |
| # this happens if slice_required does not fit into max_tokens | |
| if window_slice is None: | |
| # if statistics is not None: | |
| # statistics["out_of_token_window"][ | |
| # relation_mapping.get((head, tail), "TO_PREDICT") | |
| # ] += 1 | |
| continue | |
| window_start, window_end = window_slice | |
| input_ids = input_ids[window_start:window_end] | |
| head_token_slice = head_start - window_start, head_end - window_start | |
| tail_token_slice = tail_start - window_start, tail_end - window_start | |
| # maybe expand to n-ary relations? | |
| head_arg = RelationArgument(head, HEAD, head_token_slice, self.add_type_to_marker) | |
| tail_arg = RelationArgument(tail, TAIL, tail_token_slice, self.add_type_to_marker) | |
| arg_list = [head_arg, tail_arg] | |
| if head_token_slice[0] < tail_token_slice[0]: | |
| assert ( | |
| head_token_slice[1] <= tail_token_slice[0] | |
| ), f"the head and tail entities are not allowed to overlap in {document.id}" | |
| else: | |
| assert ( | |
| tail_token_slice[1] <= head_token_slice[0] | |
| ), f"the head and tail entities are not allowed to overlap in {document.id}" | |
| # expand to n-ary relations? | |
| arg_list.reverse() | |
| first_arg_start_id = self.argument_markers_to_id[arg_list[0].as_start_marker] | |
| first_arg_end_id = self.argument_markers_to_id[arg_list[0].as_end_marker] | |
| second_arg_start_id = self.argument_markers_to_id[arg_list[1].as_start_marker] | |
| second_arg_end_id = self.argument_markers_to_id[arg_list[1].as_end_marker] | |
| new_input_ids = ( | |
| input_ids[: arg_list[0].offsets[0]] | |
| + [first_arg_start_id] | |
| + input_ids[arg_list[0].offsets[0] : arg_list[0].offsets[1]] | |
| + [first_arg_end_id] | |
| + input_ids[arg_list[0].offsets[1] : arg_list[1].offsets[0]] | |
| + [second_arg_start_id] | |
| + input_ids[arg_list[1].offsets[0] : arg_list[1].offsets[1]] | |
| + [second_arg_end_id] | |
| + input_ids[arg_list[1].offsets[1] :] | |
| ) | |
| if self.append_markers: | |
| new_input_ids.extend( | |
| [ | |
| self.argument_markers_to_id[head_arg.as_append_marker], | |
| self.sep_token_id, | |
| self.argument_markers_to_id[tail_arg.as_append_marker], | |
| self.sep_token_id, | |
| ] | |
| ) | |
| # when windowing is used, we have to add the special tokens manually | |
| if not add_special_tokens: | |
| new_input_ids = self.tokenizer.build_inputs_with_special_tokens( | |
| token_ids_0=new_input_ids | |
| ) | |
| # lots of logging from here on | |
| log_this_example = ( | |
| relations is not None | |
| and self.log_first_n_examples is not None | |
| and self._logged_examples_counter <= self.log_first_n_examples | |
| ) | |
| if log_this_example: | |
| self._log_example(document, arg_list, new_input_ids, relations, tokens) | |
| task_encodings.append( | |
| TaskEncoding( | |
| document=document, | |
| inputs={"input_ids": new_input_ids}, | |
| metadata={ | |
| HEAD: head, | |
| TAIL: tail, | |
| }, | |
| ) | |
| ) | |
| return task_encodings | |
| def _log_example( | |
| self, | |
| document: TextDocument, | |
| arg_list: List[RelationArgument], | |
| input_ids: List[int], | |
| relations: Sequence[BinaryRelation], | |
| tokens: List[str], | |
| ): | |
| first_arg_start = arg_list[0].as_start_marker | |
| first_arg_end = arg_list[0].as_end_marker | |
| second_arg_start = arg_list[1].as_start_marker | |
| second_arg_end = arg_list[1].as_end_marker | |
| new_tokens = ( | |
| tokens[: arg_list[0].offsets[0]] | |
| + [first_arg_start] | |
| + tokens[arg_list[0].offsets[0] : arg_list[0].offsets[1]] | |
| + [first_arg_end] | |
| + tokens[arg_list[0].offsets[1] : arg_list[1].offsets[0]] | |
| + [second_arg_start] | |
| + tokens[arg_list[1].offsets[0] : arg_list[1].offsets[1]] | |
| + [second_arg_end] | |
| + tokens[arg_list[1].offsets[1] :] | |
| ) | |
| head_idx = 0 if arg_list[0].role == HEAD else 1 | |
| tail_idx = 0 if arg_list[0].role == TAIL else 1 | |
| if self.append_markers: | |
| head_marker = arg_list[head_idx].as_append_marker | |
| tail_marker = arg_list[tail_idx].as_append_marker | |
| new_tokens.extend( | |
| [head_marker, self.tokenizer.sep_token, tail_marker, self.tokenizer.sep_token] | |
| ) | |
| logger.info("*** Example ***") | |
| logger.info("doc id: %s", document.id) | |
| logger.info("tokens: %s", " ".join([str(x) for x in new_tokens])) | |
| logger.info("input_ids: %s", " ".join([str(x) for x in input_ids])) | |
| rel_labels = [relation.label for relation in relations] | |
| rel_label_ids = [self.label_to_id[label] for label in rel_labels] | |
| logger.info("Expected labels: %s (ids = %s)", rel_labels, rel_label_ids) | |
| self._logged_examples_counter += 1 | |
| def encode_target( | |
| self, | |
| task_encoding: TransformerReTextClassificationTaskEncoding2, | |
| ) -> TransformerReTextClassificationTargetEncoding2: | |
| metadata = task_encoding.metadata | |
| document = task_encoding.document | |
| relations: Sequence[BinaryRelation] = document[self.relation_annotation] | |
| head_tail_to_labels = { | |
| (relation.head, relation.tail): [relation.label] for relation in relations | |
| } | |
| labels = head_tail_to_labels.get((metadata[HEAD], metadata[TAIL]), [self.none_label]) | |
| target = [self.label_to_id[label] for label in labels] | |
| return target | |
| def unbatch_output( | |
| self, model_output: TransformerTextClassificationModelBatchOutput | |
| ) -> Sequence[TransformerReTextClassificationTaskOutput2]: | |
| logits = model_output["logits"] | |
| output_label_probs = logits.sigmoid() if self.multi_label else logits.softmax(dim=-1) | |
| output_label_probs = output_label_probs.detach().cpu().numpy() | |
| unbatched_output = [] | |
| if self.multi_label: | |
| raise NotImplementedError | |
| else: | |
| label_ids = np.argmax(output_label_probs, axis=-1) | |
| for batch_idx, label_id in enumerate(label_ids): | |
| label = self.id_to_label[label_id] | |
| prob = float(output_label_probs[batch_idx, label_id]) | |
| result: TransformerReTextClassificationTaskOutput2 = { | |
| "labels": [label], | |
| "probabilities": [prob], | |
| } | |
| unbatched_output.append(result) | |
| return unbatched_output | |
| def create_annotations_from_output( | |
| self, | |
| task_encoding: TransformerReTextClassificationTaskEncoding2, | |
| task_output: TransformerReTextClassificationTaskOutput2, | |
| ) -> Iterator[Tuple[str, Union[BinaryRelation, MultiLabeledBinaryRelation]]]: | |
| labels = task_output["labels"] | |
| probabilities = task_output["probabilities"] | |
| if labels != [self.none_label]: | |
| yield ( | |
| self.relation_annotation, | |
| BinaryRelation( | |
| head=task_encoding.metadata[HEAD], | |
| tail=task_encoding.metadata[TAIL], | |
| label=labels[0], | |
| score=probabilities[0], | |
| ), | |
| ) | |
| def collate( | |
| self, task_encodings: Sequence[TransformerReTextClassificationTaskEncoding2] | |
| ) -> TransformerTextClassificationModelStepBatchEncoding: | |
| input_features = [task_encoding.inputs for task_encoding in task_encodings] | |
| inputs: Dict[str, torch.Tensor] = self.tokenizer.pad( | |
| input_features, | |
| padding=self.padding, | |
| max_length=self.max_length, | |
| pad_to_multiple_of=self.pad_to_multiple_of, | |
| return_tensors="pt", | |
| ) | |
| if not task_encodings[0].has_targets: | |
| return inputs, None | |
| target_list: List[TransformerReTextClassificationTargetEncoding2] = [ | |
| task_encoding.targets for task_encoding in task_encodings | |
| ] | |
| targets = torch.tensor(target_list, dtype=torch.int64) | |
| if not self.multi_label: | |
| targets = targets.flatten() | |
| return inputs, targets | |