ma4389 commited on
Commit
5075efd
·
verified ·
1 Parent(s): ff457f7

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +82 -0
  2. best_gru_model.pth +3 -0
  3. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import DistilBertTokenizer
4
+ import gradio as gr
5
+ import re
6
+ import nltk
7
+ from nltk.corpus import stopwords
8
+ from nltk.tokenize import word_tokenize
9
+ from nltk.stem import WordNetLemmatizer
10
+
11
+ # Load preprocessing tools
12
+ nltk.download('stopwords')
13
+ nltk.download('punkt_tab')
14
+ nltk.download('wordnet')
15
+
16
+ stop_words = set(stopwords.words("english"))
17
+ lemmatizer = WordNetLemmatizer()
18
+
19
+ # Preprocessing function
20
+ def preprocess_text(text):
21
+ text = re.sub(r'[^A-Za-z\s]', '', text)
22
+ text = re.sub(r'https?://\S+|www\.\S+', '', text)
23
+ text = text.lower()
24
+ tokens = word_tokenize(text)
25
+ tokens = [word for word in tokens if word not in stop_words]
26
+ tokens = [lemmatizer.lemmatize(word) for word in tokens]
27
+ return ' '.join(tokens)
28
+
29
+ # Define class mapping
30
+ label_dict = {
31
+ 0: "sadness",
32
+ 1: "joy",
33
+ 2: "love",
34
+ 3: "anger",
35
+ 4: "fear",
36
+ 5: "surprise"
37
+ }
38
+
39
+ # Load tokenizer
40
+ tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
41
+ max_len = 32
42
+
43
+ # Define the GRU Classifier
44
+ class GRUClassifier(nn.Module):
45
+ def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes):
46
+ super(GRUClassifier, self).__init__()
47
+ self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
48
+ self.gru = nn.GRU(embed_dim, hidden_dim, batch_first=True)
49
+ self.fc = nn.Linear(hidden_dim, num_classes)
50
+
51
+ def forward(self, input_ids):
52
+ x = self.embedding(input_ids)
53
+ out, _ = self.gru(x)
54
+ out = out[:, -1, :]
55
+ return self.fc(out)
56
+
57
+ # Load model
58
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
59
+ model = GRUClassifier(vocab_size=tokenizer.vocab_size, embed_dim=128, hidden_dim=64, num_classes=6)
60
+ model.load_state_dict(torch.load("best_gru_model.pth", map_location=device))
61
+ model.to(device)
62
+ model.eval()
63
+
64
+ # Inference function
65
+ def classify_emotion(text):
66
+ cleaned = preprocess_text(text)
67
+ tokens = tokenizer(cleaned, truncation=True, padding='max_length', max_length=max_len, return_tensors='pt')
68
+ input_ids = tokens['input_ids'].to(device)
69
+ with torch.no_grad():
70
+ outputs = model(input_ids)
71
+ prediction = torch.argmax(outputs, dim=1).item()
72
+ return label_dict[prediction]
73
+
74
+ # Gradio Interface
75
+ iface = gr.Interface(fn=classify_emotion,
76
+ inputs=gr.Textbox(lines=2, placeholder="Enter a sentence..."),
77
+ outputs="text",
78
+ title="Emotion Classifier (GRU)",
79
+ description="Predicts emotion from text. Classes: sadness, joy, love, anger, fear, surprise")
80
+
81
+ if __name__ == "__main__":
82
+ iface.launch()
best_gru_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fde25b76a7949a8378b504f2e2458bfbdbbea7700608adb02899767d309436e9
3
+ size 15780232
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ gradio
4
+ nltk