Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer | |
| from huggingface_hub import hf_hub_download | |
| from BertEmotionClassifier import BertEmotionClassifier | |
| # ---------------------------- | |
| # Config | |
| # ---------------------------- | |
| EMOTIONS = ['anger', 'fear', 'joy', 'sadness', 'surprise'] | |
| MODEL_NAME = "roberta-base" | |
| st.set_page_config(page_title="Emotion Classifier", page_icon="๐ญ", layout="centered") | |
| st.title("๐ญ Emotion Detection AI") | |
| st.markdown("Predict emotional sentiment from text using a fine-tuned RoBERTa model.") | |
| # ---------------------------- | |
| # Load Model | |
| # ---------------------------- | |
| def load_model(): | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = BertEmotionClassifier(model_name=MODEL_NAME, num_labels=len(EMOTIONS)) | |
| model_path = hf_hub_download(repo_id="aadhi3/RoBert_Model", filename="model.pth") | |
| state_dict = torch.load(model_path, map_location="cpu") | |
| state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| return tokenizer, model | |
| tokenizer, model = load_model() | |
| # ---------------------------- | |
| # Prediction Function | |
| # ---------------------------- | |
| def predict_emotions(text: str): | |
| encoding = tokenizer( | |
| text, | |
| add_special_tokens=True, | |
| max_length=256, | |
| padding="max_length", | |
| truncation=True, | |
| return_tensors="pt" | |
| ) | |
| with torch.no_grad(): | |
| logits = model(**encoding) | |
| probs = F.softmax(logits, dim=-1)[0].cpu() | |
| return {emo: round(float(probs[i]), 4) for i, emo in enumerate(EMOTIONS)} | |
| # ---------------------------- | |
| # UI Layout | |
| # ---------------------------- | |
| input_text = st.text_area("Enter text to analyze:", height=120) | |
| if st.button("๐ฎ Analyze"): | |
| if input_text.strip(): | |
| with st.spinner("Analyzing emotions..."): | |
| results = predict_emotions(input_text.strip()) | |
| df = pd.DataFrame(results.items(), columns=["Emotion", "Probability"]) | |
| dominant = df.loc[df["Probability"].idxmax(), "Emotion"] | |
| st.markdown("### ๐ Prediction Details") | |
| cols = st.columns(len(EMOTIONS)) | |
| for col, (emo, val) in zip(cols, results.items()): | |
| col.metric(label=emo.capitalize(), value=f"{val}") | |
| st.markdown(f"### ๐ฏ Dominant Emotion: **{dominant.upper()}**") | |
| # Chart | |
| fig, ax = plt.subplots(figsize=(5, 3)) | |
| ax.bar(df["Emotion"], df["Probability"]) | |
| ax.set_ylim(0, 1) | |
| ax.set_ylabel("Probability") | |
| ax.set_title("Emotion Prediction") | |
| st.pyplot(fig) | |
| st.markdown("---") | |
| st.caption("Built with โค๏ธ using Streamlit & PyTorch โ deployed on Hugging Face Spaces") | |