File size: 2,530 Bytes
fcc8b20 95ebd81 d111374 807f53a aa9b825 807f53a fcc8b20 aa9b825 fcc8b20 aa9b825 fcc8b20 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import torch
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import gradio as gr
import re
import nltk
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from nltk.tokenize import word_tokenize
# ======== Download NLTK Resources ========
nltk.download('stopwords')
nltk.download('punkt_tab')
nltk.download('wordnet')
nltk.download('omw-1.4')
# ======== Preprocessing Setup ========
stop_words = set(stopwords.words('english'))
lemmatizer = WordNetLemmatizer()
def preprocess_text(text):
# Remove non-alphabetic characters
text = re.sub(r'[^A-Za-z\s]', '', text)
# Remove URLs
text = re.sub(r'http\S+|www\S+|https\S+', '', text)
# Normalize spaces
text = re.sub(r'\s+', ' ', text).strip()
# Lowercase
text = text.lower()
# Tokenize
tokens = word_tokenize(text)
# Remove stopwords
tokens = [word for word in tokens if word not in stop_words]
# Lemmatize
tokens = [lemmatizer.lemmatize(word) for word in tokens]
return ' '.join(tokens)
# ======== Class Names ========
class_names = [
"Anxiety",
"Bipolar",
"Depression",
"Normal",
"Personality disorder",
"Stress",
"Suicidal"
]
# ======== Load Tokenizer & Model ========
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
model = DistilBertForSequenceClassification.from_pretrained(
"distilbert-base-uncased",
num_labels=len(class_names) # 7 classes
)
model.load_state_dict(torch.load("best_model.pth", map_location=torch.device("cpu")))
model.eval()
# ======== Prediction Function ========
def predict_text(text):
cleaned_text = preprocess_text(text)
if not cleaned_text.strip():
return {cls: 0.0 for cls in class_names}
inputs = tokenizer(cleaned_text, truncation=True, padding=True, max_length=128, return_tensors='pt')
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=1).flatten().tolist()
return {cls: float(prob) for cls, prob in zip(class_names, probs)}
# ======== Gradio Interface ========
demo = gr.Interface(
fn=predict_text,
inputs=gr.Textbox(lines=4, placeholder="Enter your statement here..."),
outputs=gr.Label(num_top_classes=len(class_names)),
title="Mental Health Sentiment Classifier",
description="Classifies text into mental health categories: Normal, Depression, Suicidal, Anxiety, Bipolar, Stress, Personality disorder."
)
if __name__ == "__main__":
demo.launch()
|