VenujaDeSilva's picture
Update app.py
1480198 verified
import streamlit as st
import torch
import joblib
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# -----------------------------------------------------------
# ๐Ÿš€ Streamlit Page Configuration
# -----------------------------------------------------------
st.set_page_config(
page_title="StackOverflow Tag Predictor",
page_icon="๐ŸŽฏ",
layout="centered",
)
# -----------------------------------------------------------
# ๐ŸŒˆ Custom CSS for a Rich UI
# -----------------------------------------------------------
st.markdown("""
<style>
body {
background-color: #F5F7FB;
}
.header {
text-align: center;
margin-top: -20px;
margin-bottom: 10px;
}
.header-title {
font-size: 48px;
font-weight: 900;
background: linear-gradient(90deg, #4A4AFC, #6A6AFF);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
}
.header-subtitle {
font-size: 18px;
color: #555;
margin-top: -10px;
margin-bottom: 25px;
}
.card {
background: white;
padding: 30px;
border-radius: 18px;
box-shadow: 0px 6px 20px rgba(0,0,0,0.08);
margin-bottom: 20px;
}
.result-tag {
background: linear-gradient(90deg, #4A4AFC, #6A6AFF);
padding: 14px 24px;
border-radius: 14px;
color: white;
display: inline-block;
font-size: 22px;
font-weight: 700;
animation: fadeIn 0.4s ease-out;
}
@keyframes fadeIn {
from {opacity: 0; transform: translateY(10px);}
to {opacity: 1; transform: translateY(0);}
}
.footer {
text-align: center;
margin-top: 40px;
color: #777;
font-size: 14px;
}
</style>
""", unsafe_allow_html=True)
# -----------------------------------------------------------
# ๐Ÿ“ฆ Load Model & Tokenizer
# -----------------------------------------------------------
@st.cache_resource
def load_model():
model = AutoModelForSequenceClassification.from_pretrained(".")
tokenizer = AutoTokenizer.from_pretrained(".")
return model, tokenizer
model, tokenizer = load_model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# -----------------------------------------------------------
# ๐Ÿ”ค Load Label Encoder
# -----------------------------------------------------------
label_encoder = joblib.load("label_encoder.joblib")
id2label = {i: label for i, label in enumerate(label_encoder.classes_)}
# -----------------------------------------------------------
# ๐Ÿ”ฎ Prediction Function
# -----------------------------------------------------------
def predict_tag(text):
encoding = tokenizer(
text,
truncation=True,
padding=True,
max_length=128,
return_tensors="pt"
)
encoding = {k: v.to(device) for k, v in encoding.items()}
with torch.no_grad():
outputs = model(**encoding)
pred_id = torch.argmax(outputs.logits, dim=-1).item()
tag = id2label[pred_id]
confidence = torch.softmax(outputs.logits, dim=-1).max().item()
return tag, confidence
# -----------------------------------------------------------
# ๐ŸŽฏ Header
# -----------------------------------------------------------
st.markdown("""
<div class="header">
<div class="header-title">๐ŸŽฏ StackOverflow Tag Predictor</div>
<div class="header-subtitle">Powered by DistilBERT โ€ข Predict the most likely tag from a question title</div>
</div>
""", unsafe_allow_html=True)
# -----------------------------------------------------------
# ๐ŸŽ›๏ธ Sidebar โ€“ About the Model
# -----------------------------------------------------------
st.sidebar.title("โ„น๏ธ About This App")
st.sidebar.write("""
This app uses a fine-tuned **DistilBERT** model trained on the
top 50 StackOverflow tags.
You can:
- Type your own question title
- Pick from example titles
- See model confidence
""")
st.sidebar.write("### ๐Ÿ”ง Model Info")
st.sidebar.write(f"**Labels:** {len(id2label)} classes")
st.sidebar.write("**Framework:** PyTorch + HuggingFace Transformers")
# -----------------------------------------------------------
# ๐Ÿงช Example Questions Dropdown
# -----------------------------------------------------------
examples = [
"How to fix NullPointerException in Java?",
"What is the best way to center a div in CSS?",
"How do I connect to a MySQL database in Python?",
"Why is my React component not rendering?",
"How to optimize a SQL query that is too slow?",
"How to declare an array in C++?"
]
example_choice = st.selectbox(
"โœจ Or choose an example question:",
["(None)"] + examples
)
# -----------------------------------------------------------
# ๐Ÿ“ Main Input Card
# -----------------------------------------------------------
st.markdown("<div class='card'>", unsafe_allow_html=True)
if example_choice != "(None)":
user_input = example_choice
else:
user_input = st.text_area(
"๐Ÿ’ฌ Enter a StackOverflow question title:",
height=120,
placeholder="Example: \"How to fix NullPointerException in Java?\""
)
predict_btn = st.button("๐Ÿ” Predict Tag", use_container_width=True)
# -----------------------------------------------------------
# ๐Ÿ“Š Prediction Output
# -----------------------------------------------------------
if predict_btn:
if user_input.strip() == "":
st.warning("โš ๏ธ Please enter a question title.")
else:
with st.spinner("Analyzing with AIโ€ฆ ๐Ÿ”งโœจ"):
tag, confidence = predict_tag(user_input)
st.success("Prediction ready! ๐ŸŽ‰")
st.markdown(f"<div class='result-tag'>{tag}</div>", unsafe_allow_html=True)
st.markdown(
f"### ๐Ÿ”ฅ Confidence Score: **{confidence*100:.2f}%**"
)
st.markdown("</div>", unsafe_allow_html=True)
# -----------------------------------------------------------
# ๐Ÿ“˜ Footer
# -----------------------------------------------------------
st.markdown("""
<div class='footer'>
Made with โค๏ธ using DistilBERT + Streamlit + HuggingFace Spaces<br>
Try different example titles or write your own!
</div>
""", unsafe_allow_html=True)