File size: 2,188 Bytes
a4f91f2
ef8d2b3
 
 
a4f91f2
ef8d2b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import streamlit as st
import torch
import time
from transformers import AutoTokenizer, AutoModelForSequenceClassification


# Load model and tokenizer once
@st.cache_resource
def load_model_and_tokenizer():
    repo_id = "tsid7710/distillbert-emotion-model"

    with st.spinner("๐Ÿ”„ Loading model and tokenizer... please wait"):
        tokenizer = AutoTokenizer.from_pretrained(repo_id)
        model = AutoModelForSequenceClassification.from_pretrained(repo_id)
        
        device = "cuda" if torch.cuda.is_available() else "cpu"
        model = model.to(device)
        time.sleep(1.5)

    return tokenizer, model, device


tokenizer, model, device = load_model_and_tokenizer()


# Show success message only the first time it loads (not on reruns)
if "model_loaded" not in st.session_state:
    st.success("โœ… Model loaded successfully!")
    st.session_state.model_loaded = True


############## Streamlit Code #################

st.title("๐Ÿ’ฌ Emotion Classifier")

st.markdown("""
### ๐Ÿง  About This Project
This app uses a fine-tuned **DistilBERT** model to detect emotions from text.  
It classifies your sentence into one of six emotions โ€” **sadness**, **joy**, **love**, **anger**, **fear**, or **surprise**.  
Simply type a sentence below and click **Find Emotion** to see what the model predicts!
""")

user_input = st.text_input("โœ๏ธ Enter a sentence to analyze its emotion:")

classes = ['sadness', 'joy', 'love', 'anger', 'fear', 'surprise'] 

def find_emotion(user_input: str) -> str:
    inputs = tokenizer(text = user_input, return_tensors = "pt", truncation=True, padding=True).to(device)
    with torch.inference_mode():
        output = model(**inputs)
        logits = output.logits
        pred = torch.argmax(logits, dim = -1).item()
    
    print("Prediction: ", pred)
    return classes[pred]


if st.button('Find Emotion'):
    if user_input.strip():
        with st.spinner("๐Ÿง  Analyzing emotion..."):
            time.sleep(1) # Short delay for UX
            result = find_emotion(user_input)

        st.success(f"Predicted Emotion: **{result}**")
    else:
        st.warning("โš ๏ธ Please enter a sentence before clicking the button.")