Update app.py
Browse files
app.py
CHANGED
|
@@ -13,16 +13,6 @@ nltk.download('punkt_tab')
|
|
| 13 |
nltk.download('wordnet')
|
| 14 |
nltk.download('omw-1.4')
|
| 15 |
|
| 16 |
-
class_names = [
|
| 17 |
-
"Normal",
|
| 18 |
-
"Depression",
|
| 19 |
-
"Suicidal",
|
| 20 |
-
"Anxiety",
|
| 21 |
-
"Bipolar",
|
| 22 |
-
"Personality disorder"
|
| 23 |
-
]
|
| 24 |
-
|
| 25 |
-
|
| 26 |
# ======== Preprocessing Setup ========
|
| 27 |
stop_words = set(stopwords.words('english'))
|
| 28 |
lemmatizer = WordNetLemmatizer()
|
|
@@ -51,6 +41,7 @@ class_names = [
|
|
| 51 |
"Suicidal",
|
| 52 |
"Anxiety",
|
| 53 |
"Bipolar",
|
|
|
|
| 54 |
"Personality disorder"
|
| 55 |
]
|
| 56 |
|
|
@@ -58,7 +49,7 @@ class_names = [
|
|
| 58 |
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
|
| 59 |
model = DistilBertForSequenceClassification.from_pretrained(
|
| 60 |
"distilbert-base-uncased",
|
| 61 |
-
num_labels=len(class_names)
|
| 62 |
)
|
| 63 |
model.load_state_dict(torch.load("best_model.pth", map_location=torch.device("cpu")))
|
| 64 |
model.eval()
|
|
@@ -82,7 +73,7 @@ demo = gr.Interface(
|
|
| 82 |
inputs=gr.Textbox(lines=4, placeholder="Enter your statement here..."),
|
| 83 |
outputs=gr.Label(num_top_classes=len(class_names)),
|
| 84 |
title="Mental Health Sentiment Classifier",
|
| 85 |
-
description="Classifies text into mental health categories."
|
| 86 |
)
|
| 87 |
|
| 88 |
if __name__ == "__main__":
|
|
|
|
| 13 |
nltk.download('wordnet')
|
| 14 |
nltk.download('omw-1.4')
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
# ======== Preprocessing Setup ========
|
| 17 |
stop_words = set(stopwords.words('english'))
|
| 18 |
lemmatizer = WordNetLemmatizer()
|
|
|
|
| 41 |
"Suicidal",
|
| 42 |
"Anxiety",
|
| 43 |
"Bipolar",
|
| 44 |
+
"Stress",
|
| 45 |
"Personality disorder"
|
| 46 |
]
|
| 47 |
|
|
|
|
| 49 |
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
|
| 50 |
model = DistilBertForSequenceClassification.from_pretrained(
|
| 51 |
"distilbert-base-uncased",
|
| 52 |
+
num_labels=len(class_names) # 7 classes
|
| 53 |
)
|
| 54 |
model.load_state_dict(torch.load("best_model.pth", map_location=torch.device("cpu")))
|
| 55 |
model.eval()
|
|
|
|
| 73 |
inputs=gr.Textbox(lines=4, placeholder="Enter your statement here..."),
|
| 74 |
outputs=gr.Label(num_top_classes=len(class_names)),
|
| 75 |
title="Mental Health Sentiment Classifier",
|
| 76 |
+
description="Classifies text into mental health categories: Normal, Depression, Suicidal, Anxiety, Bipolar, Stress, Personality disorder."
|
| 77 |
)
|
| 78 |
|
| 79 |
if __name__ == "__main__":
|