ma4389 commited on
Commit
aa9b825
·
verified ·
1 Parent(s): fcc8b20

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -12
app.py CHANGED
@@ -13,16 +13,6 @@ nltk.download('punkt_tab')
13
  nltk.download('wordnet')
14
  nltk.download('omw-1.4')
15
 
16
- class_names = [
17
- "Normal",
18
- "Depression",
19
- "Suicidal",
20
- "Anxiety",
21
- "Bipolar",
22
- "Personality disorder"
23
- ]
24
-
25
-
26
  # ======== Preprocessing Setup ========
27
  stop_words = set(stopwords.words('english'))
28
  lemmatizer = WordNetLemmatizer()
@@ -51,6 +41,7 @@ class_names = [
51
  "Suicidal",
52
  "Anxiety",
53
  "Bipolar",
 
54
  "Personality disorder"
55
  ]
56
 
@@ -58,7 +49,7 @@ class_names = [
58
  tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
59
  model = DistilBertForSequenceClassification.from_pretrained(
60
  "distilbert-base-uncased",
61
- num_labels=len(class_names)
62
  )
63
  model.load_state_dict(torch.load("best_model.pth", map_location=torch.device("cpu")))
64
  model.eval()
@@ -82,7 +73,7 @@ demo = gr.Interface(
82
  inputs=gr.Textbox(lines=4, placeholder="Enter your statement here..."),
83
  outputs=gr.Label(num_top_classes=len(class_names)),
84
  title="Mental Health Sentiment Classifier",
85
- description="Classifies text into mental health categories."
86
  )
87
 
88
  if __name__ == "__main__":
 
13
  nltk.download('wordnet')
14
  nltk.download('omw-1.4')
15
 
 
 
 
 
 
 
 
 
 
 
16
  # ======== Preprocessing Setup ========
17
  stop_words = set(stopwords.words('english'))
18
  lemmatizer = WordNetLemmatizer()
 
41
  "Suicidal",
42
  "Anxiety",
43
  "Bipolar",
44
+ "Stress",
45
  "Personality disorder"
46
  ]
47
 
 
49
  tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
50
  model = DistilBertForSequenceClassification.from_pretrained(
51
  "distilbert-base-uncased",
52
+ num_labels=len(class_names) # 7 classes
53
  )
54
  model.load_state_dict(torch.load("best_model.pth", map_location=torch.device("cpu")))
55
  model.eval()
 
73
  inputs=gr.Textbox(lines=4, placeholder="Enter your statement here..."),
74
  outputs=gr.Label(num_top_classes=len(class_names)),
75
  title="Mental Health Sentiment Classifier",
76
+ description="Classifies text into mental health categories: Normal, Depression, Suicidal, Anxiety, Bipolar, Stress, Personality disorder."
77
  )
78
 
79
  if __name__ == "__main__":