Maddy90's picture
Update app.py
3c40fa9 verified
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import Dataset
import os
# Load the pre-trained BART model and tokenizer
bart_model_path = "BART_samsum"
if os.path.exists(bart_model_path):
bart_model = AutoModelForSeq2SeqLM.from_pretrained(bart_model_path)
bart_tokenizer = AutoTokenizer.from_pretrained(bart_model_path)
print(f"Loaded BART model from {bart_model_path}")
else:
print(f"BART model not found in {bart_model_path}. Please ensure the model is uploaded.")
exit(1)
# Function to perform task summarization
def summarize_task(input_text):
inputs = bart_tokenizer(input_text, return_tensors="pt", max_length=1024, truncation=True)
summary_ids = bart_model.generate(inputs["input_ids"], max_length=150, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True)
summary = bart_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return summary
# Set up the Gradio interface
iface = gr.Interface(
fn=summarize_task,
inputs=gr.Textbox(lines=5, label="Input Task"),
outputs=gr.Textbox(label="Task Summary"),
title="Task Summarization",
description="Enter a task description and get a summary.",
theme="huggingface",
examples=[["Develop a Python script that reads data from a CSV file, processes it to calculate average sales per region, and generates a bar chart to visualize the results."]],
)
# Launch the app
iface.launch()