import streamlit as st import torch import pandas as pd import matplotlib.pyplot as plt import torch.nn.functional as F from transformers import AutoTokenizer from huggingface_hub import hf_hub_download from BertEmotionClassifier import BertEmotionClassifier # ---------------------------- # Config # ---------------------------- EMOTIONS = ['anger', 'fear', 'joy', 'sadness', 'surprise'] MODEL_NAME = "roberta-base" st.set_page_config(page_title="Emotion Classifier", page_icon="🎭", layout="centered") st.title("🎭 Emotion Detection AI") st.markdown("Predict emotional sentiment from text using a fine-tuned RoBERTa model.") # ---------------------------- # Load Model # ---------------------------- @st.cache_resource def load_model(): tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = BertEmotionClassifier(model_name=MODEL_NAME, num_labels=len(EMOTIONS)) model_path = hf_hub_download(repo_id="aadhi3/RoBert_Model", filename="model.pth") state_dict = torch.load(model_path, map_location="cpu") state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} model.load_state_dict(state_dict) model.eval() return tokenizer, model tokenizer, model = load_model() # ---------------------------- # Prediction Function # ---------------------------- def predict_emotions(text: str): encoding = tokenizer( text, add_special_tokens=True, max_length=256, padding="max_length", truncation=True, return_tensors="pt" ) with torch.no_grad(): logits = model(**encoding) probs = F.softmax(logits, dim=-1)[0].cpu() return {emo: round(float(probs[i]), 4) for i, emo in enumerate(EMOTIONS)} # ---------------------------- # UI Layout # ---------------------------- input_text = st.text_area("Enter text to analyze:", height=120) if st.button("🔮 Analyze"): if input_text.strip(): with st.spinner("Analyzing emotions..."): results = predict_emotions(input_text.strip()) df = pd.DataFrame(results.items(), columns=["Emotion", "Probability"]) dominant = df.loc[df["Probability"].idxmax(), "Emotion"] st.markdown("### 📊 Prediction Details") cols = st.columns(len(EMOTIONS)) for col, (emo, val) in zip(cols, results.items()): col.metric(label=emo.capitalize(), value=f"{val}") st.markdown(f"### 🎯 Dominant Emotion: **{dominant.upper()}**") # Chart fig, ax = plt.subplots(figsize=(5, 3)) ax.bar(df["Emotion"], df["Probability"]) ax.set_ylim(0, 1) ax.set_ylabel("Probability") ax.set_title("Emotion Prediction") st.pyplot(fig) st.markdown("---") st.caption("Built with ❤️ using Streamlit & PyTorch — deployed on Hugging Face Spaces")