Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import joblib | |
| import numpy as np | |
| import pandas as pd | |
| import altair as alt | |
| # --------------------------------------------------------- | |
| # Custom CSS for Fun, Colorful UI | |
| # --------------------------------------------------------- | |
| st.markdown(""" | |
| <style> | |
| /* Animated gradient title */ | |
| .title-gradient { | |
| font-size: 40px; | |
| font-weight: 900; | |
| text-align: center; | |
| background: linear-gradient(90deg, #ff0080, #ff8c00, #40e0d0, #8a2be2); | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| animation: glow 4s ease-in-out infinite; | |
| } | |
| @keyframes glow { | |
| 0% { filter: drop-shadow(0 0 2px #ff0080); } | |
| 50% { filter: drop-shadow(0 0 8px #40e0d0); } | |
| 100% { filter: drop-shadow(0 0 2px #ff0080); } | |
| } | |
| /* Tag pill styling */ | |
| .tag-pill { | |
| display: inline-block; | |
| padding: 8px 14px; | |
| margin: 4px; | |
| background-color: #ff6ec7; | |
| color: white; | |
| border-radius: 20px; | |
| font-size: 14px; | |
| font-weight: 600; | |
| } | |
| /* Centered subtle text */ | |
| .center { | |
| text-align: center; | |
| color: #666; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # --------------------------------------------------------- | |
| # Load Model + Tokenizer | |
| # --------------------------------------------------------- | |
| def load_model(): | |
| model = AutoModelForSequenceClassification.from_pretrained(".") | |
| tokenizer = AutoTokenizer.from_pretrained(".") | |
| return model, tokenizer | |
| model, tokenizer = load_model() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = model.to(device) | |
| # Load MultiLabelBinarizer | |
| mlb = joblib.load("mlb.joblib") | |
| labels = mlb.classes_ | |
| # --------------------------------------------------------- | |
| # Prediction function | |
| # --------------------------------------------------------- | |
| def predict_tags(text, threshold=0.3): | |
| encoded = tokenizer( | |
| text, | |
| padding=True, | |
| truncation=True, | |
| max_length=256, | |
| return_tensors="pt" | |
| ) | |
| encoded = {k: v.to(device) for k, v in encoded.items()} | |
| with torch.no_grad(): | |
| outputs = model(**encoded) | |
| probs = torch.sigmoid(outputs.logits).cpu().numpy()[0] | |
| predicted_mask = probs >= threshold | |
| predicted_tags = labels[predicted_mask] | |
| return predicted_tags, probs | |
| # --------------------------------------------------------- | |
| # 🎨 Sidebar | |
| # --------------------------------------------------------- | |
| st.sidebar.header("⚙️ Settings") | |
| threshold = st.sidebar.slider( | |
| "Prediction Threshold", | |
| 0.0, 1.0, 0.30, | |
| help="Lower = more tags, Higher = fewer but more confident" | |
| ) | |
| st.sidebar.markdown(""" | |
| ### 🤖 Model Info | |
| - BERT-based tag predictor | |
| - Multi-label classification | |
| - Trained on StackOverflow dataset | |
| """) | |
| st.sidebar.markdown("---") | |
| st.sidebar.markdown("Made with ❤️ using Streamlit + Transformers") | |
| # --------------------------------------------------------- | |
| # 🎉 Title + Description | |
| # --------------------------------------------------------- | |
| st.markdown("<h1 class='title-gradient'>✨ StackOverflow Tag Predictor ✨</h1>", unsafe_allow_html=True) | |
| st.markdown("<p class='center'>Ask any technical question and watch the magic happen! 🪄</p>", unsafe_allow_html=True) | |
| # --------------------------------------------------------- | |
| # Example Questions | |
| # --------------------------------------------------------- | |
| st.markdown("### 🎯 Try an example:") | |
| examples = [ | |
| "How do I fix a TypeError in Python when concatenating lists?", | |
| "What is the recommended way to deploy a React application?", | |
| "Why does my SQL JOIN return duplicate rows?" | |
| ] | |
| cols = st.columns(len(examples)) | |
| for i, ex in enumerate(examples): | |
| if cols[i].button(f"Example {i+1}"): | |
| st.session_state["example_text"] = ex | |
| user_text = st.text_area( | |
| "✍️ Enter your StackOverflow question:", | |
| value=st.session_state.get("example_text", ""), | |
| height=150 | |
| ) | |
| # --------------------------------------------------------- | |
| # Predict Button | |
| # --------------------------------------------------------- | |
| if st.button("🔮 Predict Tags!"): | |
| if not user_text.strip(): | |
| st.warning("Please enter a question first ✏️") | |
| else: | |
| with st.spinner("✨ Analyzing your question… summoning the tag spirits… 🔮"): | |
| predicted_tags, probs = predict_tags(user_text, threshold) | |
| # Display tags | |
| st.markdown("## 🏷️ Predicted Tags:") | |
| if len(predicted_tags) == 0: | |
| st.error("😕 No tags predicted — try lowering the threshold!") | |
| else: | |
| for t in predicted_tags: | |
| st.markdown(f"<span class='tag-pill'>#{t}</span>", unsafe_allow_html=True) | |
| # Probability Chart | |
| st.markdown("### 📊 Tag Probability Chart") | |
| df = pd.DataFrame({ | |
| "Tag": labels, | |
| "Probability": probs | |
| }) | |
| chart = alt.Chart(df).mark_bar(color="#ff6ec7").encode( | |
| x="Probability:Q", | |
| y=alt.Y("Tag:N", sort="-x") | |
| ).properties(height=350) | |
| st.altair_chart(chart, use_container_width=True) | |
| # --------------------------------------------------------- | |
| # Footer | |
| # --------------------------------------------------------- | |
| st.markdown("<p class='center'>✨ Powered by BERT • Hugging Face • Streamlit</p>", unsafe_allow_html=True) | |