|
|
import torch |
|
|
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification |
|
|
import gradio as gr |
|
|
import re |
|
|
import nltk |
|
|
from nltk.corpus import stopwords |
|
|
from nltk.stem import WordNetLemmatizer |
|
|
from nltk.tokenize import word_tokenize |
|
|
|
|
|
|
|
|
nltk.download('stopwords') |
|
|
nltk.download('punkt_tab') |
|
|
nltk.download('wordnet') |
|
|
nltk.download('omw-1.4') |
|
|
|
|
|
|
|
|
stop_words = set(stopwords.words('english')) |
|
|
lemmatizer = WordNetLemmatizer() |
|
|
|
|
|
def preprocess_text(text): |
|
|
|
|
|
text = re.sub(r'[^A-Za-z\s]', '', text) |
|
|
|
|
|
text = re.sub(r'http\S+|www\S+|https\S+', '', text) |
|
|
|
|
|
text = re.sub(r'\s+', ' ', text).strip() |
|
|
|
|
|
text = text.lower() |
|
|
|
|
|
tokens = word_tokenize(text) |
|
|
|
|
|
tokens = [word for word in tokens if word not in stop_words] |
|
|
|
|
|
tokens = [lemmatizer.lemmatize(word) for word in tokens] |
|
|
return ' '.join(tokens) |
|
|
|
|
|
|
|
|
class_names = [ |
|
|
"Anxiety", |
|
|
"Bipolar", |
|
|
"Depression", |
|
|
"Normal", |
|
|
"Personality disorder", |
|
|
"Stress", |
|
|
"Suicidal" |
|
|
] |
|
|
|
|
|
|
|
|
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") |
|
|
model = DistilBertForSequenceClassification.from_pretrained( |
|
|
"distilbert-base-uncased", |
|
|
num_labels=len(class_names) |
|
|
) |
|
|
model.load_state_dict(torch.load("best_model.pth", map_location=torch.device("cpu"))) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
def predict_text(text): |
|
|
cleaned_text = preprocess_text(text) |
|
|
if not cleaned_text.strip(): |
|
|
return {cls: 0.0 for cls in class_names} |
|
|
|
|
|
inputs = tokenizer(cleaned_text, truncation=True, padding=True, max_length=128, return_tensors='pt') |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
probs = torch.softmax(outputs.logits, dim=1).flatten().tolist() |
|
|
|
|
|
return {cls: float(prob) for cls, prob in zip(class_names, probs)} |
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=predict_text, |
|
|
inputs=gr.Textbox(lines=4, placeholder="Enter your statement here..."), |
|
|
outputs=gr.Label(num_top_classes=len(class_names)), |
|
|
title="Mental Health Sentiment Classifier", |
|
|
description="Classifies text into mental health categories: Normal, Depression, Suicidal, Anxiety, Bipolar, Stress, Personality disorder." |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|