Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| from datasets import Dataset | |
| import os | |
| # Load the pre-trained BART model and tokenizer | |
| bart_model_path = "BART_samsum" | |
| if os.path.exists(bart_model_path): | |
| bart_model = AutoModelForSeq2SeqLM.from_pretrained(bart_model_path) | |
| bart_tokenizer = AutoTokenizer.from_pretrained(bart_model_path) | |
| print(f"Loaded BART model from {bart_model_path}") | |
| else: | |
| print(f"BART model not found in {bart_model_path}. Please ensure the model is uploaded.") | |
| exit(1) | |
| # Function to perform task summarization | |
| def summarize_task(input_text): | |
| inputs = bart_tokenizer(input_text, return_tensors="pt", max_length=1024, truncation=True) | |
| summary_ids = bart_model.generate(inputs["input_ids"], max_length=150, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True) | |
| summary = bart_tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
| return summary | |
| # Set up the Gradio interface | |
| iface = gr.Interface( | |
| fn=summarize_task, | |
| inputs=gr.Textbox(lines=5, label="Input Task"), | |
| outputs=gr.Textbox(label="Task Summary"), | |
| title="Task Summarization", | |
| description="Enter a task description and get a summary.", | |
| theme="huggingface", | |
| examples=[["Develop a Python script that reads data from a CSV file, processes it to calculate average sales per region, and generates a bar chart to visualize the results."]], | |
| ) | |
| # Launch the app | |
| iface.launch() |