| import torch |
|
|
|
|
| def _selectStrategyLength(sentences, predictions, max_length): |
| selected_sents = [] |
| sents_priority = torch.argsort(predictions, descending=True) |
| summary_len = 0 |
| i = 0 |
|
|
| while (summary_len < max_length) and (i < len(sents_priority)): |
| if summary_len + len(sentences[sents_priority[i]]) < max_length: |
| selected_sents.append(sents_priority[i].item()) |
| summary_len += len(sentences[sents_priority[i]]) |
| i += 1 |
|
|
| return sorted(selected_sents) |
|
|
|
|
| def _selectStrategyCount(sentences, predictions, num_sents): |
| selected_idxs = sorted(torch.topk(predictions, min(len(predictions), num_sents)).indices) |
| return [tensor.item() for tensor in selected_idxs] |
|
|
|
|
| def _selectStrategyRatio(sentences, predictions, ratio): |
| doc_length = sum([ len(sent) for sent in sentences ]) |
| return _selectStrategyLength(sentences, predictions, doc_length*ratio) |
|
|
|
|
| def _selectStrategyThreshold(sentences, predictions, threshold): |
| return [i for i, score in enumerate(predictions) if score >= threshold] |
|
|
|
|
| def select(sentences, predictions, strategy, strategy_args): |
| selected_sents = [] |
|
|
| if strategy == "length": |
| selected_sents = _selectStrategyLength(sentences, predictions, strategy_args) |
| elif strategy == "count": |
| selected_sents = _selectStrategyCount(sentences, predictions, strategy_args) |
| elif strategy == "ratio": |
| selected_sents = _selectStrategyRatio(sentences, predictions, strategy_args) |
| elif strategy == "threshold": |
| selected_sents = _selectStrategyThreshold(sentences, predictions, strategy_args) |
| else: |
| raise NotImplementedError(f"Unknown strategy {strategy}") |
| |
| return [sentences[i] for i in selected_sents], selected_sents |
|
|
|
|
|
|
| """ |
| Splits a document in chunks of maximum a given size. |
| |
| Parameters |
| ---------- |
| doc_tokens : str[] |
| List of the tokens of the document. |
| |
| bos_token : str |
| Begin of sentence token. |
| |
| eos_token : str |
| End of sentence token. |
| |
| max_size : int |
| Maximum size of a chunk. |
| Returns |
| ------- |
| chunks : str[][] |
| Splitted document. |
| """ |
| def splitDocument(doc_tokens, bos_token, eos_token, max_size): |
| def _findNextBOSFrom(start_idx): |
| for i in range(start_idx, len(doc_tokens)): |
| if doc_tokens[i] == bos_token: |
| return i |
| return -1 |
| |
| def _findPreviousEOSFrom(start_idx): |
| for i in range(start_idx, -1, -1): |
| if doc_tokens[i] == eos_token: |
| return i |
| return -1 |
| |
| chunks = [] |
| |
| while len(doc_tokens) > max_size: |
| |
| eos_idx = _findPreviousEOSFrom(max_size - 1) |
|
|
| if eos_idx == -1: |
| |
| |
| next_bos_idx = _findNextBOSFrom(max_size) |
| if next_bos_idx != -1: |
| doc_tokens = doc_tokens[:max_size-1] + [eos_token] + doc_tokens[next_bos_idx:] |
| else: |
| doc_tokens = doc_tokens[:max_size-1] + [eos_token] |
| eos_idx = max_size - 1 |
|
|
| chunks.append(doc_tokens[:eos_idx+1]) |
| doc_tokens = doc_tokens[eos_idx+1:] |
|
|
| if len(doc_tokens) > 0: chunks.append(doc_tokens) |
| |
| return chunks |