dejanseo's picture
Update src/streamlit_app.py
0ade00c verified
import streamlit as st
import torch
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSequenceClassification
MODEL_ID = "dejanseo/ecommerce-query-volume-classifier"
LABELS = ["very_low", "low", "medium", "high", "very_high"]
DISPLAY_LABELS = {
"very_high": "Very High",
"high": "High",
"medium": "Medium",
"low": "Low",
"very_low": "Very Low",
}
VOLUME_RANGES = {
"very_high": "10,000+",
"high": "1,000–9,999",
"medium": "100–999",
"low": "10–99",
"very_low": "<10",
}
CLASS_COLORS = {
"Very High": "background-color: rgba(40, 167, 69, 0.1); color: #28a745;",
"High": "background-color: rgba(113, 194, 133, 0.1); color: #71c285;",
"Medium": "background-color: rgba(255, 193, 7, 0.1); color: #c79100;",
"Low": "background-color: rgba(253, 126, 20, 0.1); color: #fd7e14;",
"Very Low": "background-color: rgba(220, 53, 69, 0.1); color: #dc3545;",
}
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
model.eval()
return tokenizer, model
def predict(queries, tokenizer, model):
inputs = tokenizer(
queries,
return_tensors="pt",
padding=True,
truncation=True,
max_length=32,
)
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.softmax(logits, dim=-1)
preds = torch.argmax(probs, dim=-1)
results = []
for i, query in enumerate(queries):
label = LABELS[preds[i].item()]
conf = probs[i][preds[i].item()].item()
dist = {LABELS[j]: probs[i][j].item() for j in range(len(LABELS))}
results.append({"query": query, "label": label, "confidence": conf, "distribution": dist})
return results
def style_class_cells(val):
return CLASS_COLORS.get(val, "")
st.set_page_config(page_title="Query Volume Classifier", page_icon="πŸ”", layout="wide")
st.markdown("""
<style>
@import url('https://fonts.googleapis.com/css2?family=Montserrat:wght@400;500;600;700&display=swap');
html, body, [class*="css"], .stMarkdown, .stTextInput, .stTextArea,
.stMetric, .stDataFrame, h1, h2, h3, h4, h5, h6, p, span, div, label {
font-family: 'Montserrat', sans-serif !important;
font-weight: 500;
}
</style>
""", unsafe_allow_html=True)
st.logo(
"https://dejan.ai/wp-content/uploads/2024/02/dejan-300x103.png",
link="https://dejan.ai/",
size="large",
)
st.subheader("eCommerce Query Volume Classifier")
st.caption("Predicts search volume class for product queries.")
tokenizer, model = load_model()
tab_single, tab_batch = st.tabs(["Single Query", "Batch"])
with tab_single:
query = st.text_input("Enter a product search query", placeholder="e.g. wireless mouse, airpods, replacement gasket for instant pot")
if st.button("Classify", key="single") and query.strip():
result = predict([query.strip()], tokenizer, model)[0]
label = result["label"]
display_label = DISPLAY_LABELS[label]
volume_range = VOLUME_RANGES[label]
col1, col2, col3 = st.columns(3)
col1.metric("Predicted Class", display_label)
col2.metric("Estimated Sessions", volume_range)
col3.metric("Confidence", f"{result['confidence']:.1%}")
st.markdown("**Class probabilities**")
for lbl in ["very_high", "high", "medium", "low", "very_low"]:
prob = result["distribution"][lbl]
st.progress(prob, text=f"{DISPLAY_LABELS[lbl]} β€” {prob:.1%}")
with tab_batch:
st.markdown("Enter one query per line:")
text = st.text_area("Queries", height=200, placeholder="laptop\nairpods\norganic flurb capsules\nreplacement gasket for instant pot duo 8 quart")
if st.button("Classify", key="batch") and text.strip():
queries = [q.strip() for q in text.strip().splitlines() if q.strip()]
if len(queries) > 500:
st.warning("Please enter 500 queries or fewer.")
else:
results = predict(queries, tokenizer, model)
df = pd.DataFrame([
{
"Query": r["query"],
"Class": DISPLAY_LABELS[r["label"]],
"Sessions": VOLUME_RANGES[r["label"]],
"Confidence": r["confidence"],
}
for r in results
])
styled = df.style.map(style_class_cells, subset=["Class"])
st.dataframe(
styled,
column_config={
"Confidence": st.column_config.ProgressColumn(
"Confidence",
format="%.1f%%",
min_value=0,
max_value=1,
),
},
use_container_width=True,
hide_index=True,
)