import streamlit as st import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification import joblib import numpy as np import pandas as pd import altair as alt # --------------------------------------------------------- # Custom CSS for Fun, Colorful UI # --------------------------------------------------------- st.markdown(""" """, unsafe_allow_html=True) # --------------------------------------------------------- # Load Model + Tokenizer # --------------------------------------------------------- @st.cache_resource def load_model(): model = AutoModelForSequenceClassification.from_pretrained(".") tokenizer = AutoTokenizer.from_pretrained(".") return model, tokenizer model, tokenizer = load_model() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) # Load MultiLabelBinarizer mlb = joblib.load("mlb.joblib") labels = mlb.classes_ # --------------------------------------------------------- # Prediction function # --------------------------------------------------------- def predict_tags(text, threshold=0.3): encoded = tokenizer( text, padding=True, truncation=True, max_length=256, return_tensors="pt" ) encoded = {k: v.to(device) for k, v in encoded.items()} with torch.no_grad(): outputs = model(**encoded) probs = torch.sigmoid(outputs.logits).cpu().numpy()[0] predicted_mask = probs >= threshold predicted_tags = labels[predicted_mask] return predicted_tags, probs # --------------------------------------------------------- # 🎨 Sidebar # --------------------------------------------------------- st.sidebar.header("⚙️ Settings") threshold = st.sidebar.slider( "Prediction Threshold", 0.0, 1.0, 0.30, help="Lower = more tags, Higher = fewer but more confident" ) st.sidebar.markdown(""" ### 🤖 Model Info - BERT-based tag predictor - Multi-label classification - Trained on StackOverflow dataset """) st.sidebar.markdown("---") st.sidebar.markdown("Made with ❤️ using Streamlit + Transformers") # --------------------------------------------------------- # 🎉 Title + Description # --------------------------------------------------------- st.markdown("
Ask any technical question and watch the magic happen! 🪄
", unsafe_allow_html=True) # --------------------------------------------------------- # Example Questions # --------------------------------------------------------- st.markdown("### 🎯 Try an example:") examples = [ "How do I fix a TypeError in Python when concatenating lists?", "What is the recommended way to deploy a React application?", "Why does my SQL JOIN return duplicate rows?" ] cols = st.columns(len(examples)) for i, ex in enumerate(examples): if cols[i].button(f"Example {i+1}"): st.session_state["example_text"] = ex user_text = st.text_area( "✍️ Enter your StackOverflow question:", value=st.session_state.get("example_text", ""), height=150 ) # --------------------------------------------------------- # Predict Button # --------------------------------------------------------- if st.button("🔮 Predict Tags!"): if not user_text.strip(): st.warning("Please enter a question first ✏️") else: with st.spinner("✨ Analyzing your question… summoning the tag spirits… 🔮"): predicted_tags, probs = predict_tags(user_text, threshold) # Display tags st.markdown("## 🏷️ Predicted Tags:") if len(predicted_tags) == 0: st.error("😕 No tags predicted — try lowering the threshold!") else: for t in predicted_tags: st.markdown(f"#{t}", unsafe_allow_html=True) # Probability Chart st.markdown("### 📊 Tag Probability Chart") df = pd.DataFrame({ "Tag": labels, "Probability": probs }) chart = alt.Chart(df).mark_bar(color="#ff6ec7").encode( x="Probability:Q", y=alt.Y("Tag:N", sort="-x") ).properties(height=350) st.altair_chart(chart, use_container_width=True) # --------------------------------------------------------- # Footer # --------------------------------------------------------- st.markdown("✨ Powered by BERT • Hugging Face • Streamlit
", unsafe_allow_html=True)