Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from transformers import BertTokenizer, BertModel, BertForSequenceClassification | |
| import gradio as gr | |
| # Your model repo | |
| model_name = "keethu/bert-emotion-classifier" | |
| # Load tokenizer | |
| tokenizer = BertTokenizer.from_pretrained(model_name) | |
| # Load base BERT model | |
| base_bert = BertModel.from_pretrained(model_name) | |
| # Define your classifier architecture (same as training) | |
| class BERTClassifier(nn.Module): | |
| def __init__(self, bert_model, num_labels=5, dropout=0.3): | |
| super(BERTClassifier, self).__init__() | |
| self.bert = bert_model | |
| self.dropout = nn.Dropout(dropout) | |
| self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels) | |
| def forward(self, input_ids, attention_mask): | |
| outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
| pooled_output = outputs.pooler_output | |
| pooled_output = self.dropout(pooled_output) | |
| logits = self.classifier(pooled_output) | |
| return logits | |
| # Create model instance | |
| model = BERTClassifier(base_bert, num_labels=5, dropout=0.3) | |
| # Load the trained weights - USE from_pretrained properly | |
| from huggingface_hub import hf_hub_download | |
| import os | |
| # Download the model file | |
| model_path = hf_hub_download(repo_id=model_name, filename="pytorch_model.bin") | |
| # Load state dict | |
| state_dict = torch.load(model_path, map_location='cpu') | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| emotion_labels = ['anger', 'fear', 'joy', 'sadness', 'surprise'] | |
| def predict_emotions(text): | |
| # Tokenize input | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128) | |
| # Get predictions | |
| with torch.no_grad(): | |
| outputs = model(inputs['input_ids'], inputs['attention_mask']) | |
| probs = torch.sigmoid(outputs)[0].numpy() | |
| # Create results dictionary | |
| results = {emotion_labels[i]: float(probs[i]) for i in range(len(emotion_labels))} | |
| return results | |
| # Create Gradio interface | |
| iface = gr.Interface( | |
| fn=predict_emotions, | |
| inputs=gr.Textbox(lines=3, placeholder="Enter text here to detect emotions..."), | |
| outputs=gr.Label(num_top_classes=5), | |
| title="Emotion Classification", | |
| description="Multi-label emotion detection using fine-tuned BERT. Enter any text to detect anger, fear, joy, sadness, and surprise.", | |
| examples=[ | |
| ["I am so happy and excited about this!"], | |
| ["This is terrible and makes me angry."], | |
| ["I can't believe this happened, it's shocking!"] | |
| ] | |
| ) | |
| iface.launch() | |