|
|
import streamlit as st |
|
|
import torch |
|
|
import joblib |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.set_page_config( |
|
|
page_title="StackOverflow Tag Predictor", |
|
|
page_icon="๐ฏ", |
|
|
layout="centered", |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.markdown(""" |
|
|
<style> |
|
|
|
|
|
body { |
|
|
background-color: #F5F7FB; |
|
|
} |
|
|
|
|
|
.header { |
|
|
text-align: center; |
|
|
margin-top: -20px; |
|
|
margin-bottom: 10px; |
|
|
} |
|
|
|
|
|
.header-title { |
|
|
font-size: 48px; |
|
|
font-weight: 900; |
|
|
background: linear-gradient(90deg, #4A4AFC, #6A6AFF); |
|
|
-webkit-background-clip: text; |
|
|
-webkit-text-fill-color: transparent; |
|
|
} |
|
|
|
|
|
.header-subtitle { |
|
|
font-size: 18px; |
|
|
color: #555; |
|
|
margin-top: -10px; |
|
|
margin-bottom: 25px; |
|
|
} |
|
|
|
|
|
.card { |
|
|
background: white; |
|
|
padding: 30px; |
|
|
border-radius: 18px; |
|
|
box-shadow: 0px 6px 20px rgba(0,0,0,0.08); |
|
|
margin-bottom: 20px; |
|
|
} |
|
|
|
|
|
.result-tag { |
|
|
background: linear-gradient(90deg, #4A4AFC, #6A6AFF); |
|
|
padding: 14px 24px; |
|
|
border-radius: 14px; |
|
|
color: white; |
|
|
display: inline-block; |
|
|
font-size: 22px; |
|
|
font-weight: 700; |
|
|
animation: fadeIn 0.4s ease-out; |
|
|
} |
|
|
|
|
|
@keyframes fadeIn { |
|
|
from {opacity: 0; transform: translateY(10px);} |
|
|
to {opacity: 1; transform: translateY(0);} |
|
|
} |
|
|
|
|
|
.footer { |
|
|
text-align: center; |
|
|
margin-top: 40px; |
|
|
color: #777; |
|
|
font-size: 14px; |
|
|
} |
|
|
|
|
|
</style> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_model(): |
|
|
model = AutoModelForSequenceClassification.from_pretrained(".") |
|
|
tokenizer = AutoTokenizer.from_pretrained(".") |
|
|
return model, tokenizer |
|
|
|
|
|
model, tokenizer = load_model() |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model = model.to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
label_encoder = joblib.load("label_encoder.joblib") |
|
|
id2label = {i: label for i, label in enumerate(label_encoder.classes_)} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict_tag(text): |
|
|
encoding = tokenizer( |
|
|
text, |
|
|
truncation=True, |
|
|
padding=True, |
|
|
max_length=128, |
|
|
return_tensors="pt" |
|
|
) |
|
|
encoding = {k: v.to(device) for k, v in encoding.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**encoding) |
|
|
|
|
|
pred_id = torch.argmax(outputs.logits, dim=-1).item() |
|
|
tag = id2label[pred_id] |
|
|
confidence = torch.softmax(outputs.logits, dim=-1).max().item() |
|
|
|
|
|
return tag, confidence |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.markdown(""" |
|
|
<div class="header"> |
|
|
<div class="header-title">๐ฏ StackOverflow Tag Predictor</div> |
|
|
<div class="header-subtitle">Powered by DistilBERT โข Predict the most likely tag from a question title</div> |
|
|
</div> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.sidebar.title("โน๏ธ About This App") |
|
|
st.sidebar.write(""" |
|
|
This app uses a fine-tuned **DistilBERT** model trained on the |
|
|
top 50 StackOverflow tags. |
|
|
|
|
|
You can: |
|
|
- Type your own question title |
|
|
- Pick from example titles |
|
|
- See model confidence |
|
|
""") |
|
|
|
|
|
st.sidebar.write("### ๐ง Model Info") |
|
|
st.sidebar.write(f"**Labels:** {len(id2label)} classes") |
|
|
st.sidebar.write("**Framework:** PyTorch + HuggingFace Transformers") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples = [ |
|
|
"How to fix NullPointerException in Java?", |
|
|
"What is the best way to center a div in CSS?", |
|
|
"How do I connect to a MySQL database in Python?", |
|
|
"Why is my React component not rendering?", |
|
|
"How to optimize a SQL query that is too slow?", |
|
|
"How to declare an array in C++?" |
|
|
] |
|
|
|
|
|
example_choice = st.selectbox( |
|
|
"โจ Or choose an example question:", |
|
|
["(None)"] + examples |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.markdown("<div class='card'>", unsafe_allow_html=True) |
|
|
|
|
|
if example_choice != "(None)": |
|
|
user_input = example_choice |
|
|
else: |
|
|
user_input = st.text_area( |
|
|
"๐ฌ Enter a StackOverflow question title:", |
|
|
height=120, |
|
|
placeholder="Example: \"How to fix NullPointerException in Java?\"" |
|
|
) |
|
|
|
|
|
predict_btn = st.button("๐ Predict Tag", use_container_width=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if predict_btn: |
|
|
if user_input.strip() == "": |
|
|
st.warning("โ ๏ธ Please enter a question title.") |
|
|
else: |
|
|
with st.spinner("Analyzing with AIโฆ ๐งโจ"): |
|
|
tag, confidence = predict_tag(user_input) |
|
|
|
|
|
st.success("Prediction ready! ๐") |
|
|
|
|
|
st.markdown(f"<div class='result-tag'>{tag}</div>", unsafe_allow_html=True) |
|
|
|
|
|
st.markdown( |
|
|
f"### ๐ฅ Confidence Score: **{confidence*100:.2f}%**" |
|
|
) |
|
|
|
|
|
st.markdown("</div>", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.markdown(""" |
|
|
<div class='footer'> |
|
|
Made with โค๏ธ using DistilBERT + Streamlit + HuggingFace Spaces<br> |
|
|
Try different example titles or write your own! |
|
|
</div> |
|
|
""", unsafe_allow_html=True) |
|
|
|