Sazzz02's picture
Update app.py
027d00f verified
# app.py
import gradio as gr
import joblib
import re
import nltk
from nltk.corpus import stopwords
import string
# Download NLTK stopwords if not already present
try:
stopwords.words('english')
except LookupError:
nltk.download('stopwords')
# Define global variables for the model, vectorizer, and stopwords
MODEL_PATH = "random_forest_model.joblib"
VECTORIZER_PATH = "tfidf_vectorizer.joblib"
STOP_WORDS = set(stopwords.words('english'))
# Load the trained model and vectorizer
try:
model = joblib.load(MODEL_PATH)
tfidf_vectorizer = joblib.load(VECTORIZER_PATH)
except FileNotFoundError:
raise FileNotFoundError(
"Model or vectorizer files not found. "
"Please ensure 'random_forest_model.joblib' and 'tfidf_vectorizer.joblib' "
"are in the same directory as this script."
)
def preprocess_text(text):
"""
Cleans and preprocesses text data to match the format used during training.
"""
# Convert to lowercase
text = text.lower()
# Remove punctuation
text = text.translate(str.maketrans('', '', string.punctuation))
# Remove digits
text = re.sub(r'\d+', '', text)
# Remove stopwords
text = ' '.join([word for word in text.split() if word not in STOP_WORDS])
return text
def predict_class(input_text):
"""
Takes raw text input, preprocesses it, and returns the predicted class.
"""
# Preprocess the input text
preprocessed_text = preprocess_text(input_text)
# Use the TF-IDF vectorizer to transform the text
text_vector = tfidf_vectorizer.transform([preprocessed_text])
# Get the model's prediction
prediction = model.predict(text_vector)
# Return the predicted class name
return prediction[0]
# New, more diverse sample inputs for the Gradio app
example_inputs = [
# Example for a Financial Report or Invoice
"Invoice No: INV-2024-001\nDate: 04/11/2024\nCustomer: John Smith\nItem 1: Laptop, QTY: 1, Price: $1200.00\nTotal Amount Due: $1275.00",
# Example for a Legal Document or Contract
"Agenda for our next project synchronization meeting scheduled for Tuesday",
# Example for a Medical Record or Clinical Note
"Patient: Jane Doe, DOB: 12/05/1990\nSymptoms: Severe headache, fever, and persistent cough. Diagnosis: Influenza. Treatment: Prescribed Ibuprofen and advised rest.",
]
# Set up the Gradio interface with examples
interface = gr.Interface(
fn=predict_class,
inputs=gr.Textbox(lines=10, placeholder="Paste your document text here...", label="Input Document Text"),
outputs=gr.Textbox(label="Predicted Document Class"),
title="Document Classification App",
description="This app classifies an input document text into one of five predefined categories.",
examples=example_inputs
)
# Launch the app
if __name__ == "__main__":
interface.launch()