import streamlit as st import torch import pandas as pd from transformers import AutoTokenizer, AutoModelForSequenceClassification MODEL_ID = "dejanseo/ecommerce-query-volume-classifier" LABELS = ["very_low", "low", "medium", "high", "very_high"] DISPLAY_LABELS = { "very_high": "Very High", "high": "High", "medium": "Medium", "low": "Low", "very_low": "Very Low", } VOLUME_RANGES = { "very_high": "10,000+", "high": "1,000–9,999", "medium": "100–999", "low": "10–99", "very_low": "<10", } CLASS_COLORS = { "Very High": "background-color: rgba(40, 167, 69, 0.1); color: #28a745;", "High": "background-color: rgba(113, 194, 133, 0.1); color: #71c285;", "Medium": "background-color: rgba(255, 193, 7, 0.1); color: #c79100;", "Low": "background-color: rgba(253, 126, 20, 0.1); color: #fd7e14;", "Very Low": "background-color: rgba(220, 53, 69, 0.1); color: #dc3545;", } @st.cache_resource def load_model(): tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID) model.eval() return tokenizer, model def predict(queries, tokenizer, model): inputs = tokenizer( queries, return_tensors="pt", padding=True, truncation=True, max_length=32, ) with torch.no_grad(): logits = model(**inputs).logits probs = torch.softmax(logits, dim=-1) preds = torch.argmax(probs, dim=-1) results = [] for i, query in enumerate(queries): label = LABELS[preds[i].item()] conf = probs[i][preds[i].item()].item() dist = {LABELS[j]: probs[i][j].item() for j in range(len(LABELS))} results.append({"query": query, "label": label, "confidence": conf, "distribution": dist}) return results def style_class_cells(val): return CLASS_COLORS.get(val, "") st.set_page_config(page_title="Query Volume Classifier", page_icon="🔍", layout="wide") st.markdown(""" """, unsafe_allow_html=True) st.logo( "https://dejan.ai/wp-content/uploads/2024/02/dejan-300x103.png", link="https://dejan.ai/", size="large", ) st.subheader("eCommerce Query Volume Classifier") st.caption("Predicts search volume class for product queries.") tokenizer, model = load_model() tab_single, tab_batch = st.tabs(["Single Query", "Batch"]) with tab_single: query = st.text_input("Enter a product search query", placeholder="e.g. wireless mouse, airpods, replacement gasket for instant pot") if st.button("Classify", key="single") and query.strip(): result = predict([query.strip()], tokenizer, model)[0] label = result["label"] display_label = DISPLAY_LABELS[label] volume_range = VOLUME_RANGES[label] col1, col2, col3 = st.columns(3) col1.metric("Predicted Class", display_label) col2.metric("Estimated Sessions", volume_range) col3.metric("Confidence", f"{result['confidence']:.1%}") st.markdown("**Class probabilities**") for lbl in ["very_high", "high", "medium", "low", "very_low"]: prob = result["distribution"][lbl] st.progress(prob, text=f"{DISPLAY_LABELS[lbl]} — {prob:.1%}") with tab_batch: st.markdown("Enter one query per line:") text = st.text_area("Queries", height=200, placeholder="laptop\nairpods\norganic flurb capsules\nreplacement gasket for instant pot duo 8 quart") if st.button("Classify", key="batch") and text.strip(): queries = [q.strip() for q in text.strip().splitlines() if q.strip()] if len(queries) > 500: st.warning("Please enter 500 queries or fewer.") else: results = predict(queries, tokenizer, model) df = pd.DataFrame([ { "Query": r["query"], "Class": DISPLAY_LABELS[r["label"]], "Sessions": VOLUME_RANGES[r["label"]], "Confidence": r["confidence"], } for r in results ]) styled = df.style.map(style_class_cells, subset=["Class"]) st.dataframe( styled, column_config={ "Confidence": st.column_config.ProgressColumn( "Confidence", format="%.1f%%", min_value=0, max_value=1, ), }, use_container_width=True, hide_index=True, )