|
|
from transformers import T5Tokenizer, T5ForConditionalGeneration |
|
|
|
|
|
def summarize_text(text: str, |
|
|
model_name: str = "t5-base", |
|
|
max_length: int = 150, |
|
|
min_length: int = 40, |
|
|
num_beams: int = 4) -> str: |
|
|
""" |
|
|
Summarizes the given text using a T5 model. |
|
|
|
|
|
Parameters: |
|
|
- text: The long input text to be summarized. |
|
|
- model_name: The pre-trained T5 model to use (e.g., "t5-base", "t5-small", etc.) |
|
|
- max_length: The maximum length (in tokens) of the generated summary. |
|
|
- min_length: The minimum length (in tokens) of the generated summary. |
|
|
- num_beams: The number of beams for beam search (affects summary quality). |
|
|
|
|
|
Returns: |
|
|
- The summarized text (str) |
|
|
""" |
|
|
|
|
|
tokenizer = T5Tokenizer.from_pretrained(model_name, legacy=False) |
|
|
model = T5ForConditionalGeneration.from_pretrained(model_name) |
|
|
|
|
|
|
|
|
input_text = "summarize: " + text.strip() |
|
|
|
|
|
|
|
|
input_ids = tokenizer.encode(input_text, |
|
|
return_tensors="pt", |
|
|
max_length=512, |
|
|
truncation=True) |
|
|
|
|
|
|
|
|
summary_ids = model.generate(input_ids, |
|
|
max_length=max_length, |
|
|
min_length=min_length, |
|
|
num_beams=num_beams, |
|
|
early_stopping=True) |
|
|
|
|
|
|
|
|
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) |
|
|
return summary |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
long_text = ( |
|
|
"In recent years, the global economy has faced various challenges. Trade tensions, " |
|
|
"inflationary pressures, and rapid technological advancements have contributed to " |
|
|
"significant changes in market dynamics. Experts believe that these factors will continue " |
|
|
"to influence economic trends, while governments around the world are exploring policies " |
|
|
"to stabilize the economy. Meanwhile, the rise of the digital economy and the transition " |
|
|
"to green energy are emerging as key drivers of future economic growth." |
|
|
) |
|
|
|
|
|
|
|
|
summary_result = summarize_text(long_text) |
|
|
print("Summary:") |
|
|
print(summary_result) |
|
|
|