Spaces:
Sleeping
Sleeping
| # app.py | |
| import gradio as gr | |
| import joblib | |
| import re | |
| import nltk | |
| from nltk.corpus import stopwords | |
| import string | |
| # Download NLTK stopwords if not already present | |
| try: | |
| stopwords.words('english') | |
| except LookupError: | |
| nltk.download('stopwords') | |
| # Define global variables for the model, vectorizer, and stopwords | |
| MODEL_PATH = "random_forest_model.joblib" | |
| VECTORIZER_PATH = "tfidf_vectorizer.joblib" | |
| STOP_WORDS = set(stopwords.words('english')) | |
| # Load the trained model and vectorizer | |
| try: | |
| model = joblib.load(MODEL_PATH) | |
| tfidf_vectorizer = joblib.load(VECTORIZER_PATH) | |
| except FileNotFoundError: | |
| raise FileNotFoundError( | |
| "Model or vectorizer files not found. " | |
| "Please ensure 'random_forest_model.joblib' and 'tfidf_vectorizer.joblib' " | |
| "are in the same directory as this script." | |
| ) | |
| def preprocess_text(text): | |
| """ | |
| Cleans and preprocesses text data to match the format used during training. | |
| """ | |
| # Convert to lowercase | |
| text = text.lower() | |
| # Remove punctuation | |
| text = text.translate(str.maketrans('', '', string.punctuation)) | |
| # Remove digits | |
| text = re.sub(r'\d+', '', text) | |
| # Remove stopwords | |
| text = ' '.join([word for word in text.split() if word not in STOP_WORDS]) | |
| return text | |
| def predict_class(input_text): | |
| """ | |
| Takes raw text input, preprocesses it, and returns the predicted class. | |
| """ | |
| # Preprocess the input text | |
| preprocessed_text = preprocess_text(input_text) | |
| # Use the TF-IDF vectorizer to transform the text | |
| text_vector = tfidf_vectorizer.transform([preprocessed_text]) | |
| # Get the model's prediction | |
| prediction = model.predict(text_vector) | |
| # Return the predicted class name | |
| return prediction[0] | |
| # New, more diverse sample inputs for the Gradio app | |
| example_inputs = [ | |
| # Example for a Financial Report or Invoice | |
| "Invoice No: INV-2024-001\nDate: 04/11/2024\nCustomer: John Smith\nItem 1: Laptop, QTY: 1, Price: $1200.00\nTotal Amount Due: $1275.00", | |
| # Example for a Legal Document or Contract | |
| "Agenda for our next project synchronization meeting scheduled for Tuesday", | |
| # Example for a Medical Record or Clinical Note | |
| "Patient: Jane Doe, DOB: 12/05/1990\nSymptoms: Severe headache, fever, and persistent cough. Diagnosis: Influenza. Treatment: Prescribed Ibuprofen and advised rest.", | |
| ] | |
| # Set up the Gradio interface with examples | |
| interface = gr.Interface( | |
| fn=predict_class, | |
| inputs=gr.Textbox(lines=10, placeholder="Paste your document text here...", label="Input Document Text"), | |
| outputs=gr.Textbox(label="Predicted Document Class"), | |
| title="Document Classification App", | |
| description="This app classifies an input document text into one of five predefined categories.", | |
| examples=example_inputs | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| interface.launch() |