Dl-Project / app.py
aadhi3's picture
Update app.py
714712c verified
import streamlit as st
import torch
import pandas as pd
import matplotlib.pyplot as plt
import torch.nn.functional as F
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download
from BertEmotionClassifier import BertEmotionClassifier
# ----------------------------
# Config
# ----------------------------
EMOTIONS = ['anger', 'fear', 'joy', 'sadness', 'surprise']
MODEL_NAME = "roberta-base"
st.set_page_config(page_title="Emotion Classifier", page_icon="๐ŸŽญ", layout="centered")
st.title("๐ŸŽญ Emotion Detection AI")
st.markdown("Predict emotional sentiment from text using a fine-tuned RoBERTa model.")
# ----------------------------
# Load Model
# ----------------------------
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = BertEmotionClassifier(model_name=MODEL_NAME, num_labels=len(EMOTIONS))
model_path = hf_hub_download(repo_id="aadhi3/RoBert_Model", filename="model.pth")
state_dict = torch.load(model_path, map_location="cpu")
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.eval()
return tokenizer, model
tokenizer, model = load_model()
# ----------------------------
# Prediction Function
# ----------------------------
def predict_emotions(text: str):
encoding = tokenizer(
text,
add_special_tokens=True,
max_length=256,
padding="max_length",
truncation=True,
return_tensors="pt"
)
with torch.no_grad():
logits = model(**encoding)
probs = F.softmax(logits, dim=-1)[0].cpu()
return {emo: round(float(probs[i]), 4) for i, emo in enumerate(EMOTIONS)}
# ----------------------------
# UI Layout
# ----------------------------
input_text = st.text_area("Enter text to analyze:", height=120)
if st.button("๐Ÿ”ฎ Analyze"):
if input_text.strip():
with st.spinner("Analyzing emotions..."):
results = predict_emotions(input_text.strip())
df = pd.DataFrame(results.items(), columns=["Emotion", "Probability"])
dominant = df.loc[df["Probability"].idxmax(), "Emotion"]
st.markdown("### ๐Ÿ“Š Prediction Details")
cols = st.columns(len(EMOTIONS))
for col, (emo, val) in zip(cols, results.items()):
col.metric(label=emo.capitalize(), value=f"{val}")
st.markdown(f"### ๐ŸŽฏ Dominant Emotion: **{dominant.upper()}**")
# Chart
fig, ax = plt.subplots(figsize=(5, 3))
ax.bar(df["Emotion"], df["Probability"])
ax.set_ylim(0, 1)
ax.set_ylabel("Probability")
ax.set_title("Emotion Prediction")
st.pyplot(fig)
st.markdown("---")
st.caption("Built with โค๏ธ using Streamlit & PyTorch โ€” deployed on Hugging Face Spaces")