File size: 2,887 Bytes
46d4b20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac262fe
46d4b20
ac262fe
 
 
 
81cf4b2
ac262fe
027d00f
 
ac262fe
027d00f
46d4b20
 
 
 
 
 
 
 
 
 
 
 
 
 
ac262fe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# app.py

import gradio as gr
import joblib
import re
import nltk
from nltk.corpus import stopwords
import string

# Download NLTK stopwords if not already present
try:
    stopwords.words('english')
except LookupError:
    nltk.download('stopwords')

# Define global variables for the model, vectorizer, and stopwords
MODEL_PATH = "random_forest_model.joblib"
VECTORIZER_PATH = "tfidf_vectorizer.joblib"
STOP_WORDS = set(stopwords.words('english'))

# Load the trained model and vectorizer
try:
    model = joblib.load(MODEL_PATH)
    tfidf_vectorizer = joblib.load(VECTORIZER_PATH)
except FileNotFoundError:
    raise FileNotFoundError(
        "Model or vectorizer files not found. "
        "Please ensure 'random_forest_model.joblib' and 'tfidf_vectorizer.joblib' "
        "are in the same directory as this script."
    )

def preprocess_text(text):
    """
    Cleans and preprocesses text data to match the format used during training.
    """
    # Convert to lowercase
    text = text.lower()
    # Remove punctuation
    text = text.translate(str.maketrans('', '', string.punctuation))
    # Remove digits
    text = re.sub(r'\d+', '', text)
    # Remove stopwords
    text = ' '.join([word for word in text.split() if word not in STOP_WORDS])
    return text

def predict_class(input_text):
    """
    Takes raw text input, preprocesses it, and returns the predicted class.
    """
    # Preprocess the input text
    preprocessed_text = preprocess_text(input_text)
    
    # Use the TF-IDF vectorizer to transform the text
    text_vector = tfidf_vectorizer.transform([preprocessed_text])
    
    # Get the model's prediction
    prediction = model.predict(text_vector)
    
    # Return the predicted class name
    return prediction[0]

# New, more diverse sample inputs for the Gradio app
example_inputs = [
    # Example for a Financial Report or Invoice
    "Invoice No: INV-2024-001\nDate: 04/11/2024\nCustomer: John Smith\nItem 1: Laptop, QTY: 1, Price: $1200.00\nTotal Amount Due: $1275.00",
    
    # Example for a Legal Document or Contract
    "Agenda for our next project synchronization meeting scheduled for Tuesday",
    
    # Example for a Medical Record or Clinical Note
    "Patient: Jane Doe, DOB: 12/05/1990\nSymptoms: Severe headache, fever, and persistent cough. Diagnosis: Influenza. Treatment: Prescribed Ibuprofen and advised rest.",

    
]

# Set up the Gradio interface with examples
interface = gr.Interface(
    fn=predict_class,
    inputs=gr.Textbox(lines=10, placeholder="Paste your document text here...", label="Input Document Text"),
    outputs=gr.Textbox(label="Predicted Document Class"),
    title="Document Classification App",
    description="This app classifies an input document text into one of five predefined categories.",
    examples=example_inputs
)

# Launch the app
if __name__ == "__main__":
    interface.launch()