import streamlit as st import torch import joblib from transformers import AutoTokenizer, AutoModelForSequenceClassification # ----------------------------------------------------------- # ๐Ÿš€ Streamlit Page Configuration # ----------------------------------------------------------- st.set_page_config( page_title="StackOverflow Tag Predictor", page_icon="๐ŸŽฏ", layout="centered", ) # ----------------------------------------------------------- # ๐ŸŒˆ Custom CSS for a Rich UI # ----------------------------------------------------------- st.markdown(""" """, unsafe_allow_html=True) # ----------------------------------------------------------- # ๐Ÿ“ฆ Load Model & Tokenizer # ----------------------------------------------------------- @st.cache_resource def load_model(): model = AutoModelForSequenceClassification.from_pretrained(".") tokenizer = AutoTokenizer.from_pretrained(".") return model, tokenizer model, tokenizer = load_model() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) # ----------------------------------------------------------- # ๐Ÿ”ค Load Label Encoder # ----------------------------------------------------------- label_encoder = joblib.load("label_encoder.joblib") id2label = {i: label for i, label in enumerate(label_encoder.classes_)} # ----------------------------------------------------------- # ๐Ÿ”ฎ Prediction Function # ----------------------------------------------------------- def predict_tag(text): encoding = tokenizer( text, truncation=True, padding=True, max_length=128, return_tensors="pt" ) encoding = {k: v.to(device) for k, v in encoding.items()} with torch.no_grad(): outputs = model(**encoding) pred_id = torch.argmax(outputs.logits, dim=-1).item() tag = id2label[pred_id] confidence = torch.softmax(outputs.logits, dim=-1).max().item() return tag, confidence # ----------------------------------------------------------- # ๐ŸŽฏ Header # ----------------------------------------------------------- st.markdown("""
๐ŸŽฏ StackOverflow Tag Predictor
Powered by DistilBERT โ€ข Predict the most likely tag from a question title
""", unsafe_allow_html=True) # ----------------------------------------------------------- # ๐ŸŽ›๏ธ Sidebar โ€“ About the Model # ----------------------------------------------------------- st.sidebar.title("โ„น๏ธ About This App") st.sidebar.write(""" This app uses a fine-tuned **DistilBERT** model trained on the top 50 StackOverflow tags. You can: - Type your own question title - Pick from example titles - See model confidence """) st.sidebar.write("### ๐Ÿ”ง Model Info") st.sidebar.write(f"**Labels:** {len(id2label)} classes") st.sidebar.write("**Framework:** PyTorch + HuggingFace Transformers") # ----------------------------------------------------------- # ๐Ÿงช Example Questions Dropdown # ----------------------------------------------------------- examples = [ "How to fix NullPointerException in Java?", "What is the best way to center a div in CSS?", "How do I connect to a MySQL database in Python?", "Why is my React component not rendering?", "How to optimize a SQL query that is too slow?", "How to declare an array in C++?" ] example_choice = st.selectbox( "โœจ Or choose an example question:", ["(None)"] + examples ) # ----------------------------------------------------------- # ๐Ÿ“ Main Input Card # ----------------------------------------------------------- st.markdown("
", unsafe_allow_html=True) if example_choice != "(None)": user_input = example_choice else: user_input = st.text_area( "๐Ÿ’ฌ Enter a StackOverflow question title:", height=120, placeholder="Example: \"How to fix NullPointerException in Java?\"" ) predict_btn = st.button("๐Ÿ” Predict Tag", use_container_width=True) # ----------------------------------------------------------- # ๐Ÿ“Š Prediction Output # ----------------------------------------------------------- if predict_btn: if user_input.strip() == "": st.warning("โš ๏ธ Please enter a question title.") else: with st.spinner("Analyzing with AIโ€ฆ ๐Ÿ”งโœจ"): tag, confidence = predict_tag(user_input) st.success("Prediction ready! ๐ŸŽ‰") st.markdown(f"
{tag}
", unsafe_allow_html=True) st.markdown( f"### ๐Ÿ”ฅ Confidence Score: **{confidence*100:.2f}%**" ) st.markdown("
", unsafe_allow_html=True) # ----------------------------------------------------------- # ๐Ÿ“˜ Footer # ----------------------------------------------------------- st.markdown(""" """, unsafe_allow_html=True)