| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | from ..models.auto import AutoModelForSeq2SeqLM, AutoTokenizer |
| | from .base import PipelineTool |
| |
|
| |
|
| | class TextSummarizationTool(PipelineTool): |
| | """ |
| | Example: |
| | |
| | ```py |
| | from transformers.tools import TextSummarizationTool |
| | |
| | summarizer = TextSummarizationTool() |
| | summarizer(long_text) |
| | ``` |
| | """ |
| |
|
| | default_checkpoint = "philschmid/bart-large-cnn-samsum" |
| | description = ( |
| | "This is a tool that summarizes an English text. It takes an input `text` containing the text to summarize, " |
| | "and returns a summary of the text." |
| | ) |
| | name = "summarizer" |
| | pre_processor_class = AutoTokenizer |
| | model_class = AutoModelForSeq2SeqLM |
| |
|
| | inputs = ["text"] |
| | outputs = ["text"] |
| |
|
| | def encode(self, text): |
| | return self.pre_processor(text, return_tensors="pt", truncation=True) |
| |
|
| | def forward(self, inputs): |
| | return self.model.generate(**inputs)[0] |
| |
|
| | def decode(self, outputs): |
| | return self.pre_processor.decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True) |
| |
|