Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from transformers import DistilBertTokenizer | |
| import gradio as gr | |
| import re | |
| import nltk | |
| from nltk.corpus import stopwords | |
| from nltk.tokenize import word_tokenize | |
| from nltk.stem import WordNetLemmatizer | |
| # Load preprocessing tools | |
| nltk.download('stopwords') | |
| nltk.download('punkt_tab') | |
| nltk.download('wordnet') | |
| stop_words = set(stopwords.words("english")) | |
| lemmatizer = WordNetLemmatizer() | |
| # Preprocessing function | |
| def preprocess_text(text): | |
| text = re.sub(r'[^A-Za-z\s]', '', text) | |
| text = re.sub(r'https?://\S+|www\.\S+', '', text) | |
| text = text.lower() | |
| tokens = word_tokenize(text) | |
| tokens = [word for word in tokens if word not in stop_words] | |
| tokens = [lemmatizer.lemmatize(word) for word in tokens] | |
| return ' '.join(tokens) | |
| # Define class mapping | |
| label_dict = { | |
| 0: "sadness", | |
| 1: "joy", | |
| 2: "love", | |
| 3: "anger", | |
| 4: "fear", | |
| 5: "surprise" | |
| } | |
| # Load tokenizer | |
| tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') | |
| max_len = 32 | |
| # Define the GRU Classifier | |
| class GRUClassifier(nn.Module): | |
| def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes): | |
| super(GRUClassifier, self).__init__() | |
| self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0) | |
| self.gru = nn.GRU(embed_dim, hidden_dim, batch_first=True) | |
| self.fc = nn.Linear(hidden_dim, num_classes) | |
| def forward(self, input_ids): | |
| x = self.embedding(input_ids) | |
| out, _ = self.gru(x) | |
| out = out[:, -1, :] | |
| return self.fc(out) | |
| # Load model | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = GRUClassifier(vocab_size=tokenizer.vocab_size, embed_dim=128, hidden_dim=64, num_classes=6) | |
| model.load_state_dict(torch.load("best_gru_model.pth", map_location=device)) | |
| model.to(device) | |
| model.eval() | |
| # Inference function | |
| def classify_emotion(text): | |
| cleaned = preprocess_text(text) | |
| tokens = tokenizer(cleaned, truncation=True, padding='max_length', max_length=max_len, return_tensors='pt') | |
| input_ids = tokens['input_ids'].to(device) | |
| with torch.no_grad(): | |
| outputs = model(input_ids) | |
| prediction = torch.argmax(outputs, dim=1).item() | |
| return label_dict[prediction] | |
| # Gradio Interface | |
| iface = gr.Interface(fn=classify_emotion, | |
| inputs=gr.Textbox(lines=2, placeholder="Enter a sentence..."), | |
| outputs="text", | |
| title="Emotion Classifier (GRU)", | |
| description="Predicts emotion from text. Classes: sadness, joy, love, anger, fear, surprise") | |
| if __name__ == "__main__": | |
| iface.launch() | |