keethu's picture
Update app.py
52c20e5 verified
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel, BertForSequenceClassification
import gradio as gr
# Your model repo
model_name = "keethu/bert-emotion-classifier"
# Load tokenizer
tokenizer = BertTokenizer.from_pretrained(model_name)
# Load base BERT model
base_bert = BertModel.from_pretrained(model_name)
# Define your classifier architecture (same as training)
class BERTClassifier(nn.Module):
def __init__(self, bert_model, num_labels=5, dropout=0.3):
super(BERTClassifier, self).__init__()
self.bert = bert_model
self.dropout = nn.Dropout(dropout)
self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.pooler_output
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
return logits
# Create model instance
model = BERTClassifier(base_bert, num_labels=5, dropout=0.3)
# Load the trained weights - USE from_pretrained properly
from huggingface_hub import hf_hub_download
import os
# Download the model file
model_path = hf_hub_download(repo_id=model_name, filename="pytorch_model.bin")
# Load state dict
state_dict = torch.load(model_path, map_location='cpu')
model.load_state_dict(state_dict)
model.eval()
emotion_labels = ['anger', 'fear', 'joy', 'sadness', 'surprise']
def predict_emotions(text):
# Tokenize input
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
# Get predictions
with torch.no_grad():
outputs = model(inputs['input_ids'], inputs['attention_mask'])
probs = torch.sigmoid(outputs)[0].numpy()
# Create results dictionary
results = {emotion_labels[i]: float(probs[i]) for i in range(len(emotion_labels))}
return results
# Create Gradio interface
iface = gr.Interface(
fn=predict_emotions,
inputs=gr.Textbox(lines=3, placeholder="Enter text here to detect emotions..."),
outputs=gr.Label(num_top_classes=5),
title="Emotion Classification",
description="Multi-label emotion detection using fine-tuned BERT. Enter any text to detect anger, fear, joy, sadness, and surprise.",
examples=[
["I am so happy and excited about this!"],
["This is terrible and makes me angry."],
["I can't believe this happened, it's shocking!"]
]
)
iface.launch()