ma4389 commited on
Commit
2083483
·
verified ·
1 Parent(s): ed8ba86

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +79 -0
  2. best_model.pth +3 -0
  3. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
3
+ import gradio as gr
4
+ import re
5
+ import nltk
6
+ from nltk.corpus import stopwords
7
+ from nltk.stem import WordNetLemmatizer
8
+ from nltk.tokenize import word_tokenize
9
+
10
+ # ======== Download NLTK Resources ========
11
+ nltk.download('stopwords')
12
+ nltk.download('punkt_tab')
13
+ nltk.download('wordnet')
14
+ nltk.download('omw-1.4')
15
+
16
+ # ======== Preprocessing Setup ========
17
+ stop_words = set(stopwords.words('english'))
18
+ lemmatizer = WordNetLemmatizer()
19
+
20
+ def preprocess_text(text):
21
+ # Remove non-alphabetic characters
22
+ text = re.sub(r'[^A-Za-z\s]', '', text)
23
+ # Remove URLs
24
+ text = re.sub(r'http\S+|www\S+|https\S+', '', text)
25
+ # Normalize spaces
26
+ text = re.sub(r'\s+', ' ', text).strip()
27
+ # Lowercase
28
+ text = text.lower()
29
+ # Tokenize
30
+ tokens = word_tokenize(text)
31
+ # Remove stopwords
32
+ tokens = [word for word in tokens if word not in stop_words]
33
+ # Lemmatize
34
+ tokens = [lemmatizer.lemmatize(word) for word in tokens]
35
+ return ' '.join(tokens)
36
+
37
+ # ======== Class Names ========
38
+ class_names = [
39
+ "Normal",
40
+ "Depression",
41
+ "Suicidal",
42
+ "Anxiety",
43
+ "Bipolar",
44
+ "Personality disorder"
45
+ ]
46
+
47
+ # ======== Load Tokenizer & Model ========
48
+ tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
49
+ model = DistilBertForSequenceClassification.from_pretrained(
50
+ "distilbert-base-uncased",
51
+ num_labels=len(class_names)
52
+ )
53
+ model.load_state_dict(torch.load("best_model.pth", map_location=torch.device("cpu")))
54
+ model.eval()
55
+
56
+ # ======== Prediction Function ========
57
+ def predict_text(text):
58
+ cleaned_text = preprocess_text(text)
59
+ if not cleaned_text.strip():
60
+ return {cls: 0.0 for cls in class_names}
61
+
62
+ inputs = tokenizer(cleaned_text, truncation=True, padding=True, max_length=128, return_tensors='pt')
63
+ with torch.no_grad():
64
+ outputs = model(**inputs)
65
+ probs = torch.softmax(outputs.logits, dim=1).flatten().tolist()
66
+
67
+ return {cls: float(prob) for cls, prob in zip(class_names, probs)}
68
+
69
+ # ======== Gradio Interface ========
70
+ demo = gr.Interface(
71
+ fn=predict_text,
72
+ inputs=gr.Textbox(lines=4, placeholder="Enter your statement here..."),
73
+ outputs=gr.Label(num_top_classes=len(class_names)),
74
+ title="Mental Health Sentiment Classifier",
75
+ description="Classifies text into mental health categories."
76
+ )
77
+
78
+ if __name__ == "__main__":
79
+ demo.launch()
best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df2a8c6fa062a80c85368b5709e49beab9db5827726500c016637cd3f0abb583
3
+ size 267877222
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ gradio
4
+ nltk