| from transformers import Pipeline |
| import torch |
| from .utilities import padToSize |
| from .summary import select, splitDocument |
|
|
|
|
|
|
| class ExtSummPipeline(Pipeline): |
| """ |
| Extractive summarization pipeline |
| |
| Inputs |
| ------ |
| inputs : dict |
| 'sentences' : list[str] |
| Sentences of the document |
| |
| strategy : str |
| Strategy to summarize the document: |
| - 'length': summary with a maximum length (strategy_args is the maximum length). |
| - 'count': summary with the given number of sentences (strategy_args is the number of sentences). |
| - 'ratio': summary proportional to the length of the document (strategy_args is the ratio [0, 1]). |
| - 'threshold': summary only with sentences with a score higher than a given value (strategy_args is the minimum score). |
| |
| strategy_args : any |
| Parameters of the strategy. |
| |
| out_scores : bool |
| If True, the score for each sentence is returned. |
| Outputs |
| ------- |
| selected_sents : list[str] |
| List of the selected sentences |
| |
| selected_idxs : list[int] |
| List of the indexes of the selected sentences in the original input |
| |
| sents_scores : Tensor (optional) |
| """ |
|
|
| |
| def _sanitize_parameters(self, **kwargs): |
| postprocess_kwargs = {} |
| |
| if ("strategy" in kwargs and "strategy_args" not in kwargs) or ("strategy" not in kwargs and "strategy_args" in kwargs): |
| raise ValueError("`strategy` and `strategy_args` have to be both set") |
| if "strategy" in kwargs: |
| postprocess_kwargs["strategy"] = kwargs["strategy"] |
| if "strategy_args" in kwargs: |
| postprocess_kwargs["strategy_args"] = kwargs["strategy_args"] |
| if "out_scores" in kwargs: |
| postprocess_kwargs["out_scores"] = kwargs["out_scores"] |
|
|
| return {}, {}, postprocess_kwargs |
|
|
|
|
| def preprocess(self, inputs): |
| sentences = inputs["sentences"] |
|
|
| |
| doc_tokens = self.tokenizer.tokenize( f"{self.tokenizer.sep_token}{self.tokenizer.cls_token}".join(sentences) ) |
| doc_tokens = [self.tokenizer.cls_token] + doc_tokens + [self.tokenizer.sep_token] |
| doc_chunks = splitDocument(doc_tokens, self.tokenizer.cls_token, self.tokenizer.sep_token, self.model.config.input_size) |
| |
| |
| batch = { |
| "ids": [], |
| "clss_mask": [], |
| "attn_mask": [], |
| "global_attn_mask": [], |
| } |
| for chunk_tokens in doc_chunks: |
| doc_ids = self.tokenizer.convert_tokens_to_ids(chunk_tokens) |
| clss_mask = [True if token == self.tokenizer.cls_token_id else False for token in doc_ids] |
| attn_mask = [1 for _ in range(len(doc_ids))] |
| global_attn_mask = [1 if token == self.tokenizer.cls_token_id else 0 for token in doc_ids] |
|
|
| batch["ids"].append( padToSize(doc_ids, self.model.config.input_size, self.tokenizer.pad_token_id) ) |
| batch["clss_mask"].append( padToSize(clss_mask, self.model.config.input_size, False) ) |
| batch["attn_mask"].append( padToSize(attn_mask, self.model.config.input_size, 0) ) |
| batch["global_attn_mask"].append( padToSize(global_attn_mask, self.model.config.input_size, 0) ) |
|
|
| batch["ids"] = torch.as_tensor(batch["ids"]) |
| batch["clss_mask"] = torch.as_tensor(batch["clss_mask"]) |
| batch["attn_mask"] = torch.as_tensor(batch["attn_mask"]) |
| batch["global_attn_mask"] = torch.as_tensor(batch["global_attn_mask"]) |
| return { "inputs": batch, "sentences": sentences } |
|
|
|
|
| def _forward(self, args): |
| batch = args["inputs"] |
| sentences = args["sentences"] |
| out_predictions = torch.as_tensor([]).to(self.device) |
|
|
| self.model.eval() |
| with torch.no_grad(): |
| batch_preds, _ = self.model(batch) |
| for i, clss_mask in enumerate(batch["clss_mask"]): |
| out_predictions = torch.cat((out_predictions, batch_preds[i][:torch.sum(clss_mask == True)])) |
|
|
| return { "predictions": out_predictions, "sentences": sentences } |
|
|
|
|
| def postprocess(self, args, strategy: str="count", strategy_args=3, out_scores=False): |
| predictions = args["predictions"] |
| sentences = args["sentences"] |
| out = select(sentences, predictions, strategy, strategy_args) |
|
|
| if out_scores: out += (predictions,) |
|
|
| return out |