ma4389's picture
Upload 3 files
5075efd verified
raw
history blame
2.73 kB
import torch
import torch.nn as nn
from transformers import DistilBertTokenizer
import gradio as gr
import re
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from nltk.stem import WordNetLemmatizer
# Load preprocessing tools
nltk.download('stopwords')
nltk.download('punkt_tab')
nltk.download('wordnet')
stop_words = set(stopwords.words("english"))
lemmatizer = WordNetLemmatizer()
# Preprocessing function
def preprocess_text(text):
text = re.sub(r'[^A-Za-z\s]', '', text)
text = re.sub(r'https?://\S+|www\.\S+', '', text)
text = text.lower()
tokens = word_tokenize(text)
tokens = [word for word in tokens if word not in stop_words]
tokens = [lemmatizer.lemmatize(word) for word in tokens]
return ' '.join(tokens)
# Define class mapping
label_dict = {
0: "sadness",
1: "joy",
2: "love",
3: "anger",
4: "fear",
5: "surprise"
}
# Load tokenizer
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
max_len = 32
# Define the GRU Classifier
class GRUClassifier(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes):
super(GRUClassifier, self).__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
self.gru = nn.GRU(embed_dim, hidden_dim, batch_first=True)
self.fc = nn.Linear(hidden_dim, num_classes)
def forward(self, input_ids):
x = self.embedding(input_ids)
out, _ = self.gru(x)
out = out[:, -1, :]
return self.fc(out)
# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GRUClassifier(vocab_size=tokenizer.vocab_size, embed_dim=128, hidden_dim=64, num_classes=6)
model.load_state_dict(torch.load("best_gru_model.pth", map_location=device))
model.to(device)
model.eval()
# Inference function
def classify_emotion(text):
cleaned = preprocess_text(text)
tokens = tokenizer(cleaned, truncation=True, padding='max_length', max_length=max_len, return_tensors='pt')
input_ids = tokens['input_ids'].to(device)
with torch.no_grad():
outputs = model(input_ids)
prediction = torch.argmax(outputs, dim=1).item()
return label_dict[prediction]
# Gradio Interface
iface = gr.Interface(fn=classify_emotion,
inputs=gr.Textbox(lines=2, placeholder="Enter a sentence..."),
outputs="text",
title="Emotion Classifier (GRU)",
description="Predicts emotion from text. Classes: sadness, joy, love, anger, fear, surprise")
if __name__ == "__main__":
iface.launch()