ma4389's picture
Update app.py
807f53a verified
import torch
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import gradio as gr
import re
import nltk
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from nltk.tokenize import word_tokenize
# ======== Download NLTK Resources ========
nltk.download('stopwords')
nltk.download('punkt_tab')
nltk.download('wordnet')
nltk.download('omw-1.4')
# ======== Preprocessing Setup ========
stop_words = set(stopwords.words('english'))
lemmatizer = WordNetLemmatizer()
def preprocess_text(text):
# Remove non-alphabetic characters
text = re.sub(r'[^A-Za-z\s]', '', text)
# Remove URLs
text = re.sub(r'http\S+|www\S+|https\S+', '', text)
# Normalize spaces
text = re.sub(r'\s+', ' ', text).strip()
# Lowercase
text = text.lower()
# Tokenize
tokens = word_tokenize(text)
# Remove stopwords
tokens = [word for word in tokens if word not in stop_words]
# Lemmatize
tokens = [lemmatizer.lemmatize(word) for word in tokens]
return ' '.join(tokens)
# ======== Class Names ========
class_names = [
"Anxiety",
"Bipolar",
"Depression",
"Normal",
"Personality disorder",
"Stress",
"Suicidal"
]
# ======== Load Tokenizer & Model ========
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
model = DistilBertForSequenceClassification.from_pretrained(
"distilbert-base-uncased",
num_labels=len(class_names) # 7 classes
)
model.load_state_dict(torch.load("best_model.pth", map_location=torch.device("cpu")))
model.eval()
# ======== Prediction Function ========
def predict_text(text):
cleaned_text = preprocess_text(text)
if not cleaned_text.strip():
return {cls: 0.0 for cls in class_names}
inputs = tokenizer(cleaned_text, truncation=True, padding=True, max_length=128, return_tensors='pt')
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=1).flatten().tolist()
return {cls: float(prob) for cls, prob in zip(class_names, probs)}
# ======== Gradio Interface ========
demo = gr.Interface(
fn=predict_text,
inputs=gr.Textbox(lines=4, placeholder="Enter your statement here..."),
outputs=gr.Label(num_top_classes=len(class_names)),
title="Mental Health Sentiment Classifier",
description="Classifies text into mental health categories: Normal, Depression, Suicidal, Anxiety, Bipolar, Stress, Personality disorder."
)
if __name__ == "__main__":
demo.launch()