ma4389 commited on
Commit
fcc8b20
·
verified ·
1 Parent(s): 2083483

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -79
app.py CHANGED
@@ -1,79 +1,89 @@
1
- import torch
2
- from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
3
- import gradio as gr
4
- import re
5
- import nltk
6
- from nltk.corpus import stopwords
7
- from nltk.stem import WordNetLemmatizer
8
- from nltk.tokenize import word_tokenize
9
-
10
- # ======== Download NLTK Resources ========
11
- nltk.download('stopwords')
12
- nltk.download('punkt_tab')
13
- nltk.download('wordnet')
14
- nltk.download('omw-1.4')
15
-
16
- # ======== Preprocessing Setup ========
17
- stop_words = set(stopwords.words('english'))
18
- lemmatizer = WordNetLemmatizer()
19
-
20
- def preprocess_text(text):
21
- # Remove non-alphabetic characters
22
- text = re.sub(r'[^A-Za-z\s]', '', text)
23
- # Remove URLs
24
- text = re.sub(r'http\S+|www\S+|https\S+', '', text)
25
- # Normalize spaces
26
- text = re.sub(r'\s+', ' ', text).strip()
27
- # Lowercase
28
- text = text.lower()
29
- # Tokenize
30
- tokens = word_tokenize(text)
31
- # Remove stopwords
32
- tokens = [word for word in tokens if word not in stop_words]
33
- # Lemmatize
34
- tokens = [lemmatizer.lemmatize(word) for word in tokens]
35
- return ' '.join(tokens)
36
-
37
- # ======== Class Names ========
38
- class_names = [
39
- "Normal",
40
- "Depression",
41
- "Suicidal",
42
- "Anxiety",
43
- "Bipolar",
44
- "Personality disorder"
45
- ]
46
-
47
- # ======== Load Tokenizer & Model ========
48
- tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
49
- model = DistilBertForSequenceClassification.from_pretrained(
50
- "distilbert-base-uncased",
51
- num_labels=len(class_names)
52
- )
53
- model.load_state_dict(torch.load("best_model.pth", map_location=torch.device("cpu")))
54
- model.eval()
55
-
56
- # ======== Prediction Function ========
57
- def predict_text(text):
58
- cleaned_text = preprocess_text(text)
59
- if not cleaned_text.strip():
60
- return {cls: 0.0 for cls in class_names}
61
-
62
- inputs = tokenizer(cleaned_text, truncation=True, padding=True, max_length=128, return_tensors='pt')
63
- with torch.no_grad():
64
- outputs = model(**inputs)
65
- probs = torch.softmax(outputs.logits, dim=1).flatten().tolist()
66
-
67
- return {cls: float(prob) for cls, prob in zip(class_names, probs)}
68
-
69
- # ======== Gradio Interface ========
70
- demo = gr.Interface(
71
- fn=predict_text,
72
- inputs=gr.Textbox(lines=4, placeholder="Enter your statement here..."),
73
- outputs=gr.Label(num_top_classes=len(class_names)),
74
- title="Mental Health Sentiment Classifier",
75
- description="Classifies text into mental health categories."
76
- )
77
-
78
- if __name__ == "__main__":
79
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
3
+ import gradio as gr
4
+ import re
5
+ import nltk
6
+ from nltk.corpus import stopwords
7
+ from nltk.stem import WordNetLemmatizer
8
+ from nltk.tokenize import word_tokenize
9
+
10
+ # ======== Download NLTK Resources ========
11
+ nltk.download('stopwords')
12
+ nltk.download('punkt_tab')
13
+ nltk.download('wordnet')
14
+ nltk.download('omw-1.4')
15
+
16
+ class_names = [
17
+ "Normal",
18
+ "Depression",
19
+ "Suicidal",
20
+ "Anxiety",
21
+ "Bipolar",
22
+ "Personality disorder"
23
+ ]
24
+
25
+
26
+ # ======== Preprocessing Setup ========
27
+ stop_words = set(stopwords.words('english'))
28
+ lemmatizer = WordNetLemmatizer()
29
+
30
+ def preprocess_text(text):
31
+ # Remove non-alphabetic characters
32
+ text = re.sub(r'[^A-Za-z\s]', '', text)
33
+ # Remove URLs
34
+ text = re.sub(r'http\S+|www\S+|https\S+', '', text)
35
+ # Normalize spaces
36
+ text = re.sub(r'\s+', ' ', text).strip()
37
+ # Lowercase
38
+ text = text.lower()
39
+ # Tokenize
40
+ tokens = word_tokenize(text)
41
+ # Remove stopwords
42
+ tokens = [word for word in tokens if word not in stop_words]
43
+ # Lemmatize
44
+ tokens = [lemmatizer.lemmatize(word) for word in tokens]
45
+ return ' '.join(tokens)
46
+
47
+ # ======== Class Names ========
48
+ class_names = [
49
+ "Normal",
50
+ "Depression",
51
+ "Suicidal",
52
+ "Anxiety",
53
+ "Bipolar",
54
+ "Personality disorder"
55
+ ]
56
+
57
+ # ======== Load Tokenizer & Model ========
58
+ tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
59
+ model = DistilBertForSequenceClassification.from_pretrained(
60
+ "distilbert-base-uncased",
61
+ num_labels=len(class_names)
62
+ )
63
+ model.load_state_dict(torch.load("best_model.pth", map_location=torch.device("cpu")))
64
+ model.eval()
65
+
66
+ # ======== Prediction Function ========
67
+ def predict_text(text):
68
+ cleaned_text = preprocess_text(text)
69
+ if not cleaned_text.strip():
70
+ return {cls: 0.0 for cls in class_names}
71
+
72
+ inputs = tokenizer(cleaned_text, truncation=True, padding=True, max_length=128, return_tensors='pt')
73
+ with torch.no_grad():
74
+ outputs = model(**inputs)
75
+ probs = torch.softmax(outputs.logits, dim=1).flatten().tolist()
76
+
77
+ return {cls: float(prob) for cls, prob in zip(class_names, probs)}
78
+
79
+ # ======== Gradio Interface ========
80
+ demo = gr.Interface(
81
+ fn=predict_text,
82
+ inputs=gr.Textbox(lines=4, placeholder="Enter your statement here..."),
83
+ outputs=gr.Label(num_top_classes=len(class_names)),
84
+ title="Mental Health Sentiment Classifier",
85
+ description="Classifies text into mental health categories."
86
+ )
87
+
88
+ if __name__ == "__main__":
89
+ demo.launch()