keethu commited on
Commit
e0c4792
·
verified ·
1 Parent(s): 9803334

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -0
app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import BertTokenizer, BertModel
4
+ import gradio as gr
5
+
6
+ # Load tokenizer and base BERT model
7
+ model_name = "keerthikapujari25/bert-emotion-classifier" # Replace with your HF username/repo
8
+ tokenizer = BertTokenizer.from_pretrained(model_name)
9
+ bert_model = BertModel.from_pretrained(model_name)
10
+
11
+ # Define your classifier architecture (same as training)
12
+ class BERTClassifier(nn.Module):
13
+ def __init__(self, bert_model, num_labels=5, dropout=0.3):
14
+ super(BERTClassifier, self).__init__()
15
+ self.bert = bert_model
16
+ self.dropout = nn.Dropout(dropout)
17
+ self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
18
+
19
+ def forward(self, input_ids, attention_mask):
20
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
21
+ pooled_output = outputs.pooler_output
22
+ pooled_output = self.dropout(pooled_output)
23
+ logits = self.classifier(pooled_output)
24
+ return logits
25
+
26
+ # Load model
27
+ model = BERTClassifier(bert_model, num_labels=5, dropout=0.3)
28
+ model.load_state_dict(torch.load(f"{model_name}/pytorch_model.bin", map_location='cpu'))
29
+ model.eval()
30
+
31
+ emotion_labels = ['anger', 'fear', 'joy', 'sadness', 'surprise']
32
+
33
+ def predict_emotions(text):
34
+ # Tokenize input
35
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
36
+
37
+ # Get predictions
38
+ with torch.no_grad():
39
+ outputs = model(inputs['input_ids'], inputs['attention_mask'])
40
+ probs = torch.sigmoid(outputs)[0].numpy()
41
+
42
+ # Create results dictionary
43
+ results = {emotion_labels[i]: float(probs[i]) for i in range(len(emotion_labels))}
44
+ return results
45
+
46
+ # Create Gradio interface
47
+ iface = gr.Interface(
48
+ fn=predict_emotions,
49
+ inputs=gr.Textbox(lines=3, placeholder="Enter text here to detect emotions..."),
50
+ outputs=gr.Label(num_top_classes=5),
51
+ title="Emotion Classification",
52
+ description="Multi-label emotion detection using fine-tuned BERT. Enter any text to detect anger, fear, joy, sadness, and surprise.",
53
+ examples=[
54
+ ["I am so happy and excited about this!"],
55
+ ["This is terrible and makes me angry."],
56
+ ["I can't believe this happened, it's shocking!"]
57
+ ]
58
+ )
59
+
60
+ iface.launch()