import torch from transformers import DistilBertTokenizer, DistilBertForSequenceClassification import gradio as gr import re import nltk from nltk.corpus import stopwords from nltk.stem import WordNetLemmatizer from nltk.tokenize import word_tokenize # ======== Download NLTK Resources ======== nltk.download('stopwords') nltk.download('punkt_tab') nltk.download('wordnet') nltk.download('omw-1.4') # ======== Preprocessing Setup ======== stop_words = set(stopwords.words('english')) lemmatizer = WordNetLemmatizer() def preprocess_text(text): # Remove non-alphabetic characters text = re.sub(r'[^A-Za-z\s]', '', text) # Remove URLs text = re.sub(r'http\S+|www\S+|https\S+', '', text) # Normalize spaces text = re.sub(r'\s+', ' ', text).strip() # Lowercase text = text.lower() # Tokenize tokens = word_tokenize(text) # Remove stopwords tokens = [word for word in tokens if word not in stop_words] # Lemmatize tokens = [lemmatizer.lemmatize(word) for word in tokens] return ' '.join(tokens) # ======== Class Names ======== class_names = [ "Anxiety", "Bipolar", "Depression", "Normal", "Personality disorder", "Stress", "Suicidal" ] # ======== Load Tokenizer & Model ======== tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") model = DistilBertForSequenceClassification.from_pretrained( "distilbert-base-uncased", num_labels=len(class_names) # 7 classes ) model.load_state_dict(torch.load("best_model.pth", map_location=torch.device("cpu"))) model.eval() # ======== Prediction Function ======== def predict_text(text): cleaned_text = preprocess_text(text) if not cleaned_text.strip(): return {cls: 0.0 for cls in class_names} inputs = tokenizer(cleaned_text, truncation=True, padding=True, max_length=128, return_tensors='pt') with torch.no_grad(): outputs = model(**inputs) probs = torch.softmax(outputs.logits, dim=1).flatten().tolist() return {cls: float(prob) for cls, prob in zip(class_names, probs)} # ======== Gradio Interface ======== demo = gr.Interface( fn=predict_text, inputs=gr.Textbox(lines=4, placeholder="Enter your statement here..."), outputs=gr.Label(num_top_classes=len(class_names)), title="Mental Health Sentiment Classifier", description="Classifies text into mental health categories: Normal, Depression, Suicidal, Anxiety, Bipolar, Stress, Personality disorder." ) if __name__ == "__main__": demo.launch()