import streamlit as st
import torch
import joblib
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# -----------------------------------------------------------
# ๐ Streamlit Page Configuration
# -----------------------------------------------------------
st.set_page_config(
page_title="StackOverflow Tag Predictor",
page_icon="๐ฏ",
layout="centered",
)
# -----------------------------------------------------------
# ๐ Custom CSS for a Rich UI
# -----------------------------------------------------------
st.markdown("""
""", unsafe_allow_html=True)
# -----------------------------------------------------------
# ๐ฆ Load Model & Tokenizer
# -----------------------------------------------------------
@st.cache_resource
def load_model():
model = AutoModelForSequenceClassification.from_pretrained(".")
tokenizer = AutoTokenizer.from_pretrained(".")
return model, tokenizer
model, tokenizer = load_model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# -----------------------------------------------------------
# ๐ค Load Label Encoder
# -----------------------------------------------------------
label_encoder = joblib.load("label_encoder.joblib")
id2label = {i: label for i, label in enumerate(label_encoder.classes_)}
# -----------------------------------------------------------
# ๐ฎ Prediction Function
# -----------------------------------------------------------
def predict_tag(text):
encoding = tokenizer(
text,
truncation=True,
padding=True,
max_length=128,
return_tensors="pt"
)
encoding = {k: v.to(device) for k, v in encoding.items()}
with torch.no_grad():
outputs = model(**encoding)
pred_id = torch.argmax(outputs.logits, dim=-1).item()
tag = id2label[pred_id]
confidence = torch.softmax(outputs.logits, dim=-1).max().item()
return tag, confidence
# -----------------------------------------------------------
# ๐ฏ Header
# -----------------------------------------------------------
st.markdown("""
""", unsafe_allow_html=True)
# -----------------------------------------------------------
# ๐๏ธ Sidebar โ About the Model
# -----------------------------------------------------------
st.sidebar.title("โน๏ธ About This App")
st.sidebar.write("""
This app uses a fine-tuned **DistilBERT** model trained on the
top 50 StackOverflow tags.
You can:
- Type your own question title
- Pick from example titles
- See model confidence
""")
st.sidebar.write("### ๐ง Model Info")
st.sidebar.write(f"**Labels:** {len(id2label)} classes")
st.sidebar.write("**Framework:** PyTorch + HuggingFace Transformers")
# -----------------------------------------------------------
# ๐งช Example Questions Dropdown
# -----------------------------------------------------------
examples = [
"How to fix NullPointerException in Java?",
"What is the best way to center a div in CSS?",
"How do I connect to a MySQL database in Python?",
"Why is my React component not rendering?",
"How to optimize a SQL query that is too slow?",
"How to declare an array in C++?"
]
example_choice = st.selectbox(
"โจ Or choose an example question:",
["(None)"] + examples
)
# -----------------------------------------------------------
# ๐ Main Input Card
# -----------------------------------------------------------
st.markdown("", unsafe_allow_html=True)
if example_choice != "(None)":
user_input = example_choice
else:
user_input = st.text_area(
"๐ฌ Enter a StackOverflow question title:",
height=120,
placeholder="Example: \"How to fix NullPointerException in Java?\""
)
predict_btn = st.button("๐ Predict Tag", use_container_width=True)
# -----------------------------------------------------------
# ๐ Prediction Output
# -----------------------------------------------------------
if predict_btn:
if user_input.strip() == "":
st.warning("โ ๏ธ Please enter a question title.")
else:
with st.spinner("Analyzing with AIโฆ ๐งโจ"):
tag, confidence = predict_tag(user_input)
st.success("Prediction ready! ๐")
st.markdown(f"
{tag}
", unsafe_allow_html=True)
st.markdown(
f"### ๐ฅ Confidence Score: **{confidence*100:.2f}%**"
)
st.markdown("
", unsafe_allow_html=True)
# -----------------------------------------------------------
# ๐ Footer
# -----------------------------------------------------------
st.markdown("""
""", unsafe_allow_html=True)