Spaces:
Runtime error
Runtime error
| import collections | |
| import itertools | |
| from dataclasses import dataclass | |
| from typing import List, Optional, Set, Tuple | |
| from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer | |
| from relik.reader.data.relik_reader_sample import RelikReaderSample | |
| class Window: | |
| doc_id: int | |
| window_id: int | |
| text: str | |
| tokens: List[str] | |
| doc_topic: Optional[str] | |
| offset: int | |
| token2char_start: dict | |
| token2char_end: dict | |
| window_candidates: Optional[List[str]] = None | |
| class WindowManager: | |
| def __init__(self, tokenizer: BaseTokenizer) -> None: | |
| self.tokenizer = tokenizer | |
| def tokenize(self, document: str) -> Tuple[List[str], List[Tuple[int, int]]]: | |
| tokenized_document = self.tokenizer(document) | |
| tokens = [] | |
| tokens_char_mapping = [] | |
| for token in tokenized_document: | |
| tokens.append(token.text) | |
| tokens_char_mapping.append((token.start_char, token.end_char)) | |
| return tokens, tokens_char_mapping | |
| def create_windows( | |
| self, | |
| document: str, | |
| window_size: int, | |
| stride: int, | |
| doc_id: int = 0, | |
| doc_topic: str = None, | |
| ) -> List[RelikReaderSample]: | |
| document_tokens, tokens_char_mapping = self.tokenize(document) | |
| if doc_topic is None: | |
| doc_topic = document_tokens[0] if len(document_tokens) > 0 else "" | |
| document_windows = [] | |
| if len(document_tokens) <= window_size: | |
| text = document | |
| # relik_reader_sample = RelikReaderSample() | |
| document_windows.append( | |
| # Window( | |
| RelikReaderSample( | |
| doc_id=doc_id, | |
| window_id=0, | |
| text=text, | |
| tokens=document_tokens, | |
| doc_topic=doc_topic, | |
| offset=0, | |
| token2char_start={ | |
| str(i): tokens_char_mapping[i][0] | |
| for i in range(len(document_tokens)) | |
| }, | |
| token2char_end={ | |
| str(i): tokens_char_mapping[i][1] | |
| for i in range(len(document_tokens)) | |
| }, | |
| ) | |
| ) | |
| else: | |
| for window_id, i in enumerate(range(0, len(document_tokens), stride)): | |
| # if the last stride is smaller than the window size, then we can | |
| # include more tokens form the previous window. | |
| if i != 0 and i + window_size > len(document_tokens): | |
| overflowing_tokens = i + window_size - len(document_tokens) | |
| if overflowing_tokens >= stride: | |
| break | |
| i -= overflowing_tokens | |
| involved_token_indices = list( | |
| range(i, min(i + window_size, len(document_tokens) - 1)) | |
| ) | |
| window_tokens = [document_tokens[j] for j in involved_token_indices] | |
| window_text_start = tokens_char_mapping[involved_token_indices[0]][0] | |
| window_text_end = tokens_char_mapping[involved_token_indices[-1]][1] | |
| text = document[window_text_start:window_text_end] | |
| document_windows.append( | |
| # Window( | |
| RelikReaderSample( | |
| # dict( | |
| doc_id=doc_id, | |
| window_id=window_id, | |
| text=text, | |
| tokens=window_tokens, | |
| doc_topic=doc_topic, | |
| offset=window_text_start, | |
| token2char_start={ | |
| str(i): tokens_char_mapping[ti][0] | |
| for i, ti in enumerate(involved_token_indices) | |
| }, | |
| token2char_end={ | |
| str(i): tokens_char_mapping[ti][1] | |
| for i, ti in enumerate(involved_token_indices) | |
| }, | |
| # ) | |
| ) | |
| ) | |
| return document_windows | |
| def merge_windows( | |
| self, windows: List[RelikReaderSample] | |
| ) -> List[RelikReaderSample]: | |
| windows_by_doc_id = collections.defaultdict(list) | |
| for window in windows: | |
| windows_by_doc_id[window.doc_id].append(window) | |
| merged_window_by_doc = { | |
| doc_id: self.merge_doc_windows(doc_windows) | |
| for doc_id, doc_windows in windows_by_doc_id.items() | |
| } | |
| return list(merged_window_by_doc.values()) | |
| def merge_doc_windows(self, windows: List[RelikReaderSample]) -> RelikReaderSample: | |
| if len(windows) == 1: | |
| return windows[0] | |
| if len(windows) > 0 and getattr(windows[0], "offset", None) is not None: | |
| windows = sorted(windows, key=(lambda x: x.offset)) | |
| window_accumulator = windows[0] | |
| for next_window in windows[1:]: | |
| window_accumulator = self._merge_window_pair( | |
| window_accumulator, next_window | |
| ) | |
| return window_accumulator | |
| def _merge_tokens( | |
| self, window1: RelikReaderSample, window2: RelikReaderSample | |
| ) -> Tuple[list, dict, dict]: | |
| w1_tokens = window1.tokens[1:-1] | |
| w2_tokens = window2.tokens[1:-1] | |
| # find intersection | |
| tokens_intersection = None | |
| for k in reversed(range(1, len(w1_tokens))): | |
| if w1_tokens[-k:] == w2_tokens[:k]: | |
| tokens_intersection = k | |
| break | |
| assert tokens_intersection is not None, ( | |
| f"{window1.doc_id} - {window1.sent_id} - {window1.offset}" | |
| + f" {window2.doc_id} - {window2.sent_id} - {window2.offset}\n" | |
| + f"w1 tokens: {w1_tokens}\n" | |
| + f"w2 tokens: {w2_tokens}\n" | |
| ) | |
| final_tokens = ( | |
| [window1.tokens[0]] # CLS | |
| + w1_tokens | |
| + w2_tokens[tokens_intersection:] | |
| + [window1.tokens[-1]] # SEP | |
| ) | |
| w2_starting_offset = len(w1_tokens) - tokens_intersection | |
| def merge_char_mapping(t2c1: dict, t2c2: dict) -> dict: | |
| final_t2c = dict() | |
| final_t2c.update(t2c1) | |
| for t, c in t2c2.items(): | |
| t = int(t) | |
| if t < tokens_intersection: | |
| continue | |
| final_t2c[str(t + w2_starting_offset)] = c | |
| return final_t2c | |
| return ( | |
| final_tokens, | |
| merge_char_mapping(window1.token2char_start, window2.token2char_start), | |
| merge_char_mapping(window1.token2char_end, window2.token2char_end), | |
| ) | |
| def _merge_span_annotation( | |
| self, span_annotation1: List[list], span_annotation2: List[list] | |
| ) -> List[list]: | |
| uniq_store = set() | |
| final_span_annotation_store = [] | |
| for span_annotation in itertools.chain(span_annotation1, span_annotation2): | |
| span_annotation_id = tuple(span_annotation) | |
| if span_annotation_id not in uniq_store: | |
| uniq_store.add(span_annotation_id) | |
| final_span_annotation_store.append(span_annotation) | |
| return sorted(final_span_annotation_store, key=lambda x: x[0]) | |
| def _merge_predictions( | |
| self, | |
| window1: RelikReaderSample, | |
| window2: RelikReaderSample, | |
| ) -> Tuple[Set[Tuple[int, int, str]], dict]: | |
| merged_predictions = window1.predicted_window_labels_chars.union( | |
| window2.predicted_window_labels_chars | |
| ) | |
| span_title_probabilities = dict() | |
| # probabilities | |
| for span_prediction, predicted_probs in itertools.chain( | |
| window1.probs_window_labels_chars.items(), | |
| window2.probs_window_labels_chars.items(), | |
| ): | |
| if span_prediction not in span_title_probabilities: | |
| span_title_probabilities[span_prediction] = predicted_probs | |
| return merged_predictions, span_title_probabilities | |
| def _merge_window_pair( | |
| self, | |
| window1: RelikReaderSample, | |
| window2: RelikReaderSample, | |
| ) -> RelikReaderSample: | |
| merging_output = dict() | |
| if getattr(window1, "doc_id", None) is not None: | |
| assert window1.doc_id == window2.doc_id | |
| if getattr(window1, "offset", None) is not None: | |
| assert ( | |
| window1.offset < window2.offset | |
| ), f"window 2 offset ({window2.offset}) is smaller that window 1 offset({window1.offset})" | |
| merging_output["doc_id"] = window1.doc_id | |
| merging_output["offset"] = window2.offset | |
| m_tokens, m_token2char_start, m_token2char_end = self._merge_tokens( | |
| window1, window2 | |
| ) | |
| window_labels = None | |
| if getattr(window1, "window_labels", None) is not None: | |
| window_labels = self._merge_span_annotation( | |
| window1.window_labels, window2.window_labels | |
| ) | |
| ( | |
| predicted_window_labels_chars, | |
| probs_window_labels_chars, | |
| ) = self._merge_predictions( | |
| window1, | |
| window2, | |
| ) | |
| merging_output.update( | |
| dict( | |
| tokens=m_tokens, | |
| token2char_start=m_token2char_start, | |
| token2char_end=m_token2char_end, | |
| window_labels=window_labels, | |
| predicted_window_labels_chars=predicted_window_labels_chars, | |
| probs_window_labels_chars=probs_window_labels_chars, | |
| ) | |
| ) | |
| return RelikReaderSample(**merging_output) | |