| from transformers import Pipeline |
| import torch |
| from .utilities import padToSize |
| from .summary import select, splitDocument |
|
|
|
|
| """ |
| Generates the segments ids for BERT |
| """ |
| def generateSegmentIds(doc_ids, tokenizer): |
| |
| segments_ids = [0] * len(doc_ids) |
| curr_segment = 0 |
|
|
| for i, token in enumerate(doc_ids): |
| segments_ids[i] = curr_segment |
| if token == tokenizer.vocab["[SEP]"]: |
| curr_segment = 1 - curr_segment |
|
|
| return segments_ids |
|
|
|
|
| class ExtSummPipeline(Pipeline): |
| """ |
| Extractive summarization pipeline |
| |
| Inputs |
| ------ |
| inputs : dict |
| 'sentences' : list[str] |
| Sentences of the document |
| |
| strategy : str |
| Strategy to summarize the document: |
| - 'length': summary with a maximum length (strategy_args is the maximum length). |
| - 'count': summary with the given number of sentences (strategy_args is the number of sentences). |
| - 'ratio': summary proportional to the length of the document (strategy_args is the ratio [0, 1]). |
| - 'threshold': summary only with sentences with a score higher than a given value (strategy_args is the minimum score). |
| |
| strategy_args : any |
| Parameters of the strategy. |
| |
| Outputs |
| ------- |
| selected_sents : list[str] |
| List of the selected sentences |
| |
| selected_idxs : list[int] |
| List of the indexes of the selected sentences in the original input |
| """ |
|
|
|
|
| def _sanitize_parameters(self, **kwargs): |
| postprocess_kwargs = {} |
| |
| if ("strategy" in kwargs and "strategy_args" not in kwargs) or ("strategy" not in kwargs and "strategy_args" in kwargs): |
| raise ValueError("`strategy` and `strategy_args` have to be both set") |
| if "strategy" in kwargs: |
| postprocess_kwargs["strategy"] = kwargs["strategy"] |
| if "strategy_args" in kwargs: |
| postprocess_kwargs["strategy_args"] = kwargs["strategy_args"] |
|
|
| return {}, {}, postprocess_kwargs |
|
|
|
|
| def preprocess(self, inputs): |
| sentences = inputs["sentences"] |
|
|
| |
| doc_tokens = self.tokenizer.tokenize( f"{self.tokenizer.sep_token}{self.tokenizer.cls_token}".join(sentences) ) |
| doc_tokens = [self.tokenizer.cls_token] + doc_tokens + [self.tokenizer.sep_token] |
| doc_chunks = splitDocument(doc_tokens, self.tokenizer.cls_token, self.tokenizer.sep_token, self.model.config.input_size) |
| |
| |
| batch = { |
| "ids": [], |
| "segments_ids": [], |
| "clss_mask": [], |
| "attn_mask": [], |
| } |
| for chunk_tokens in doc_chunks: |
| doc_ids = self.tokenizer.convert_tokens_to_ids(chunk_tokens) |
| segment_ids = generateSegmentIds(doc_ids, self.tokenizer) |
| clss_mask = [True if token == self.tokenizer.cls_token_id else False for token in doc_ids] |
| attn_mask = [1 for _ in range(len(doc_ids))] |
|
|
| batch["ids"].append( padToSize(doc_ids, self.model.config.input_size, self.tokenizer.pad_token_id) ) |
| batch["segments_ids"].append( padToSize(segment_ids, self.model.config.input_size, 0) ) |
| batch["clss_mask"].append( padToSize(clss_mask, self.model.config.input_size, False) ) |
| batch["attn_mask"].append( padToSize(attn_mask, self.model.config.input_size, 0) ) |
|
|
| batch["ids"] = torch.as_tensor(batch["ids"]) |
| batch["segments_ids"] = torch.as_tensor(batch["segments_ids"]) |
| batch["clss_mask"] = torch.as_tensor(batch["clss_mask"]) |
| batch["attn_mask"] = torch.as_tensor(batch["attn_mask"]) |
| return { "inputs": batch, "sentences": sentences } |
|
|
|
|
| def _forward(self, args): |
| batch = args["inputs"] |
| sentences = args["sentences"] |
| out_predictions = torch.as_tensor([]).to(self.device) |
|
|
| self.model.eval() |
| with torch.no_grad(): |
| batch_preds, _ = self.model(batch) |
| for i, clss_mask in enumerate(batch["clss_mask"]): |
| out_predictions = torch.cat((out_predictions, batch_preds[i][:torch.sum(clss_mask == True)])) |
|
|
| return { "predictions": out_predictions, "sentences": sentences } |
|
|
|
|
| def postprocess(self, args, strategy: str="count", strategy_args=3): |
| predictions = args["predictions"] |
| sentences = args["sentences"] |
| return select(sentences, predictions, strategy, strategy_args) |