| | import torch |
| | import gradio as gr |
| | from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer, AutoConfig |
| |
|
| | |
| | model_names = [ |
| | "facebook/bart-large-cnn", |
| | "tsmatz/mt5_summarize_japanese", |
| | "avisena/bart-base-job-info-summarizer", |
| | "RussianNLP/FRED-T5-Summarizer", |
| | "google/flan-t5-small", |
| | "prithivMLmods/t5-Flan-Prompt-Enhance" |
| | ] |
| |
|
| | |
| | summarizer = None |
| | tokenizer = None |
| | max_tokens = None |
| |
|
| | |
| | example_text = ( |
| | "Artificial intelligence (AI) is intelligence—perceiving, synthesizing, and inferring information—" |
| | "demonstrated by machines, as opposed to intelligence displayed by non-human animals and humans. " |
| | "Example tasks in which AI is employed include speech recognition, computer vision, language translation, " |
| | "autonomous vehicles, and game playing. AI research has been defined as the field of study of intelligent " |
| | "agents, which refers to any system that perceives its environment and takes actions that maximize its " |
| | "chance of achieving its goals." |
| | ) |
| |
|
| | |
| | def load_model(model_name): |
| | global summarizer, tokenizer, max_tokens |
| | try: |
| | |
| | summarizer = pipeline("summarization", model=model_name, torch_dtype=torch.float32) |
| | tokenizer = AutoTokenizer.from_pretrained(model_name) |
| | config = AutoConfig.from_pretrained(model_name) |
| |
|
| | |
| | max_tokens = getattr(config, 'max_position_embeddings', 1024) |
| |
|
| | return f"Model {model_name} loaded successfully! Max tokens: {max_tokens}" |
| | except Exception as e: |
| | return f"Failed to load model {model_name}. Error: {str(e)}" |
| |
|
| | |
| | def summarize_text(input, min_length, max_length): |
| | if summarizer is None: |
| | return "No model loaded!" |
| |
|
| | try: |
| | |
| | input_tokens = tokenizer.encode(input, return_tensors="pt") |
| | num_tokens = input_tokens.shape[1] |
| | if num_tokens > max_tokens: |
| | return f"Error: Input exceeds the max token limit of {max_tokens}." |
| |
|
| | |
| | min_summary_length = max(10, int(num_tokens * (min_length / 100))) |
| | max_summary_length = min(max_tokens, int(num_tokens * (max_length / 100))) |
| |
|
| | |
| | output = summarizer(input, min_length=min_summary_length, max_length=max_summary_length, truncation=True) |
| | return output[0]['summary_text'] |
| | except Exception as e: |
| | return f"Summarization failed: {str(e)}" |
| |
|
| | |
| | with gr.Blocks() as demo: |
| | with gr.Row(): |
| | model_dropdown = gr.Dropdown(choices=model_names, label="Choose a model", value="sshleifer/distilbart-cnn-12-6") |
| | load_button = gr.Button("Load Model") |
| |
|
| | load_message = gr.Textbox(label="Load Status", interactive=False) |
| |
|
| | min_length_slider = gr.Slider(minimum=0, maximum=100, step=1, label="Minimum Summary Length (%)", value=10) |
| | max_length_slider = gr.Slider(minimum=0, maximum=100, step=1, label="Maximum Summary Length (%)", value=20) |
| |
|
| | input_text = gr.Textbox(label="Input text to summarize", lines=6, value=example_text) |
| | summarize_button = gr.Button("Summarize Text") |
| | output_text = gr.Textbox(label="Summarized text", lines=4) |
| |
|
| | load_button.click(fn=load_model, inputs=model_dropdown, outputs=load_message) |
| | summarize_button.click(fn=summarize_text, inputs=[input_text, min_length_slider, max_length_slider], |
| | outputs=output_text) |
| |
|
| | demo.launch() |
| |
|