Sazzz02 commited on
Commit
46d4b20
·
verified ·
1 Parent(s): 42456ae

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -0
app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+
3
+ import gradio as gr
4
+ import joblib
5
+ import re
6
+ import nltk
7
+ from nltk.corpus import stopwords
8
+ import string
9
+
10
+ # Download NLTK stopwords if not already present
11
+ try:
12
+ stopwords.words('english')
13
+ except LookupError:
14
+ nltk.download('stopwords')
15
+
16
+ # Define global variables for the model, vectorizer, and stopwords
17
+ MODEL_PATH = "random_forest_model.joblib"
18
+ VECTORIZER_PATH = "tfidf_vectorizer.joblib"
19
+ STOP_WORDS = set(stopwords.words('english'))
20
+
21
+ # Load the trained model and vectorizer
22
+ try:
23
+ model = joblib.load(MODEL_PATH)
24
+ tfidf_vectorizer = joblib.load(VECTORIZER_PATH)
25
+ except FileNotFoundError:
26
+ raise FileNotFoundError(
27
+ "Model or vectorizer files not found. "
28
+ "Please ensure 'random_forest_model.joblib' and 'tfidf_vectorizer.joblib' "
29
+ "are in the same directory as this script."
30
+ )
31
+
32
+ def preprocess_text(text):
33
+ """
34
+ Cleans and preprocesses text data to match the format used during training.
35
+ """
36
+ # Convert to lowercase
37
+ text = text.lower()
38
+ # Remove punctuation
39
+ text = text.translate(str.maketrans('', '', string.punctuation))
40
+ # Remove digits
41
+ text = re.sub(r'\d+', '', text)
42
+ # Remove stopwords
43
+ text = ' '.join([word for word in text.split() if word not in STOP_WORDS])
44
+ return text
45
+
46
+ def predict_class(input_text):
47
+ """
48
+ Takes raw text input, preprocesses it, and returns the predicted class.
49
+ """
50
+ # Preprocess the input text
51
+ preprocessed_text = preprocess_text(input_text)
52
+
53
+ # Use the TF-IDF vectorizer to transform the text
54
+ text_vector = tfidf_vectorizer.transform([preprocessed_text])
55
+
56
+ # Get the model's prediction
57
+ prediction = model.predict(text_vector)
58
+
59
+ # Return the predicted class name
60
+ return prediction[0]
61
+
62
+ # Sample inputs for the Gradio app
63
+ example_inputs = [
64
+ "The company's annual financial report showed a net profit of 50 million dollars, an increase of 15% from the previous year. The key drivers were cost reduction and increased market share in Asia.",
65
+ "Patient medical history reveals a family history of hypertension. Symptoms include elevated blood pressure readings and persistent headaches. The patient has been prescribed a new medication.",
66
+ "Instructions for assembly: Attach part A to part B using the supplied screw. Ensure the connection is tight to prevent detachment. The product is intended for indoor use only."
67
+ ]
68
+
69
+ # Set up the Gradio interface with examples
70
+ interface = gr.Interface(
71
+ fn=predict_class,
72
+ inputs=gr.Textbox(lines=10, placeholder="Paste your document text here...", label="Input Document Text"),
73
+ outputs=gr.Textbox(label="Predicted Document Class"),
74
+ title="Document Classification App",
75
+ description="This app classifies an input document text into one of five predefined categories.",
76
+ examples=example_inputs
77
+ )
78
+
79
+ # Launch the app
80
+ if __name__ == "__main__":
81
+ interface.launch(inline=False, share=True)